From aab0cea53336284446423883111f8766d7070340 Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Mon, 29 Sep 2025 11:04:08 +0300 Subject: [PATCH 01/13] feat(pytorch): batch_converter hook, converter helper, tests --- .vscode/settings.json | 1 + src/anndata/experimental/pytorch/__init__.py | 7 +- .../experimental/pytorch/converters.py | 66 +++++++++++++++++++ .../tests/pytorch/test_batch_converter.py | 31 +++++++++ 4 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 src/anndata/experimental/pytorch/converters.py create mode 100644 src/anndata/tests/pytorch/test_batch_converter.py diff --git a/.vscode/settings.json b/.vscode/settings.json index 7024abf5f..47cd9754b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -24,4 +24,5 @@ //"-nauto", ], "python.terminal.activateEnvironment": true, + "cursorpyright.analysis.typeCheckingMode": "basic", } diff --git a/src/anndata/experimental/pytorch/__init__.py b/src/anndata/experimental/pytorch/__init__.py index 36c9441fe..2701c502a 100644 --- a/src/anndata/experimental/pytorch/__init__.py +++ b/src/anndata/experimental/pytorch/__init__.py @@ -1,5 +1,10 @@ from __future__ import annotations +# public re-exports from ._annloader import AnnLoader +from .converters import to_tensor_dict as batch_dict_converter -__all__ = ["AnnLoader"] +__all__: list[str] = [ + "AnnLoader", + "batch_dict_converter", +] diff --git a/src/anndata/experimental/pytorch/converters.py b/src/anndata/experimental/pytorch/converters.py new file mode 100644 index 000000000..01e162ccf --- /dev/null +++ b/src/anndata/experimental/pytorch/converters.py @@ -0,0 +1,66 @@ +"""Helper converters for AnnLoader batches. + +This module provides convenience converters that can be passed to the +``batch_converter`` parameter of :pyclass:`~anndata.experimental.pytorch.AnnLoader`. +""" +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any, Dict + +import pandas as pd +import torch + +try: # keep mypy happy when torch not present for docs build + from torch import Tensor +except ImportError: # pragma: no cover + Tensor = Any # type: ignore + +__all__ = ["to_tensor_dict"] + + +def _to_tensor(arr) -> Tensor | Any: # noqa: ANN401 + """Best-effort conversion of *arr* to ``torch.Tensor``. + + Falls back to returning *arr* unchanged if torch or numpy is not available. + """ + if isinstance(arr, torch.Tensor): + return arr + try: + import numpy as np + from scipy.sparse import issparse + + if issparse(arr): + arr = arr.toarray() + if isinstance(arr, (np.ndarray, list)): + return torch.tensor(arr) + except ImportError: # pragma: no cover + pass + return arr + + +def to_tensor_dict(batch: Any) -> Dict[str, Any]: # noqa: ANN401 + """Convert an AnnLoader batch to a plain ``dict`` of tensors/arrays. + + * ``X`` → ``"x"`` + * each column in ``obs`` becomes a key in the output dict + * if *batch* is already a mapping it is returned as a *shallow copy*. + """ + # If user already returns dict-like we preserve it + if isinstance(batch, Mapping): + return dict(batch) + + out: Dict[str, Any] = {} + + # AnnCollectionView has .X and .obs attributes + if hasattr(batch, "X"): + out["x"] = _to_tensor(batch.X) + + if hasattr(batch, "obs") and batch.obs is not None: + obs_df = batch.obs + if isinstance(obs_df, pd.DataFrame): + for col in obs_df.columns: + # ensure unique keys – users can post-process if needed + out[col] = _to_tensor(obs_df[col].to_numpy()) + + return out diff --git a/src/anndata/tests/pytorch/test_batch_converter.py b/src/anndata/tests/pytorch/test_batch_converter.py new file mode 100644 index 000000000..2344bc6bc --- /dev/null +++ b/src/anndata/tests/pytorch/test_batch_converter.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import numpy as np +import pytest +import torch + +import anndata as ad +from anndata.experimental.pytorch import AnnLoader, batch_dict_converter + + +def _make_dummy_adata(n_obs: int = 10, n_vars: int = 4): + X = np.random.rand(n_obs, n_vars).astype(np.float32) + obs = {"group": np.arange(n_obs) % 2} + return ad.AnnData(X=X, obs=obs) + + +@pytest.mark.parametrize("num_workers", [0, 2]) +def test_batch_converter_returns_dict(num_workers): + adata = _make_dummy_adata(16, 3) + loader = AnnLoader( + adata, + batch_size=4, + batch_converter=batch_dict_converter, + num_workers=num_workers, + ) + batch = next(iter(loader)) + + assert isinstance(batch, dict) + assert "x" in batch and isinstance(batch["x"], torch.Tensor) + assert batch["x"].shape == (4, 3) + assert "group" in batch From 31f027eed8bf09b8ede375be0ed79247c94c0154 Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Mon, 29 Sep 2025 12:27:17 +0300 Subject: [PATCH 02/13] feat: add batch_converter hook + multiprocessing support to AnnLoader - Add batch_converter parameter for advanced batch-level post-processing - Enable multiprocessing (num_workers>0) via AnnCollectionView pickling - Implement worker-side batch conversion for true parallelism - Add comprehensive tests for both single and multi-threaded modes - Include helper batch_dict_converter for common dict format - All tests pass, pre-commit hooks clean, backward compatible Solves two major AnnLoader limitations in unified implementation: 1. Batch-level transformation (vs element-wise convert) 2. Multiprocessing support (was broken due to unpicklable AnnCollectionView) Enables production ML workflows with PyTorch Lightning integration, data augmentation, balanced sampling, and parallel data loading. --- .../multi_files/_anncollection.py | 68 +++++++++++++++++- .../experimental/pytorch/_annloader.py | 62 ++++++++++++++++- .../experimental/pytorch/converters.py | 31 ++++++--- .../tests/pytorch/test_batch_converter.py | 69 +++++++++++++++++-- .../tests/pytorch/test_batch_converter_mp.py | 40 +++++++++++ 5 files changed, 254 insertions(+), 16 deletions(-) create mode 100644 src/anndata/tests/pytorch/test_batch_converter_mp.py diff --git a/src/anndata/experimental/multi_files/_anncollection.py b/src/anndata/experimental/multi_files/_anncollection.py index 77e833cbd..059f4fe63 100644 --- a/src/anndata/experimental/multi_files/_anncollection.py +++ b/src/anndata/experimental/multi_files/_anncollection.py @@ -291,6 +291,43 @@ def __init__(self, reference, convert, resolved_idx): self._convert_X = None self.convert = convert + # ------------------------------------------------------------------ + # Pickling support (worker-safe) + # ------------------------------------------------------------------ + + def __getstate__(self): + """Return minimal state for safe pickling across processes. + + We only serialise lightweight metadata – the on-disk store is reopened + in the child worker to avoid passing an open HDF5 handle. + """ + return { + "reference_state": self.reference.__getstate__() + if hasattr(self.reference, "__getstate__") + else None, + "oidx": self.oidx, + "vidx": self.vidx, + "convert": self.convert, + } + + def __setstate__(self, state): + from anndata.experimental.multi_files import ( + AnnCollection, # local import to avoid circular + ) + + # Rebuild from saved reference_state (in-memory collections) + parent = AnnCollection.__new__(AnnCollection) + if state["reference_state"] is not None: + parent.__setstate__(state["reference_state"]) + else: + msg = "Cannot restore AnnCollectionView without reference state" + raise ValueError(msg) + + # Recreate the view via slicing to reuse internal helpers + view = parent[state["oidx"], state["vidx"]] + self.__dict__.update(view.__dict__) + self.convert = state["convert"] + def _lazy_init_attr(self, attr: str, *, set_vidx: bool = False): if getattr(self, f"_{attr}_view") is not None: return @@ -779,7 +816,7 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915 ai_attr = getattr(a, attr) a0_attr = getattr(adatas[0], attr) new_keys = [] - for key in keys: + for key in keys or []: if key in ai_attr: a0_ashape = a0_attr[key].shape ai_ashape = ai_attr[key].shape @@ -806,6 +843,35 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915 self.indices_strict = indices_strict + # ------------------------------------------------------------------ + # Pickling support (worker-safe) + # ------------------------------------------------------------------ + + def __getstate__(self): + """Return state for pickling. For in-memory collections, we serialize all data.""" + return { + "adatas": self.adatas, + "join_obs": getattr(self, "_join_obs", "inner"), + "join_obsm": getattr(self, "_join_obsm", None), + "join_vars": getattr(self, "_join_vars", None), + "convert": self._convert, + "harmonize_dtypes": getattr(self, "_harmonize_dtypes", True), + "indices_strict": self.indices_strict, + } + + def __setstate__(self, state): + """Restore state from pickling.""" + # Reconstruct AnnCollection from saved adatas and parameters + self.__init__( + state["adatas"], + join_obs=state["join_obs"], + join_obsm=state["join_obsm"], + join_vars=state["join_vars"], + convert=state["convert"], + harmonize_dtypes=state["harmonize_dtypes"], + indices_strict=state["indices_strict"], + ) + def __getitem__(self, index: Index): oidx, vidx = _normalize_indices(index, self.obs_names, self.var_names) resolved_idx = self._resolve_idx(oidx, vidx) diff --git a/src/anndata/experimental/pytorch/_annloader.py b/src/anndata/experimental/pytorch/_annloader.py index d5e9fbf81..707441351 100644 --- a/src/anndata/experimental/pytorch/_annloader.py +++ b/src/anndata/experimental/pytorch/_annloader.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator, Sequence - from typing import TypeAlias, Union + from typing import Any, TypeAlias, Union from scipy.sparse import spmatrix @@ -118,6 +118,40 @@ def compose_convert(arr): return new_convert +class _WorkerCollateWrapper: + """Picklable wrapper for batch_converter to use in multiprocessing workers.""" + + def __init__(self, batch_converter): + self.batch_converter = batch_converter + + def __call__(self, batch): + # First, create a batch from AnnCollectionView objects by concatenating them + if len(batch) == 0: + return batch + + # Assume all samples are AnnCollectionView objects + first_sample = batch[0] + if not hasattr(first_sample, "reference"): + # Not an AnnCollectionView, fallback to default collate + from torch.utils.data._utils.collate import default_collate + + return default_collate(batch) + + # Create a batch view by concatenating the indices + reference = first_sample.reference + all_oidx = [] + all_vidx = first_sample.vidx # Assume same variables for all samples + + for sample in batch: + all_oidx.extend(sample.oidx) + + # Create a new view with all the indices + batch_view = reference[all_oidx, all_vidx] + + # Apply the batch converter to the combined view + return self.batch_converter(batch_view) + + # AnnLoader has the same arguments as DataLoader, but uses BatchIndexSampler by default class AnnLoader(DataLoader): """\ @@ -143,6 +177,9 @@ class AnnLoader(DataLoader): use_cuda Transfer pytorch tensors to the default cuda device after conversion. Only works if `use_default_converter=True` + batch_converter + Optional callable to transform each batch after collation. + Note: Only works with `num_workers=0` due to `AnnCollectionView` serialization constraints. **kwargs Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also arguments for `AnnCollection` initialization. @@ -157,6 +194,7 @@ def __init__( shuffle: bool = False, use_default_converter: bool = True, use_cuda: bool = False, + batch_converter: Callable[[Any], Any] | None = None, **kwargs, ): if isinstance(adatas, AnnData): @@ -199,6 +237,20 @@ def __init__( dataset.convert, _converter, dict(dataset.attrs_keys, X=[]) ) + # Remove in case user passed via **kwargs (for forward-compat) + batch_converter = kwargs.pop("batch_converter", batch_converter) + self._batch_converter = batch_converter + + # If workers >0 and user supplied a converter, apply it inside worker via custom collate + num_workers = kwargs.get("num_workers", 0) + if ( + batch_converter is not None + and num_workers > 0 + and "collate_fn" not in kwargs + ): + kwargs["collate_fn"] = _WorkerCollateWrapper(batch_converter) + # Set batch_converter to None so main process doesn't apply it again + self._batch_converter = None has_sampler = "sampler" in kwargs has_batch_sampler = "batch_sampler" in kwargs @@ -232,3 +284,11 @@ def __init__( super().__init__(dataset, batch_size=None, sampler=sampler, **kwargs) else: super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs) + + def __iter__(self): # type: ignore[override] + for batch in super().__iter__(): + yield ( + self._batch_converter(batch) + if self._batch_converter is not None + else batch + ) diff --git a/src/anndata/experimental/pytorch/converters.py b/src/anndata/experimental/pytorch/converters.py index 01e162ccf..edc9dfe8d 100644 --- a/src/anndata/experimental/pytorch/converters.py +++ b/src/anndata/experimental/pytorch/converters.py @@ -3,10 +3,11 @@ This module provides convenience converters that can be passed to the ``batch_converter`` parameter of :pyclass:`~anndata.experimental.pytorch.AnnLoader`. """ + from __future__ import annotations from collections.abc import Mapping -from typing import Any, Dict +from typing import Any import pandas as pd import torch @@ -19,7 +20,7 @@ __all__ = ["to_tensor_dict"] -def _to_tensor(arr) -> Tensor | Any: # noqa: ANN401 +def _to_tensor(arr) -> Tensor | Any: """Best-effort conversion of *arr* to ``torch.Tensor``. Falls back to returning *arr* unchanged if torch or numpy is not available. @@ -39,7 +40,7 @@ def _to_tensor(arr) -> Tensor | Any: # noqa: ANN401 return arr -def to_tensor_dict(batch: Any) -> Dict[str, Any]: # noqa: ANN401 +def to_tensor_dict(batch: Any) -> dict[str, Any]: """Convert an AnnLoader batch to a plain ``dict`` of tensors/arrays. * ``X`` → ``"x"`` @@ -50,17 +51,31 @@ def to_tensor_dict(batch: Any) -> Dict[str, Any]: # noqa: ANN401 if isinstance(batch, Mapping): return dict(batch) - out: Dict[str, Any] = {} + out: dict[str, Any] = {} # AnnCollectionView has .X and .obs attributes if hasattr(batch, "X"): out["x"] = _to_tensor(batch.X) if hasattr(batch, "obs") and batch.obs is not None: - obs_df = batch.obs - if isinstance(obs_df, pd.DataFrame): - for col in obs_df.columns: + obs_data = batch.obs + # Handle pandas DataFrame + if isinstance(obs_data, pd.DataFrame): + for col in obs_data.columns: # ensure unique keys – users can post-process if needed - out[col] = _to_tensor(obs_df[col].to_numpy()) + out[col] = _to_tensor(obs_data[col].to_numpy()) + # Handle AnnCollection MapObsView (can be converted to dict directly) + elif hasattr(obs_data, "to_dict"): + obs_dict = obs_data.to_dict() + for key, value in obs_dict.items(): + out[key] = _to_tensor(value) + # Handle generic dict-like objects + elif hasattr(obs_data, "keys") and callable(obs_data.keys): + try: + obs_dict = dict(obs_data) + for key, value in obs_dict.items(): + out[key] = _to_tensor(value) + except (TypeError, AttributeError, ValueError): + pass # Skip if conversion fails return out diff --git a/src/anndata/tests/pytorch/test_batch_converter.py b/src/anndata/tests/pytorch/test_batch_converter.py index 2344bc6bc..16d9927cb 100644 --- a/src/anndata/tests/pytorch/test_batch_converter.py +++ b/src/anndata/tests/pytorch/test_batch_converter.py @@ -1,7 +1,6 @@ from __future__ import annotations import numpy as np -import pytest import torch import anndata as ad @@ -10,22 +9,80 @@ def _make_dummy_adata(n_obs: int = 10, n_vars: int = 4): X = np.random.rand(n_obs, n_vars).astype(np.float32) - obs = {"group": np.arange(n_obs) % 2} + obs = {"group": np.arange(n_obs) % 2, "category": ["A", "B"] * (n_obs // 2)} return ad.AnnData(X=X, obs=obs) -@pytest.mark.parametrize("num_workers", [0, 2]) -def test_batch_converter_returns_dict(num_workers): +def test_batch_converter_default_behavior(): + """Test that without batch_converter, we get AnnCollectionView.""" + adata = _make_dummy_adata(16, 3) + loader = AnnLoader(adata, batch_size=4) + batch = next(iter(loader)) + + # Should be AnnCollectionView without converter + assert hasattr(batch, "X") + assert hasattr(batch, "obs") + + +def test_batch_converter_returns_dict(): + """Test that batch_dict_converter returns proper dict format.""" adata = _make_dummy_adata(16, 3) loader = AnnLoader( adata, batch_size=4, batch_converter=batch_dict_converter, - num_workers=num_workers, + num_workers=0, # Single-threaded for now ) batch = next(iter(loader)) assert isinstance(batch, dict) - assert "x" in batch and isinstance(batch["x"], torch.Tensor) + assert "x" in batch + assert isinstance(batch["x"], torch.Tensor) assert batch["x"].shape == (4, 3) assert "group" in batch + assert isinstance(batch["group"], torch.Tensor) + assert "category" in batch + + +def test_custom_batch_converter(): + """Test that custom batch converters work.""" + adata = _make_dummy_adata(8, 2) + + def custom_converter(batch): + result = batch_dict_converter(batch) + result["custom_field"] = torch.tensor([42]) + result["x_sum"] = result["x"].sum() + return result + + loader = AnnLoader( + adata, + batch_size=4, + batch_converter=custom_converter, + num_workers=0, + ) + batch = next(iter(loader)) + + assert isinstance(batch, dict) + assert "x" in batch + assert "group" in batch + assert "custom_field" in batch + assert "x_sum" in batch + assert batch["custom_field"].item() == 42 + assert isinstance(batch["x_sum"], torch.Tensor) + + +def test_batch_converter_multiprocessing_works(): + """Test that batch converter now works with num_workers > 0.""" + adata = _make_dummy_adata(16, 3) + loader = AnnLoader( + adata, + batch_size=4, + batch_converter=batch_dict_converter, + num_workers=2, + ) + + # This should now work with our multiprocessing fix + batch = next(iter(loader)) + assert isinstance(batch, dict) + assert "x" in batch + assert batch["x"].shape == (4, 3) diff --git a/src/anndata/tests/pytorch/test_batch_converter_mp.py b/src/anndata/tests/pytorch/test_batch_converter_mp.py new file mode 100644 index 000000000..e751ba788 --- /dev/null +++ b/src/anndata/tests/pytorch/test_batch_converter_mp.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import anndata as ad +from anndata.experimental.pytorch import AnnLoader, batch_dict_converter + + +@pytest.mark.filterwarnings("ignore:Series.__getitem__") +def test_worker_safe_batch_converter(): + """AnnLoader should work with num_workers > 0 when batch_converter is supplied.""" + adata = ad.AnnData(X=np.random.rand(32, 4).astype(np.float32)) + + loader = AnnLoader( + adata, + batch_size=8, + batch_converter=batch_dict_converter, + num_workers=2, + ) + + batch = next(iter(loader)) + assert isinstance(batch, dict) + assert batch["x"].shape == (8, 4) + + +def test_default_collate_fails_with_anncollection(): + """Sanity-check that vanilla DataLoader still fails, documenting why fix is needed.""" + from torch.utils.data import DataLoader + + from anndata.experimental.multi_files import AnnCollection + + adata1 = ad.AnnData(X=np.random.rand(4, 3).astype(np.float32)) + adata2 = ad.AnnData(X=np.random.rand(4, 3).astype(np.float32)) + coll = AnnCollection([adata1, adata2]) + dataset = coll # AnnCollection implements __getitem__/__len__ + + failing_loader = DataLoader(dataset, batch_size=4, num_workers=2) + with pytest.raises(TypeError): + next(iter(failing_loader)) From 897a8382978ff4c46c05d52953a2cc8602956a2c Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Mon, 29 Sep 2025 14:33:16 +0300 Subject: [PATCH 03/13] reverts --- .vscode/settings.json | 1 - 1 file changed, 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 47cd9754b..7024abf5f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -24,5 +24,4 @@ //"-nauto", ], "python.terminal.activateEnvironment": true, - "cursorpyright.analysis.typeCheckingMode": "basic", } From a335761410c093a4a7a0dbf7e26899bf92e3b8dc Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Mon, 29 Sep 2025 15:32:51 +0300 Subject: [PATCH 04/13] docs: Update documentation for batch_converter and multiprocessing support - Fix AnnLoader docstring to remove incorrect multiprocessing limitation - Add batch_dict_converter to API documentation - Add release notes fragment for PR #2135 - Document new batch_converter parameter and multiprocessing capabilities The batch_converter parameter now works seamlessly with both single-threaded and multi-threaded data loading, enabling faster PyTorch training workflows. --- docs/api.md | 1 + docs/release-notes/2135.feature.md | 8 ++++++++ src/anndata/experimental/pytorch/_annloader.py | 2 +- 3 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 docs/release-notes/2135.feature.md diff --git a/docs/api.md b/docs/api.md index 2bb76940d..c52bf7a70 100644 --- a/docs/api.md +++ b/docs/api.md @@ -135,6 +135,7 @@ In particular, for pytorch-based models. experimental.AnnCollection experimental.AnnLoader + experimental.pytorch.batch_dict_converter ``` Out of core concatenation diff --git a/docs/release-notes/2135.feature.md b/docs/release-notes/2135.feature.md new file mode 100644 index 000000000..ed77897bc --- /dev/null +++ b/docs/release-notes/2135.feature.md @@ -0,0 +1,8 @@ +Add `batch_converter` parameter and multiprocessing support to {class}`~anndata.experimental.pytorch.AnnLoader`. + +- Added `batch_converter` parameter for batch-level post-processing of data batches +- Added {func}`~anndata.experimental.pytorch.batch_dict_converter` helper function for converting batches to tensor dictionaries +- Fixed multiprocessing support (`num_workers > 0`) by implementing pickling for `AnnCollectionView` objects +- Batch converters now work seamlessly with both single-threaded and multi-threaded data loading + +{user}`ronamit` diff --git a/src/anndata/experimental/pytorch/_annloader.py b/src/anndata/experimental/pytorch/_annloader.py index 707441351..dbd3e5128 100644 --- a/src/anndata/experimental/pytorch/_annloader.py +++ b/src/anndata/experimental/pytorch/_annloader.py @@ -179,7 +179,7 @@ class AnnLoader(DataLoader): Only works if `use_default_converter=True` batch_converter Optional callable to transform each batch after collation. - Note: Only works with `num_workers=0` due to `AnnCollectionView` serialization constraints. + Works with both single-threaded and multi-threaded data loading. **kwargs Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also arguments for `AnnCollection` initialization. From 520977be8bb34a30d426df1a14e790357dd03224 Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Mon, 29 Sep 2025 15:39:35 +0300 Subject: [PATCH 05/13] test: Move PyTorch batch_converter tests to main tests directory - Move tests from src/anndata/tests/pytorch/ to tests/pytorch/ - Follow standard anndata test organization pattern - Add __init__.py to make pytorch test package discoverable - Tests for batch_converter parameter and multiprocessing support --- tests/pytorch/__init__.py | 0 {src/anndata/tests => tests}/pytorch/test_batch_converter.py | 0 {src/anndata/tests => tests}/pytorch/test_batch_converter_mp.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 tests/pytorch/__init__.py rename {src/anndata/tests => tests}/pytorch/test_batch_converter.py (100%) rename {src/anndata/tests => tests}/pytorch/test_batch_converter_mp.py (100%) diff --git a/tests/pytorch/__init__.py b/tests/pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/anndata/tests/pytorch/test_batch_converter.py b/tests/pytorch/test_batch_converter.py similarity index 100% rename from src/anndata/tests/pytorch/test_batch_converter.py rename to tests/pytorch/test_batch_converter.py diff --git a/src/anndata/tests/pytorch/test_batch_converter_mp.py b/tests/pytorch/test_batch_converter_mp.py similarity index 100% rename from src/anndata/tests/pytorch/test_batch_converter_mp.py rename to tests/pytorch/test_batch_converter_mp.py From bc17aeaa3b1324c51d4ad0fc8d6d06d6f5f7ada0 Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Mon, 29 Sep 2025 16:52:38 +0300 Subject: [PATCH 06/13] fix: Handle missing PyTorch dependencies in CI - Add conditional torch imports using find_spec() pattern in converters.py - Make batch_dict_converter import conditional with helpful error message - Add pytest.importorskip('torch') to PyTorch test files - Fix linting warnings by using += instead of .append() for __all__ This resolves CI test collection errors when torch is not available while maintaining full functionality when torch is installed. --- src/anndata/experimental/pytorch/__init__.py | 21 ++++++++++++++----- .../experimental/pytorch/converters.py | 13 ++++++++---- tests/pytorch/test_batch_converter.py | 7 ++++++- tests/pytorch/test_batch_converter_mp.py | 3 +++ 4 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/anndata/experimental/pytorch/__init__.py b/src/anndata/experimental/pytorch/__init__.py index 2701c502a..813006543 100644 --- a/src/anndata/experimental/pytorch/__init__.py +++ b/src/anndata/experimental/pytorch/__init__.py @@ -1,10 +1,21 @@ from __future__ import annotations +from importlib.util import find_spec + # public re-exports from ._annloader import AnnLoader -from .converters import to_tensor_dict as batch_dict_converter -__all__: list[str] = [ - "AnnLoader", - "batch_dict_converter", -] +__all__: list[str] = ["AnnLoader"] + +# Only import batch_dict_converter if torch is available +if find_spec("torch"): + from .converters import to_tensor_dict as batch_dict_converter + + __all__ += ["batch_dict_converter"] +else: + # Provide a fallback that raises a helpful error + def batch_dict_converter(*args, **kwargs): + msg = "batch_dict_converter requires PyTorch. Install with: pip install torch" + raise ImportError(msg) + + __all__ += ["batch_dict_converter"] diff --git a/src/anndata/experimental/pytorch/converters.py b/src/anndata/experimental/pytorch/converters.py index edc9dfe8d..f2a60b1c2 100644 --- a/src/anndata/experimental/pytorch/converters.py +++ b/src/anndata/experimental/pytorch/converters.py @@ -7,14 +7,16 @@ from __future__ import annotations from collections.abc import Mapping -from typing import Any +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any import pandas as pd -import torch -try: # keep mypy happy when torch not present for docs build +if find_spec("torch") or TYPE_CHECKING: + import torch from torch import Tensor -except ImportError: # pragma: no cover +else: + torch = None # type: ignore Tensor = Any # type: ignore __all__ = ["to_tensor_dict"] @@ -25,6 +27,9 @@ def _to_tensor(arr) -> Tensor | Any: Falls back to returning *arr* unchanged if torch or numpy is not available. """ + if torch is None: + return arr + if isinstance(arr, torch.Tensor): return arr try: diff --git a/tests/pytorch/test_batch_converter.py b/tests/pytorch/test_batch_converter.py index 16d9927cb..283cfeafa 100644 --- a/tests/pytorch/test_batch_converter.py +++ b/tests/pytorch/test_batch_converter.py @@ -1,9 +1,14 @@ from __future__ import annotations import numpy as np -import torch +import pytest import anndata as ad + +pytest.importorskip("torch") + +import torch + from anndata.experimental.pytorch import AnnLoader, batch_dict_converter diff --git a/tests/pytorch/test_batch_converter_mp.py b/tests/pytorch/test_batch_converter_mp.py index e751ba788..49447ac4a 100644 --- a/tests/pytorch/test_batch_converter_mp.py +++ b/tests/pytorch/test_batch_converter_mp.py @@ -4,6 +4,9 @@ import pytest import anndata as ad + +pytest.importorskip("torch") + from anndata.experimental.pytorch import AnnLoader, batch_dict_converter From 77130c5041094b9c996ab82365c3f9acb1d0eab1 Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Tue, 30 Sep 2025 06:15:34 +0300 Subject: [PATCH 07/13] fix: suppress pyparsing oneOf deprecation warnings from matplotlib The CI tests were failing due to DeprecationWarning: 'oneOf' deprecated - use 'one_of' warnings from pyparsing used by matplotlib. This warning is triggered when scanpy imports matplotlib during test execution. Added warning filters to pytest configuration to ignore this specific deprecation warning from the dependency chain, allowing tests to pass while preserving warnings from anndata's own code. Fixes the following failing tests: - test_scanpy_pbmc68k[zarr3/zarr2] - test_scanpy_krumsiek11[zarr3/zarr2] - test_read_partial_adata[zarr2/zarr3] --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 1d19cb724..97157d10e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,6 +147,7 @@ filterwarnings = [ "ignore:.*first_column_names:FutureWarning:scanpy", # scanpy 1.10.x "ignore:Importing read_.* from `anndata` is deprecated:FutureWarning:scanpy", "ignore:`__version__` is deprecated:FutureWarning:scanpy", + "ignore:'oneOf' deprecated - use 'one_of':DeprecationWarning", # pyparsing deprecation in matplotlib ] # When `--strict-warnings` is used, all warnings are treated as errors, except those: filterwarnings_when_strict = [ @@ -161,6 +162,7 @@ filterwarnings_when_strict = [ "default:Consolidated metadata is:UserWarning", "default:.*Structured:zarr.core.dtype.common.UnstableSpecificationWarning", "default:.*FixedLengthUTF32:zarr.core.dtype.common.UnstableSpecificationWarning", + "default:'oneOf' deprecated - use 'one_of':DeprecationWarning", ] python_files = "test_*.py" testpaths = [ From 27081239098f355fd4b312d745106ae2b1c7a51e Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Tue, 30 Sep 2025 08:10:46 +0300 Subject: [PATCH 08/13] fix: suppress additional pyparsing parseString deprecation warnings After fixing the 'oneOf' deprecation warnings, CI revealed another pyparsing deprecation warning: 'parseString' deprecated - use 'parse_string'. This is also coming from matplotlib's font configuration parsing. Added additional warning filter to suppress this specific deprecation warning from the dependency chain, ensuring all pyparsing-related deprecation warnings from matplotlib are properly handled in CI. --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 97157d10e..7d721005d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,6 +148,7 @@ filterwarnings = [ "ignore:Importing read_.* from `anndata` is deprecated:FutureWarning:scanpy", "ignore:`__version__` is deprecated:FutureWarning:scanpy", "ignore:'oneOf' deprecated - use 'one_of':DeprecationWarning", # pyparsing deprecation in matplotlib + "ignore:'parseString' deprecated - use 'parse_string':DeprecationWarning", # pyparsing deprecation in matplotlib ] # When `--strict-warnings` is used, all warnings are treated as errors, except those: filterwarnings_when_strict = [ @@ -163,6 +164,7 @@ filterwarnings_when_strict = [ "default:.*Structured:zarr.core.dtype.common.UnstableSpecificationWarning", "default:.*FixedLengthUTF32:zarr.core.dtype.common.UnstableSpecificationWarning", "default:'oneOf' deprecated - use 'one_of':DeprecationWarning", + "default:'parseString' deprecated - use 'parse_string':DeprecationWarning", ] python_files = "test_*.py" testpaths = [ From 5bacf9cbf1b90481efca18b26d8222e769be73a5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Sep 2025 05:11:52 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7d721005d..57f47b0c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,7 +148,7 @@ filterwarnings = [ "ignore:Importing read_.* from `anndata` is deprecated:FutureWarning:scanpy", "ignore:`__version__` is deprecated:FutureWarning:scanpy", "ignore:'oneOf' deprecated - use 'one_of':DeprecationWarning", # pyparsing deprecation in matplotlib - "ignore:'parseString' deprecated - use 'parse_string':DeprecationWarning", # pyparsing deprecation in matplotlib + "ignore:'parseString' deprecated - use 'parse_string':DeprecationWarning", # pyparsing deprecation in matplotlib ] # When `--strict-warnings` is used, all warnings are treated as errors, except those: filterwarnings_when_strict = [ From 6477f576db5bd91a4d2b6c0a6ecff81f68aff744 Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Tue, 30 Sep 2025 08:27:46 +0300 Subject: [PATCH 10/13] fix: suppress additional pyparsing parseString deprecation warnings After fixing the 'oneOf' deprecation warnings, CI revealed another pyparsing deprecation warning: 'parseString' deprecated - use 'parse_string'. This is also coming from matplotlib's font configuration parsing. Added additional warning filter to suppress this specific deprecation warning from the dependency chain, ensuring all pyparsing-related deprecation warnings from matplotlib are properly handled in CI. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7d721005d..57f47b0c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,7 +148,7 @@ filterwarnings = [ "ignore:Importing read_.* from `anndata` is deprecated:FutureWarning:scanpy", "ignore:`__version__` is deprecated:FutureWarning:scanpy", "ignore:'oneOf' deprecated - use 'one_of':DeprecationWarning", # pyparsing deprecation in matplotlib - "ignore:'parseString' deprecated - use 'parse_string':DeprecationWarning", # pyparsing deprecation in matplotlib + "ignore:'parseString' deprecated - use 'parse_string':DeprecationWarning", # pyparsing deprecation in matplotlib ] # When `--strict-warnings` is used, all warnings are treated as errors, except those: filterwarnings_when_strict = [ From 54c6625e061c305698f3c04bb4b9a9c0bb33b67a Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Tue, 30 Sep 2025 08:52:49 +0300 Subject: [PATCH 11/13] Fix 'resetCache' deprecation warnings from pyparsing/matplotlib --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 57f47b0c9..4d60d51c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ filterwarnings = [ "ignore:`__version__` is deprecated:FutureWarning:scanpy", "ignore:'oneOf' deprecated - use 'one_of':DeprecationWarning", # pyparsing deprecation in matplotlib "ignore:'parseString' deprecated - use 'parse_string':DeprecationWarning", # pyparsing deprecation in matplotlib + "ignore:'resetCache' deprecated - use 'reset_cache':DeprecationWarning", # pyparsing deprecation in matplotlib ] # When `--strict-warnings` is used, all warnings are treated as errors, except those: filterwarnings_when_strict = [ @@ -165,6 +166,7 @@ filterwarnings_when_strict = [ "default:.*FixedLengthUTF32:zarr.core.dtype.common.UnstableSpecificationWarning", "default:'oneOf' deprecated - use 'one_of':DeprecationWarning", "default:'parseString' deprecated - use 'parse_string':DeprecationWarning", + "default:'resetCache' deprecated - use 'reset_cache':DeprecationWarning", ] python_files = "test_*.py" testpaths = [ From ab998fddfb5b2a6f6285be3422ebae6d8d50a92c Mon Sep 17 00:00:00 2001 From: Ron Amit Date: Tue, 30 Sep 2025 10:11:42 +0300 Subject: [PATCH 12/13] Fix 'enablePackrat' deprecation warnings from pyparsing/matplotlib --- pyproject.toml | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4d60d51c8..3ab88367f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,12 +144,13 @@ addopts = [ filterwarnings = [ "ignore::anndata._warnings.OldFormatWarning", "ignore::anndata._warnings.ExperimentalFeatureWarning", - "ignore:.*first_column_names:FutureWarning:scanpy", # scanpy 1.10.x + "ignore:.*first_column_names:FutureWarning:scanpy", # scanpy 1.10.x "ignore:Importing read_.* from `anndata` is deprecated:FutureWarning:scanpy", "ignore:`__version__` is deprecated:FutureWarning:scanpy", - "ignore:'oneOf' deprecated - use 'one_of':DeprecationWarning", # pyparsing deprecation in matplotlib - "ignore:'parseString' deprecated - use 'parse_string':DeprecationWarning", # pyparsing deprecation in matplotlib - "ignore:'resetCache' deprecated - use 'reset_cache':DeprecationWarning", # pyparsing deprecation in matplotlib + "ignore:'oneOf' deprecated - use 'one_of':DeprecationWarning", # pyparsing deprecation in matplotlib + "ignore:'parseString' deprecated - use 'parse_string':DeprecationWarning", # pyparsing deprecation in matplotlib + "ignore:'resetCache' deprecated - use 'reset_cache':DeprecationWarning", # pyparsing deprecation in matplotlib + "ignore:'enablePackrat' deprecated - use 'enable_packrat':DeprecationWarning", # pyparsing deprecation in matplotlib ] # When `--strict-warnings` is used, all warnings are treated as errors, except those: filterwarnings_when_strict = [ @@ -167,6 +168,7 @@ filterwarnings_when_strict = [ "default:'oneOf' deprecated - use 'one_of':DeprecationWarning", "default:'parseString' deprecated - use 'parse_string':DeprecationWarning", "default:'resetCache' deprecated - use 'reset_cache':DeprecationWarning", + "default:'enablePackrat' deprecated - use 'enable_packrat':DeprecationWarning", ] python_files = "test_*.py" testpaths = [ From d12a6c4a9a7eda3bae8bc13e05ed1ba907d22d3c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Oct 2025 08:29:08 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 429334d15..cde1f3286 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -144,7 +144,7 @@ addopts = [ filterwarnings = [ "ignore::anndata._warnings.OldFormatWarning", "ignore::anndata._warnings.ExperimentalFeatureWarning", - "ignore:.*first_column_names:FutureWarning:scanpy", # scanpy 1.10.x + "ignore:.*first_column_names:FutureWarning:scanpy", # scanpy 1.10.x "ignore:Importing read_.* from `anndata` is deprecated:FutureWarning:scanpy", "ignore:`__version__` is deprecated:FutureWarning:scanpy", # https://github.com/matplotlib/matplotlib/pull/30589