Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions megadetector.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,15 @@ python -m megadetector.utils.gpu_test

If it <i>still</i> says "No GPUs reported by PyTorch", 95% of the time, this is fixed by <a href="https://www.nvidia.com/en-us/geforce/drivers/">updating your Nvidia driver</a> and rebooting. If you have an Nvidia GPU, and you've installed the latest driver, and you've rebooted, and you did everything we suggested here, and you're still seeing "No GPUs reported by PyTorch", <a href="mailto:cameratraps@lila.science">email us</a>.

#### Using an AMD GPU

Deep learning is pretty Nvidia-centric, to the point where "GPU" is almost synonymous with "Nvidia GPU". But... the MegaDetector Python package includes experimental support for AMD Radeon GPUs. This is not installed by default, but if you have an AMD GPU, and you run:

```bash
pip install torch-directml
```

...your GPU should be detected. If this worked, you will see "Using DirectML device" in the output when you run MegaDetector.

## How do I use the results?

Expand Down
16 changes: 14 additions & 2 deletions megadetector/detection/pytorch_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,13 @@ def __init__(self, model_path, detector_options=None, verbose=False):
self.device = 'mps'
except AttributeError:
pass
if self.device == 'cpu':
try:
import torch_directml
self.device = torch_directml.device()
print('Using DirectML device')
except ImportError:
pass

# AddaxAI depends on this printout, don't remove it
print('PTDetector using device {}'.format(str(self.device).lower()))
Expand Down Expand Up @@ -914,14 +921,17 @@ def _load_model(model_pt_path, device, compatibility_mode='', verbose=False):
# other than MPS devices.
use_map_location = (device != 'mps')

# DirectML (privateuseone) doesn't support map_location; load to CPU, .to(device) handles the rest
safe_map_location = 'cpu' if 'privateuseone' in str(device) else device

if use_map_location:
try:
checkpoint = torch.load(model_pt_path, map_location=device, weights_only=False)
checkpoint = torch.load(model_pt_path, map_location=safe_map_location, weights_only=False)
# For a transitional period, we want to support torch 1.1x, where the weights_only
# parameter doesn't exist
except Exception as e:
if "'weights_only' is an invalid keyword" in str(e):
checkpoint = torch.load(model_pt_path, map_location=device)
checkpoint = torch.load(model_pt_path, map_location=safe_map_location)
else:
raise
else:
Expand Down Expand Up @@ -1302,6 +1312,8 @@ def _process_batch_group(self, group_items, results, detection_threshold, augmen

# Run the model on the batch
pred = self.model(batch_tensor, augment=augment)[0]
if 'privateuseone' in str(self.device):
pred = pred.cpu()

# Configure NMS parameters
if 'classic' in self.compatibility_mode:
Expand Down
33 changes: 33 additions & 0 deletions megadetector/utils/gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,36 @@

#%% Torch/TF test functions

def directml_test():
"""
Check whether DirectML support is available.

Returns:
bool: Whether directML support is available.
"""

torch_directml_imported = False

try:

import torch_directml
print('\n*** DirectML imported, running DirectML test ***\n')
torch_directml_imported = True

device = torch_directml.device()
print('DirectML device name: {}'.format(str(device)))
if 'privateuseone' in str(device):
print('DirectML device detected')
return True

except Exception as e:

if torch_directml_imported:
print('Error: {}'.format(str(e)))

return False


def torch_test():
"""
Print diagnostic information about Torch/CUDA status, including Torch/CUDA versions
Expand Down Expand Up @@ -123,3 +153,6 @@ def tf_test():

print('\n*** Running TF tests ***\n')
tf_test()

# This is rare, so don't include any printouts in the common case
directml_test()
Loading