-
Notifications
You must be signed in to change notification settings - Fork 9
Refactor hardware access with dependency injection #137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
018c331
d917615
56ba029
25ff1d4
326c789
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,18 +2,18 @@ | |
| # SPDX-License-Identifier: Apache-2.0 | ||
| """Check for NVLink status.""" | ||
|
|
||
| import pynvml | ||
| from rapids_cli.hardware import HardwareInfoError | ||
| from rapids_cli.providers import get_gpu_info | ||
|
|
||
|
|
||
| def check_nvlink_status(verbose=True, **kwargs): | ||
| """Check NVLink status across all GPUs.""" | ||
| gpu_info = get_gpu_info() | ||
| try: | ||
| pynvml.nvmlInit() | ||
| except pynvml.NVMLError as e: | ||
| device_count = gpu_info.device_count | ||
| except HardwareInfoError as e: | ||
| raise ValueError("GPU not found. Please ensure GPUs are installed.") from e | ||
|
|
||
| device_count = pynvml.nvmlDeviceGetCount() | ||
|
|
||
| # NVLink requires at least 2 GPUs to be meaningful. A single GPU has nothing | ||
| # to link to, so there is nothing to check. | ||
| if device_count < 2: | ||
|
|
@@ -23,29 +23,20 @@ def check_nvlink_status(verbose=True, **kwargs): | |
| # model). Mixed configurations — e.g. some NVLink-capable GPUs alongside some | ||
| # that are not — are not handled and may produce misleading results. | ||
|
|
||
| failed_links: list[tuple[int, int]] = [] | ||
|
|
||
| for gpu_idx in range(device_count): | ||
| handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_idx) | ||
| # NVML provides no API to query the number of NVLink slots on a device | ||
| # (e.g. V100=6, A100=12, H100=18). The only way to discover the real count | ||
| # is to iterate up to NVML_NVLINK_MAX_LINKS and stop when the driver signals | ||
| # that link_id is out of range via NVMLError_InvalidArgument. | ||
| for link_id in range(pynvml.NVML_NVLINK_MAX_LINKS): | ||
| try: | ||
| # nvmlDeviceGetNvLinkState(device, link) returns NVML_FEATURE_ENABLED | ||
| # if the link is active, or NVML_FEATURE_DISABLED if it is not. | ||
| state = pynvml.nvmlDeviceGetNvLinkState(handle, link_id) | ||
| if state == pynvml.NVML_FEATURE_DISABLED: | ||
| failed_links.append((gpu_idx, link_id)) | ||
| except pynvml.NVMLError_NotSupported: | ||
| # The driver reports NVLink is not supported on this system. | ||
| # There is nothing to check — skip like the single-GPU case above. | ||
| return False | ||
| except pynvml.NVMLError_InvalidArgument: | ||
| # link_id exceeds the number of NVLink slots on this device. | ||
| # Stop iterating links for this GPU. | ||
| break | ||
| devices = gpu_info.devices | ||
|
|
||
| # An empty nvlink_states means the driver reported NVLink as unsupported (or | ||
| # no links were enumerated) for that device. Treat a system where no device | ||
| # advertises links the same as the single-GPU case — nothing to check. | ||
| if all(not dev.nvlink_states for dev in devices): | ||
| return False | ||
|
|
||
| failed_links: list[tuple[int, int]] = [ | ||
| (dev.index, link_id) | ||
| for dev in devices | ||
| for link_id, active in enumerate(dev.nvlink_states) | ||
| if not active | ||
| ] | ||
|
Comment on lines
+34
to
+39
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I haven't had the time to try this, but I'm not sure if this is covering the same than the cases we had before. The tests seem have to be changed too so it's hard to know.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The refactored |
||
|
|
||
| if failed_links: | ||
| details = ", ".join(f"GPU {gpu} link {link}" for gpu, link in failed_links) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happened here with passing
toolkit_info: CudaToolkitInfo?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
toolkit_infoparameter was replaced by the provider registry —cuda_toolkit_checknow reads fromget_toolkit_info()(viarapids_cli.providers) instead of receiving it as a kwarg. This is the core change of the DI refactor: checks no longer need provider parameters threaded through their signatures. Tests install fakes into the registry viamonkeypatch.setattrfixtures in conftest.py.Also reverted the local variable name back to
toolkit_infoin 56ba029 to minimize the diff.