From 2c2b4878fad0585cbf7f74c4735dcd3b83922561 Mon Sep 17 00:00:00 2001 From: "Sara R. Birkby" Date: Mon, 4 May 2026 11:47:46 -0500 Subject: [PATCH 1/5] Add DirectML support for AMD GPU inference on Windows/WSL Three minimal changes to pytorch_detector.py: - DirectML device detection after MPS fallback - safe_map_location for privateuseone device compatibility - pred.cpu() after model call for downstream CPU operations --- megadetector/detection/pytorch_detector.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/megadetector/detection/pytorch_detector.py b/megadetector/detection/pytorch_detector.py index 19af50d0..95208bf2 100644 --- a/megadetector/detection/pytorch_detector.py +++ b/megadetector/detection/pytorch_detector.py @@ -873,6 +873,13 @@ def __init__(self, model_path, detector_options=None, verbose=False): self.device = 'mps' except AttributeError: pass + if self.device == 'cpu': + try: + import torch_directml + self.device = torch_directml.device() + print('Using DirectML device') + except ImportError: + pass # AddaxAI depends on this printout, don't remove it print('PTDetector using device {}'.format(str(self.device).lower())) @@ -914,14 +921,17 @@ def _load_model(model_pt_path, device, compatibility_mode='', verbose=False): # other than MPS devices. use_map_location = (device != 'mps') + # DirectML (privateuseone) doesn't support map_location; load to CPU, .to(device) handles the rest + safe_map_location = 'cpu' if 'privateuseone' in str(device) else device + if use_map_location: try: - checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False) + checkpoint = torch.load(model_pt_path, map_location=safe_map_location, weights_only=False) # For a transitional period, we want to support torch 1.1x, where the weights_only # parameter doesn't exist except Exception as e: if "'weights_only' is an invalid keyword" in str(e): - checkpoint = torch.load(model_pt_path, map_location=device) + checkpoint = torch.load(model_pt_path, map_location=safe_map_location) else: raise else: @@ -1302,6 +1312,8 @@ def _process_batch_group(self, group_items, results, detection_threshold, augmen # Run the model on the batch pred = self.model(batch_tensor, augment=augment)[0] + if 'privateuseone' in str(self.device): + pred = pred.cpu() # Configure NMS parameters if 'classic' in self.compatibility_mode: From d9ad214fab0ba8706ea6bb398f646344091f49ce Mon Sep 17 00:00:00 2001 From: Dan Morris Date: Thu, 7 May 2026 21:18:19 +0000 Subject: [PATCH 2/5] add directml test to gpu_test --- megadetector/utils/gpu_test.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/megadetector/utils/gpu_test.py b/megadetector/utils/gpu_test.py index b928b8fc..4bf79117 100644 --- a/megadetector/utils/gpu_test.py +++ b/megadetector/utils/gpu_test.py @@ -20,6 +20,28 @@ #%% Torch/TF test functions +def directml_test(): + """ + Check whether DirectML support is available. + + Returns: + bool: Whether directML support is available. + """ + + try: + import torch_directml + print('\n*** DirectML imported, running DirectML test ***\n') + + device = torch_directml.device() + print('DirectML device name: {}'.format(str(device))) + if 'privateuseone' in str(device): + print('DirectML device detected') + return True + except Exception: + pass + return False + + def torch_test(): """ Print diagnostic information about Torch/CUDA status, including Torch/CUDA versions @@ -123,3 +145,6 @@ def tf_test(): print('\n*** Running TF tests ***\n') tf_test() + + # This is rare, so don't include any printouts in the common case + directml_test() From 00b028c0003e8927570ab6c8db9d7c6a639ce873 Mon Sep 17 00:00:00 2001 From: Dan Morris Date: Thu, 7 May 2026 21:29:18 +0000 Subject: [PATCH 3/5] README update --- megadetector.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/megadetector.md b/megadetector.md index 59bcb934..d6e91123 100644 --- a/megadetector.md +++ b/megadetector.md @@ -203,6 +203,15 @@ python -m megadetector.utils.gpu_test If it still says "No GPUs reported by PyTorch", 95% of the time, this is fixed by updating your Nvidia driver and rebooting. If you have an Nvidia GPU, and you've installed the latest driver, and you've rebooted, and you did everything we suggested here, and you're still seeing "No GPUs reported by PyTorch", email us. +#### Using an AMD GPU + +Deep learning is pretty Nvidia-centric, to the point where "GPU" is almost synonymous with "Nvidia GPU". But... the MegaDetector Python package includes experimental support for AMD Radeon GPUs. This is not installed by default, but if you have an AMD GPU, and you run: + +```bash +pip install torch-directml +``` + +...your GPU should be detected. If this worked, you will see "Using DirectML device" in the output when you run MegaDetector. ## How do I use the results? From e7c7eda384dfb0f5a5ffcae3327eaee987bf295a Mon Sep 17 00:00:00 2001 From: Dan Morris Date: Thu, 7 May 2026 21:34:15 +0000 Subject: [PATCH 4/5] gpu_test update --- megadetector/utils/gpu_test.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/megadetector/utils/gpu_test.py b/megadetector/utils/gpu_test.py index 4bf79117..6a9f4af1 100644 --- a/megadetector/utils/gpu_test.py +++ b/megadetector/utils/gpu_test.py @@ -28,17 +28,25 @@ def directml_test(): bool: Whether directML support is available. """ + torch_directml_imported = False + try: + import torch_directml print('\n*** DirectML imported, running DirectML test ***\n') + torch_directml_imported = True device = torch_directml.device() print('DirectML device name: {}'.format(str(device))) if 'privateuseone' in str(device): print('DirectML device detected') return True - except Exception: - pass + + except Exception as e: + + if torch_directml_imported: + print('Error: {}'.format(str(e))) + return False From 2ee04a873efb8eb2f063d09b97fa2fb7022d23e7 Mon Sep 17 00:00:00 2001 From: Dan Morris Date: Thu, 7 May 2026 15:08:40 -0700 Subject: [PATCH 5/5] whitespace fix --- megadetector/utils/gpu_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/megadetector/utils/gpu_test.py b/megadetector/utils/gpu_test.py index 6a9f4af1..6b5e6d59 100644 --- a/megadetector/utils/gpu_test.py +++ b/megadetector/utils/gpu_test.py @@ -41,12 +41,12 @@ def directml_test(): if 'privateuseone' in str(device): print('DirectML device detected') return True - + except Exception as e: - + if torch_directml_imported: print('Error: {}'.format(str(e))) - + return False