Add MPS (Metal Performance Shaders) support for Apple Silicon GPUs#429
Open
BenjaminDEMAILLE wants to merge 3 commits into
Open
Add MPS (Metal Performance Shaders) support for Apple Silicon GPUs#429BenjaminDEMAILLE wants to merge 3 commits into
BenjaminDEMAILLE wants to merge 3 commits into
Conversation
- Add --mps command-line argument to enable MPS backend - Generalize device selection to support cuda/mps/cpu - Update all modules to use device parameter instead of use_cuda flag - Add MPS availability checks and user warnings in CLI This enables GPU acceleration on Apple Silicon (M1/M2/M3/M4) Macs using PyTorch's MPS backend, providing significant speed-ups for inference on macOS devices. The implementation is backward compatible with existing --cuda and CPU workflows, with CUDA taking precedence over MPS when both are available. Based on the original MPS implementation from commit 8a70ea3 by Stephen Fleming, adapted for CellBender v0.3.2 code structure. Addresses issue broadinstitute#149
PyTorch MPS backend doesn't support all operations (e.g., aten::_standard_gamma used by Gamma distribution). This commit automatically enables PYTORCH_ENABLE_MPS_FALLBACK=1 when --mps is used, allowing unsupported operations to fall back to CPU. This provides a better user experience and enables CellBender to run on Apple Silicon without manual environment variable configuration. While some operations will use CPU fallback (slower), the overall performance is still significantly better than CPU-only mode. Related to issue broadinstitute#149
The PYTORCH_ENABLE_MPS_FALLBACK environment variable must be set BEFORE torch is imported for it to take effect. Moved the env var setting to the very beginning of base_cli.main(), before get_populated_argparser() and generate_cli_dictionary() which trigger module imports. This ensures the fallback is enabled for unsupported MPS operations like aten::_standard_gamma used by Gamma distribution in Pyro. Tested and confirmed working - Gamma sampling on MPS now falls back to CPU instead of throwing NotImplementedError.
Author
|
Hi ! |
This was referenced Feb 10, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add MPS (Metal Performance Shaders) support for Apple Silicon GPUs
This PR brings native PyTorch MPS support to CellBender 0.3.2 so users on Apple Silicon (M1/M2/M3/M4) Macs can run GPU‑accelerated inference on macOS. It adapts the original work from commit
8a70ea3on the legacysf_pytorch_mps_backendbranch (v0.2.0) to the 0.3.2 codebase.8a70ea306efbfd63f2a219b93b1b2c2749f641dfadd-mps-supportSummary
--mpsto enable the PyTorch MPS backend.use_cuda, a stringdeviceis passed end‑to‑end and can be one ofcuda,mps, orcpu.Why this change?
Changes (by file)
cellbender/remove_background/argparser.py--mpsflag with help text and link to PyTorch MPS docs.cellbender/remove_background/cli.pyargs.devicetocuda(if--cudaand available), elsemps(if--mpsand available), elsecpu.cellbender/remove_background/model.pydevice: strinstead ofuse_cuda: bool..to(device)(model and submodules) instead of.cuda().use_cudastate; storesself.deviceonly.pyro.plate(..., use_cuda=..., device=...)withpyro.plate(..., device=...).cellbender/remove_background/data/dataprep.pyDataLoaderacceptsdevice: strand pushes tensors to that device.prep_sparse_data_for_training(...)acceptsdeviceand propagates it to loaders.cellbender/remove_background/data/dataset.pyget_dataloader(...)now takesdevice: str(instead ofuse_cuda: bool) and forwards it toDataLoader.cellbender/remove_background/run.pyargs.deviceconsistently.args.device == 'cpu'.force_deviceconsistent with selected backend.deviceinto posterior computations and estimators.cellbender/remove_background/posterior.pydevicefromvi_model.device(or sensible fallback).deviceexplicitly.cellbender/remove_background/train.pymodel.device == 'cuda'(instead ofmodel.use_cuda).torch.cuda.empty_cache()if CUDA.torch.mps.empty_cache()if MPS and available (wrapped in try/except).Device selection behavior
--cudais provided andtorch.cuda.is_available(), usecuda.--mpsis provided andtorch.backends.mps.is_available()andtorch.backends.mps.is_built(), usemps.cpu.CUDA takes precedence over MPS when both are requested/available.
How to use
To verify the flag is visible:
cellbender remove-background --help | grep -A 4 -- --mpsTesting performed
--mpsappears in the CLI help.torch.backends.mps.is_available()andis_built()on Apple Silicon test machine.--cudaand CPU remain intact.Note: Full end-to-end test suite (including GPU tests) should be run in CI or by maintainers; this PR aims to be minimally invasive while restoring the MPS feature.
Limitations and follow-ups
cellbender/monitor.py) prints GPU utilization vianvidia-smi(CUDA only). There’s no analogous standard CLI for MPS; for now, logs omit MPS GPU utilization. A future improvement could add optional macOS/MPS metrics if a stable API becomes available.Backward compatibility
--mpsis additive.force_devicewhere appropriate.Related work
8a70ea3(Stephen Fleming) onsf_pytorch_mps_backend(0.2.0).Reviewer notes
use_cudaassumptions in the active code paths.Checklist
--mpsflag and help textThanks for reviewing! This should unlock fast, native GPU acceleration for a large portion of the community using Apple Silicon machines, while preserving the familiar CUDA and CPU paths.