Add Apple Silicon (MPS) GPU support#437
Closed
mianaz wants to merge 1 commit into
Closed
Conversation
Add support for running CellBender on Apple Silicon Macs using
PyTorch's MPS backend. This enables GPU-accelerated ambient RNA
removal on M1/M2/M3/M4 Macs without requiring NVIDIA CUDA hardware.
Architecture: Hybrid CPU Sampling + MPS Neural Network
- Distribution sampling (Gamma, Beta, Dirichlet) uses CPU fallback via
PyTorch's verified implementations to avoid MPS numerical issues
- Neural network operations (encoder, decoder) run on MPS for ~1.5-2x
speedup vs pure CPU
- Posterior computation routes to CPU after MPS training to ensure
noise estimation matches CUDA exactly
New CLI flags:
- --mps: Enable MPS (Apple Silicon) GPU acceleration
- --posterior-device {auto,model,cpu,cuda,mps}: Control posterior device
- --posterior-debug: Log detailed posterior diagnostics
Validated results (150 epochs, sample dataset):
- MPS noise removal: 11.0% vs CUDA: 11.1% (within 0.1%)
- MPS cells found: 11,843 vs CUDA: 11,627 (82% Jaccard overlap)
- Per-cell count correlation: 0.994
- Per-gene count correlation: 0.9999
- Ambient expression cosine similarity: 0.999
Key fixes:
- Patch Pyro's _Subsample to accept MPS devices (use_cuda=False)
- CPU fallback for Gamma/Beta/Dirichlet sampling (exact gradients)
- .contiguous() before lgamma/digamma on MPS (NaN gradient fix)
- Posterior-to-CPU routing when training on MPS (noise CDF fix)
- torch.nonzero CPU fallback on MPS (wrong indices bug)
- PyTorch 2.6+ checkpoint compatibility (weights_only handling)
- MPS batch skip for intermittent shape mismatch errors
- pandas Series .values for sparse matrix indexing in reports
New files:
- device_utils.py: Device detection, posterior routing, model moves
- pyro_mps_patch.py: Monkey-patches for Pyro/PyTorch MPS compat
- mps_diagnostics.py: Diagnostic utilities for MPS debugging
Based on CellBender v0.3.2 by Stephen Fleming et al.
Original paper: Fleming et al., Nature Methods, 2023
https://doi.org/10.1038/s41592-023-01943-7
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Author
|
Closing to rework: will incorporate cleaner device abstraction pattern from #429 and resubmit with squashed history. The numerical fixes and validation results remain the same. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds support for running CellBender on Apple Silicon Macs (M1/M2/M3/M4) using PyTorch's Metal Performance Shaders (MPS) backend. This enables GPU-accelerated ambient RNA removal without requiring NVIDIA CUDA hardware.
--mps,--posterior-device,--posterior-debugValidation Results
Tested with 150 epochs on a sample dataset across CPU, MPS (Apple M4 Max), and CUDA (NVIDIA RTX 5060):
10-epoch three-way comparison: CPU 7.66%, MPS 7.66%, CUDA 7.73%.
Changes by Category
New Files (3)
device_utils.pypyro_mps_patch.pymps_diagnostics.pyModified Files (16)
MPS Device Support
argparser.py: Add--mps,--posterior-device,--posterior-debugCLI flagscli.py: Wire new args through, setuse_mpsdefault, handle MPS device selectionrun.py: Integrate device_utils for device selection, add MPS-specific posterior routing, extract_construct_model_and_loaders()helpermodel.py: Adduse_mpsattribute,_make_all_parameters_contiguous()method, MPS-aware device handling in forward/guideestimation.py: Use device_utils for GPU detection instead of CUDA-only checksmonitor.py: MPS-aware memory reporting, device-agnostic monitoringtrain.py: MPS batch-skip for intermittent shape mismatch errors, device-aware training loopNumerical Stability Fixes
NegativeBinomialPoissonConvApprox.py:.contiguous()beforelgamma/digammato prevent NaN gradients on MPSposterior.py: CPU fallback fortorch.nonzeroon MPS, CPU-only noise log-prob when training used MPS, chi_ambient floor, posterior-device supportreport.py:.valuesfor pandas Series → numpy conversion (sparse matrix indexing fix), HTML report stylingPyTorch Compatibility
checkpoint.py: PyTorch 2.6+weights_only=False+safe_globals, robust format detection for save/loaddata/dataprep.py: Replace CUDA-specific tensor creation with device-agnostic callsdata/dataset.py: Device-agnostic tensor operationsdata/io.py: MPS device handling for data loadingTests
tests/test_posterior.py: Add posterior device override testsKey Technical Details
Why Hybrid (not Full-MPS)?
Full-MPS had scale-dependent failures: works on small datasets, fails on large ones (25,000+ droplets). Root causes:
lgamma/digammaproduce NaN on non-contiguous tensors from.expand()torch.nonzeroreturns wrong indices on MPSMPS Traps Documented
.expand()→ always call.contiguous()lgamma/digammaon expanded tensors → NaN gradients (CRITICAL)float64unsupported on MPS → usefloat32everywheretorch.nonzeroreturns wrong indices → move to CPU firstunsqueeze→ ensure both tensors are contiguousTest plan
pytest cellbender/remove_background/tests/ -v)Acknowledgments
Based on CellBender v0.3.2. Original work by Stephen Fleming et al.
🤖 Generated with Claude Code