Add Apple Silicon (MPS) GPU support#438
Open
mianaz wants to merge 1 commit into
Open
Conversation
Add support for running CellBender on Apple Silicon Macs using
PyTorch's MPS backend. This enables GPU-accelerated ambient RNA
removal on M1/M2/M3/M4 Macs without requiring NVIDIA CUDA hardware.
Architecture: Hybrid CPU Sampling + MPS Neural Network
- Distribution sampling (Gamma, Beta, Dirichlet) uses CPU fallback via
PyTorch's verified implementations to avoid MPS numerical issues
- Neural network operations (encoder, decoder) run on MPS for ~1.5-2x
speedup vs pure CPU
- Posterior computation routes to CPU after MPS training to ensure
noise estimation matches CUDA exactly
New CLI flags:
- --mps: Enable MPS (Apple Silicon) GPU acceleration
- --posterior-device {auto,model,cpu,cuda,mps}: Control posterior device
- --posterior-debug: Log detailed posterior diagnostics
Validated results (150 epochs, sample dataset):
- MPS noise removal: 11.0% vs CUDA: 11.1% (within 0.1%)
- MPS cells found: 11,843 vs CUDA: 11,627 (82% Jaccard overlap)
- Per-cell count correlation: 0.994
- Per-gene count correlation: 0.9999
- Ambient expression cosine similarity: 0.999
Key fixes:
- Patch Pyro's _Subsample to accept MPS devices (use_cuda=False)
- CPU fallback for Gamma/Beta/Dirichlet sampling (exact gradients)
- .contiguous() before lgamma/digamma on MPS (NaN gradient fix)
- Posterior-to-CPU routing when training on MPS (noise CDF fix)
- torch.nonzero CPU fallback on MPS (wrong indices bug)
- PyTorch 2.6+ checkpoint compatibility (weights_only handling)
- MPS batch skip for intermittent shape mismatch errors
- pandas Series .values for sparse matrix indexing in reports
New files:
- device_utils.py: Device detection, posterior routing, model moves
- pyro_mps_patch.py: Monkey-patches for Pyro/PyTorch MPS compat
- mps_diagnostics.py: Diagnostic utilities for MPS debugging
Based on CellBender v0.3.2 by Stephen Fleming et al.
Original paper: Fleming et al., Nature Methods, 2023
https://doi.org/10.1038/s41592-023-01943-7
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
look my PR on pytorch/pytorch#173319 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds support for running CellBender on Apple Silicon Macs (M1/M2/M3/M4) using PyTorch's Metal Performance Shaders (MPS) backend, resolving the long-standing request in #149.
--mps,--posterior-device,--posterior-debugRelation to #429
PR #429 by @BenjaminDEMAILLE takes a complementary approach: cleaner
device: strAPI generalization withPYTORCH_ENABLE_MPS_FALLBACK=1. However, through extensive testing we found that the MPS backend has several numerical issues that require explicit workarounds beyond env-var fallback:lgamma/digammaNaN on non-contiguous tensors — Training diverges after ~20 epochs without.contiguous()guardstorch.nonzeroreturns wrong indices on MPS — Posterior produces incorrect resultsPYTORCH_ENABLE_MPS_FALLBACK— Falls back to CPU without proper gradient tracking, causing high-variance REINFORCE gradients instead of pathwise reparameterized gradientsThese are the same issues reported by @asabjorklund and @BradBalderson in #149. This PR addresses all of them with targeted workarounds validated against CUDA.
Validation Results
Tested with 150 epochs on a sample dataset across CPU, MPS (Apple M4 Max), and CUDA (NVIDIA RTX 5060):
10-epoch three-way comparison: CPU 7.66%, MPS 7.66%, CUDA 7.73%.
Changes by Category
New Files (3)
device_utils.pypyro_mps_patch.pymps_diagnostics.pyModified Files (16)
MPS Device Support
argparser.py: Add--mps,--posterior-device,--posterior-debugCLI flagscli.py: Wire new args, validate MPS availability, import patchesrun.py: Integrate device_utils, MPS-specific posterior routingmodel.py: Adduse_mpsattribute,.contiguous()guards, MPS-aware device handlingestimation.py,monitor.py,dataprep.py,dataset.py,io.py: Device-agnostic updatesNumerical Stability Fixes (Critical for MPS)
NegativeBinomialPoissonConvApprox.py:.contiguous()beforelgamma/digamma(prevents NaN gradients)posterior.py: CPU fallback fortorch.nonzero, CPU-only noise log-prob after MPS trainingtrain.py: MPS batch-skip for intermittent shape mismatch errorsreport.py:.valuesfor pandas Series sparse matrix indexingPyTorch 2.6+ Compatibility
checkpoint.py:weights_only=False+safe_globalsfor constraint objectsTests
tests/test_posterior.py: Posterior device override testsKey Technical Details
Why Hybrid (not Full-MPS)?
Full-MPS with
PYTORCH_ENABLE_MPS_FALLBACK=1(as in #429) has several problems:.expand()causelgammato return NaN gradients — training divergestorch.nonzeroreturns wrong indices — posterior computes incorrect noise countsOur hybrid approach: CPU sampling with explicit gradient tracking + MPS neural network ops. ~1.6x slower than CUDA but produces correct results.
Test plan
pytest cellbender/remove_background/tests/ -v)Acknowledgments
Based on CellBender v0.3.2. Builds on the original MPS exploration by @sjfleming (
sf_pytorch_mps_backendbranch) and the device generalization work by @BenjaminDEMAILLE (#429).Fleming, S.J., Chaffin, M.D., Awatramani, A. et al. Nature Methods 20, 1323–1335 (2023). https://doi.org/10.1038/s41592-023-01943-7
Closes #149
🤖 Generated with Claude Code