Skip to content

Add Apple Silicon (MPS) GPU support#437

Closed
mianaz wants to merge 1 commit into
broadinstitute:masterfrom
mianaz:feature/mps-apple-silicon-support
Closed

Add Apple Silicon (MPS) GPU support#437
mianaz wants to merge 1 commit into
broadinstitute:masterfrom
mianaz:feature/mps-apple-silicon-support

Conversation

@mianaz

@mianaz mianaz commented Feb 9, 2026

Copy link
Copy Markdown

Summary

Adds support for running CellBender on Apple Silicon Macs (M1/M2/M3/M4) using PyTorch's Metal Performance Shaders (MPS) backend. This enables GPU-accelerated ambient RNA removal without requiring NVIDIA CUDA hardware.

  • Hybrid architecture: Distribution sampling (Gamma, Beta, Dirichlet) uses CPU fallback for numerical stability; neural network operations run on MPS for GPU acceleration
  • Posterior-to-CPU routing: After MPS training, posterior computation runs on CPU to ensure noise estimation matches CUDA exactly
  • New CLI flags: --mps, --posterior-device, --posterior-debug

Validation Results

Tested with 150 epochs on a sample dataset across CPU, MPS (Apple M4 Max), and CUDA (NVIDIA RTX 5060):

Metric MPS CUDA Difference
Noise removed 11.0% 11.1% 0.1%
Cells found 11,843 11,627 1.9%
Per-cell count correlation - - 0.994
Per-gene count correlation - - 0.9999
Ambient expression cosine sim - - 0.999
Cell overlap (Jaccard) - - 82.1%

10-epoch three-way comparison: CPU 7.66%, MPS 7.66%, CUDA 7.73%.

Changes by Category

New Files (3)

File Lines Purpose
device_utils.py 232 Device detection, posterior-device resolution, model/param-store moves
pyro_mps_patch.py 315 Monkey-patches for Pyro _Subsample and PyTorch distribution MPS compatibility
mps_diagnostics.py 672 Diagnostic utilities for comparing MPS vs CUDA posterior outputs

Modified Files (16)

MPS Device Support

  • argparser.py: Add --mps, --posterior-device, --posterior-debug CLI flags
  • cli.py: Wire new args through, set use_mps default, handle MPS device selection
  • run.py: Integrate device_utils for device selection, add MPS-specific posterior routing, extract _construct_model_and_loaders() helper
  • model.py: Add use_mps attribute, _make_all_parameters_contiguous() method, MPS-aware device handling in forward/guide
  • estimation.py: Use device_utils for GPU detection instead of CUDA-only checks
  • monitor.py: MPS-aware memory reporting, device-agnostic monitoring
  • train.py: MPS batch-skip for intermittent shape mismatch errors, device-aware training loop

Numerical Stability Fixes

  • NegativeBinomialPoissonConvApprox.py: .contiguous() before lgamma/digamma to prevent NaN gradients on MPS
  • posterior.py: CPU fallback for torch.nonzero on MPS, CPU-only noise log-prob when training used MPS, chi_ambient floor, posterior-device support
  • report.py: .values for pandas Series → numpy conversion (sparse matrix indexing fix), HTML report styling

PyTorch Compatibility

  • checkpoint.py: PyTorch 2.6+ weights_only=False + safe_globals, robust format detection for save/load
  • data/dataprep.py: Replace CUDA-specific tensor creation with device-agnostic calls
  • data/dataset.py: Device-agnostic tensor operations
  • data/io.py: MPS device handling for data loading

Tests

  • tests/test_posterior.py: Add posterior device override tests

Key Technical Details

Why Hybrid (not Full-MPS)?

Full-MPS had scale-dependent failures: works on small datasets, fails on large ones (25,000+ droplets). Root causes:

  1. MPS lgamma/digamma produce NaN on non-contiguous tensors from .expand()
  2. torch.nonzero returns wrong indices on MPS
  3. Gradient approximation errors accumulate at scale

