Skip to content

Add Apple Silicon (MPS) GPU support#438

Open
mianaz wants to merge 1 commit into
broadinstitute:mainfrom
mianaz:feature/mps-apple-silicon-support
Open

Add Apple Silicon (MPS) GPU support#438
mianaz wants to merge 1 commit into
broadinstitute:mainfrom
mianaz:feature/mps-apple-silicon-support

Conversation

@mianaz

@mianaz mianaz commented Feb 10, 2026

Copy link
Copy Markdown

Summary

Adds support for running CellBender on Apple Silicon Macs (M1/M2/M3/M4) using PyTorch's Metal Performance Shaders (MPS) backend, resolving the long-standing request in #149.

  • Hybrid architecture: Distribution sampling (Gamma, Beta, Dirichlet) uses CPU fallback for numerical stability; neural network operations run on MPS for GPU acceleration
  • Posterior-to-CPU routing: After MPS training, posterior computation runs on CPU to ensure noise estimation matches CUDA exactly
  • New CLI flags: --mps, --posterior-device, --posterior-debug

Relation to #429

PR #429 by @BenjaminDEMAILLE takes a complementary approach: cleaner device: str API generalization with PYTORCH_ENABLE_MPS_FALLBACK=1. However, through extensive testing we found that the MPS backend has several numerical issues that require explicit workarounds beyond env-var fallback:

  1. lgamma/digamma NaN on non-contiguous tensors — Training diverges after ~20 epochs without .contiguous() guards
  2. torch.nonzero returns wrong indices on MPS — Posterior produces incorrect results
  3. Posterior on MPS produces inverted denoising (87% removal vs CUDA's 5.6%) — Requires CPU routing
  4. Intermittent shape mismatch crashes — Requires batch-skip error handling
  5. Gamma sampling with PYTORCH_ENABLE_MPS_FALLBACK — Falls back to CPU without proper gradient tracking, causing high-variance REINFORCE gradients instead of pathwise reparameterized gradients

These are the same issues reported by @asabjorklund and @BradBalderson in #149. This PR addresses all of them with targeted workarounds validated against CUDA.

Validation Results

Tested with 150 epochs on a sample dataset across CPU, MPS (Apple M4 Max), and CUDA (NVIDIA RTX 5060):

Metric MPS CUDA Difference
Noise removed 11.0% 11.1% 0.1%
Cells found 11,843 11,627 1.9%
Per-cell count correlation - - 0.994
Per-gene count correlation - - 0.9999
Ambient expression cosine sim - - 0.999
Cell overlap (Jaccard) - - 82.1%

10-epoch three-way comparison: CPU 7.66%, MPS 7.66%, CUDA 7.73%.

Changes by Category

New Files (3)

File Lines Purpose
device_utils.py 232 Device detection, posterior-device resolution, model/param-store moves
pyro_mps_patch.py 315 Monkey-patches for Pyro _Subsample and PyTorch distribution MPS compatibility
mps_diagnostics.py 672 Diagnostic utilities for comparing MPS vs CUDA posterior outputs

Modified Files (16)

MPS Device Support

  • argparser.py: Add --mps, --posterior-device, --posterior-debug CLI flags
  • cli.py: Wire new args, validate MPS availability, import patches
  • run.py: Integrate device_utils, MPS-specific posterior routing
  • model.py: Add use_mps attribute, .contiguous() guards, MPS-aware device handling
  • estimation.py, monitor.py, dataprep.py, dataset.py, io.py: Device-agnostic updates

Numerical Stability Fixes (Critical for MPS)

  • NegativeBinomialPoissonConvApprox.py: .contiguous() before lgamma/digamma (prevents NaN gradients)
  • posterior.py: CPU fallback for torch.nonzero, CPU-only noise log-prob after MPS training
  • train.py: MPS batch-skip for intermittent shape mismatch errors
  • report.py: .values for pandas Series sparse matrix indexing

PyTorch 2.6+ Compatibility

  • checkpoint.py: weights_only=False + safe_globals for constraint objects

Tests

  • tests/test_posterior.py: Posterior device override tests

Key Technical Details

Why Hybrid (not Full-MPS)?

Full-MPS with PYTORCH_ENABLE_MPS_FALLBACK=1 (as in #429) has several problems:

  1. Gamma sampling falls back to CPU without gradient tracking, so Pyro uses high-variance REINFORCE estimator instead of pathwise reparameterized gradients (~100x gradient variance)
  2. Non-contiguous tensors from .expand() cause lgamma to return NaN gradients — training diverges
  3. torch.nonzero returns wrong indices — posterior computes incorrect noise counts
  4. Even with correct training, MPS posterior converges to a different local minimum (87% noise removal vs 5.6% on CUDA)

Our hybrid approach: CPU sampling with explicit gradient tracking + MPS neural network ops. ~1.6x slower than CUDA but produces correct results.

Test plan

  • All existing tests pass (pytest cellbender/remove_background/tests/ -v)
  • 10-epoch CPU/MPS/CUDA three-way comparison: all within 0.07%
  • 150-epoch MPS/CUDA comparison: within 0.1% noise removal
  • Per-cell count correlation > 0.99 between MPS and CUDA
  • Checkpoint save/load round-trip validated
  • Test on additional datasets (community testing welcome)

Acknowledgments

Based on CellBender v0.3.2. Builds on the original MPS exploration by @sjfleming (sf_pytorch_mps_backend branch) and the device generalization work by @BenjaminDEMAILLE (#429).

Fleming, S.J., Chaffin, M.D., Awatramani, A. et al. Nature Methods 20, 1323–1335 (2023). https://doi.org/10.1038/s41592-023-01943-7

Closes #149

🤖 Generated with Claude Code

Add support for running CellBender on Apple Silicon Macs using
PyTorch's MPS backend. This enables GPU-accelerated ambient RNA
removal on M1/M2/M3/M4 Macs without requiring NVIDIA CUDA hardware.

Architecture: Hybrid CPU Sampling + MPS Neural Network
- Distribution sampling (Gamma, Beta, Dirichlet) uses CPU fallback via
  PyTorch's verified implementations to avoid MPS numerical issues
- Neural network operations (encoder, decoder) run on MPS for ~1.5-2x
  speedup vs pure CPU
- Posterior computation routes to CPU after MPS training to ensure
  noise estimation matches CUDA exactly

New CLI flags:
- --mps: Enable MPS (Apple Silicon) GPU acceleration
- --posterior-device {auto,model,cpu,cuda,mps}: Control posterior device
- --posterior-debug: Log detailed posterior diagnostics

Validated results (150 epochs, sample dataset):
- MPS noise removal: 11.0% vs CUDA: 11.1% (within 0.1%)
- MPS cells found: 11,843 vs CUDA: 11,627 (82% Jaccard overlap)
- Per-cell count correlation: 0.994
- Per-gene count correlation: 0.9999
- Ambient expression cosine similarity: 0.999

Key fixes:
- Patch Pyro's _Subsample to accept MPS devices (use_cuda=False)
- CPU fallback for Gamma/Beta/Dirichlet sampling (exact gradients)
- .contiguous() before lgamma/digamma on MPS (NaN gradient fix)
- Posterior-to-CPU routing when training on MPS (noise CDF fix)
- torch.nonzero CPU fallback on MPS (wrong indices bug)
- PyTorch 2.6+ checkpoint compatibility (weights_only handling)
- MPS batch skip for intermittent shape mismatch errors
- pandas Series .values for sparse matrix indexing in reports

New files:
- device_utils.py: Device detection, posterior routing, model moves
- pyro_mps_patch.py: Monkey-patches for Pyro/PyTorch MPS compat
- mps_diagnostics.py: Diagnostic utilities for MPS debugging

Based on CellBender v0.3.2 by Stephen Fleming et al.
Original paper: Fleming et al., Nature Methods, 2023
https://doi.org/10.1038/s41592-023-01943-7

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@BenjaminDEMAILLE

Copy link
Copy Markdown

look my PR on pytorch/pytorch#173319

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

"Cuda" counterpart for Apple M1 computers

2 participants