diff --git a/megadetector.md b/megadetector.md index 59bcb934..d6e91123 100644 --- a/megadetector.md +++ b/megadetector.md @@ -203,6 +203,15 @@ python -m megadetector.utils.gpu_test If it still says "No GPUs reported by PyTorch", 95% of the time, this is fixed by updating your Nvidia driver and rebooting. If you have an Nvidia GPU, and you've installed the latest driver, and you've rebooted, and you did everything we suggested here, and you're still seeing "No GPUs reported by PyTorch", email us. +#### Using an AMD GPU + +Deep learning is pretty Nvidia-centric, to the point where "GPU" is almost synonymous with "Nvidia GPU". But... the MegaDetector Python package includes experimental support for AMD Radeon GPUs. This is not installed by default, but if you have an AMD GPU, and you run: + +```bash +pip install torch-directml +``` + +...your GPU should be detected. If this worked, you will see "Using DirectML device" in the output when you run MegaDetector. ## How do I use the results? diff --git a/megadetector/detection/pytorch_detector.py b/megadetector/detection/pytorch_detector.py index 19af50d0..95208bf2 100644 --- a/megadetector/detection/pytorch_detector.py +++ b/megadetector/detection/pytorch_detector.py @@ -873,6 +873,13 @@ def __init__(self, model_path, detector_options=None, verbose=False): self.device = 'mps' except AttributeError: pass + if self.device == 'cpu': + try: + import torch_directml + self.device = torch_directml.device() + print('Using DirectML device') + except ImportError: + pass # AddaxAI depends on this printout, don't remove it print('PTDetector using device {}'.format(str(self.device).lower())) @@ -914,14 +921,17 @@ def _load_model(model_pt_path, device, compatibility_mode='', verbose=False): # other than MPS devices. use_map_location = (device != 'mps') + # DirectML (privateuseone) doesn't support map_location; load to CPU, .to(device) handles the rest + safe_map_location = 'cpu' if 'privateuseone' in str(device) else device + if use_map_location: try: - checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False) + checkpoint = torch.load(model_pt_path, map_location=safe_map_location, weights_only=False) # For a transitional period, we want to support torch 1.1x, where the weights_only # parameter doesn't exist except Exception as e: if "'weights_only' is an invalid keyword" in str(e): - checkpoint = torch.load(model_pt_path, map_location=device) + checkpoint = torch.load(model_pt_path, map_location=safe_map_location) else: raise else: @@ -1302,6 +1312,8 @@ def _process_batch_group(self, group_items, results, detection_threshold, augmen # Run the model on the batch pred = self.model(batch_tensor, augment=augment)[0] + if 'privateuseone' in str(self.device): + pred = pred.cpu() # Configure NMS parameters if 'classic' in self.compatibility_mode: diff --git a/megadetector/utils/gpu_test.py b/megadetector/utils/gpu_test.py index b928b8fc..6b5e6d59 100644 --- a/megadetector/utils/gpu_test.py +++ b/megadetector/utils/gpu_test.py @@ -20,6 +20,36 @@ #%% Torch/TF test functions +def directml_test(): + """ + Check whether DirectML support is available. + + Returns: + bool: Whether directML support is available. + """ + + torch_directml_imported = False + + try: + + import torch_directml + print('\n*** DirectML imported, running DirectML test ***\n') + torch_directml_imported = True + + device = torch_directml.device() + print('DirectML device name: {}'.format(str(device))) + if 'privateuseone' in str(device): + print('DirectML device detected') + return True + + except Exception as e: + + if torch_directml_imported: + print('Error: {}'.format(str(e))) + + return False + + def torch_test(): """ Print diagnostic information about Torch/CUDA status, including Torch/CUDA versions @@ -123,3 +153,6 @@ def tf_test(): print('\n*** Running TF tests ***\n') tf_test() + + # This is rare, so don't include any printouts in the common case + directml_test()