MPS Traps Documented

  1. Non-contiguous tensors from .expand() → always call .contiguous()
  2. lgamma/digamma on expanded tensors → NaN gradients (CRITICAL)
  3. float64 unsupported on MPS → use float32 everywhere
  4. torch.nonzero returns wrong indices → move to CPU first
  5. Broadcasting with unsqueeze → ensure both tensors are contiguous

Test plan

  • All 41 existing tests pass (pytest cellbender/remove_background/tests/ -v)
  • 10-epoch CPU/MPS/CUDA three-way comparison: all within 0.07% noise removal
  • 150-epoch MPS/CUDA comparison: within 0.1% noise removal
  • Per-cell count correlation > 0.99 between MPS and CUDA outputs
  • Checkpoint save/load round-trip validated
  • Test on additional datasets (community testing welcome)
  • Performance benchmarks on various Apple Silicon chips

Acknowledgments

Based on CellBender v0.3.2. Original work by Stephen Fleming et al.

  • Fleming, S.J., Chaffin, M.D., Awatramani, A. et al. Unsupervised removal of systematic background noise from droplet-based single-cell experiments using CellBender. Nature Methods 20, 1323–1335 (2023). https://doi.org/10.1038/s41592-023-01943-7

🤖 Generated with Claude Code

Add support for running CellBender on Apple Silicon Macs using
PyTorch's MPS backend. This enables GPU-accelerated ambient RNA
removal on M1/M2/M3/M4 Macs without requiring NVIDIA CUDA hardware.

Architecture: Hybrid CPU Sampling + MPS Neural Network
- Distribution sampling (Gamma, Beta, Dirichlet) uses CPU fallback via
  PyTorch's verified implementations to avoid MPS numerical issues
- Neural network operations (encoder, decoder) run on MPS for ~1.5-2x
  speedup vs pure CPU
- Posterior computation routes to CPU after MPS training to ensure
  noise estimation matches CUDA exactly

New CLI flags:
- --mps: Enable MPS (Apple Silicon) GPU acceleration
- --posterior-device {auto,model,cpu,cuda,mps}: Control posterior device
- --posterior-debug: Log detailed posterior diagnostics

Validated results (150 epochs, sample dataset):
- MPS noise removal: 11.0% vs CUDA: 11.1% (within 0.1%)
- MPS cells found: 11,843 vs CUDA: 11,627 (82% Jaccard overlap)
- Per-cell count correlation: 0.994
- Per-gene count correlation: 0.9999
- Ambient expression cosine similarity: 0.999

Key fixes:
- Patch Pyro's _Subsample to accept MPS devices (use_cuda=False)
- CPU fallback for Gamma/Beta/Dirichlet sampling (exact gradients)
- .contiguous() before lgamma/digamma on MPS (NaN gradient fix)
- Posterior-to-CPU routing when training on MPS (noise CDF fix)
- torch.nonzero CPU fallback on MPS (wrong indices bug)
- PyTorch 2.6+ checkpoint compatibility (weights_only handling)
- MPS batch skip for intermittent shape mismatch errors
- pandas Series .values for sparse matrix indexing in reports

New files:
- device_utils.py: Device detection, posterior routing, model moves
- pyro_mps_patch.py: Monkey-patches for Pyro/PyTorch MPS compat
- mps_diagnostics.py: Diagnostic utilities for MPS debugging

Based on CellBender v0.3.2 by Stephen Fleming et al.
Original paper: Fleming et al., Nature Methods, 2023
https://doi.org/10.1038/s41592-023-01943-7

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@mianaz mianaz marked this pull request as draft February 9, 2026 23:51
@mianaz

mianaz commented Feb 10, 2026

Copy link
Copy Markdown
Author

Closing to rework: will incorporate cleaner device abstraction pattern from #429 and resubmit with squashed history. The numerical fixes and validation results remain the same.

@mianaz mianaz closed this Feb 10, 2026
@mianaz mianaz deleted the feature/mps-apple-silicon-support branch February 10, 2026 00:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant