diff --git a/docs/api.md b/docs/api.md index 2bb76940d..c52bf7a70 100644 --- a/docs/api.md +++ b/docs/api.md @@ -135,6 +135,7 @@ In particular, for pytorch-based models. experimental.AnnCollection experimental.AnnLoader + experimental.pytorch.batch_dict_converter ``` Out of core concatenation diff --git a/docs/release-notes/2135.feature.md b/docs/release-notes/2135.feature.md new file mode 100644 index 000000000..ed77897bc --- /dev/null +++ b/docs/release-notes/2135.feature.md @@ -0,0 +1,8 @@ +Add `batch_converter` parameter and multiprocessing support to {class}`~anndata.experimental.pytorch.AnnLoader`. + +- Added `batch_converter` parameter for batch-level post-processing of data batches +- Added {func}`~anndata.experimental.pytorch.batch_dict_converter` helper function for converting batches to tensor dictionaries +- Fixed multiprocessing support (`num_workers > 0`) by implementing pickling for `AnnCollectionView` objects +- Batch converters now work seamlessly with both single-threaded and multi-threaded data loading + +{user}`ronamit` diff --git a/pyproject.toml b/pyproject.toml index 9b05b50eb..cde1f3286 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,10 @@ filterwarnings_when_strict = [ "default:Consolidated metadata is:UserWarning", "default:.*Structured:zarr.core.dtype.common.UnstableSpecificationWarning", "default:.*FixedLengthUTF32:zarr.core.dtype.common.UnstableSpecificationWarning", + "default:'oneOf' deprecated - use 'one_of':DeprecationWarning", + "default:'parseString' deprecated - use 'parse_string':DeprecationWarning", + "default:'resetCache' deprecated - use 'reset_cache':DeprecationWarning", + "default:'enablePackrat' deprecated - use 'enable_packrat':DeprecationWarning", ] python_files = "test_*.py" testpaths = [ diff --git a/src/anndata/experimental/multi_files/_anncollection.py b/src/anndata/experimental/multi_files/_anncollection.py index 77e833cbd..059f4fe63 100644 --- a/src/anndata/experimental/multi_files/_anncollection.py +++ b/src/anndata/experimental/multi_files/_anncollection.py @@ -291,6 +291,43 @@ def __init__(self, reference, convert, resolved_idx): self._convert_X = None self.convert = convert + # ------------------------------------------------------------------ + # Pickling support (worker-safe) + # ------------------------------------------------------------------ + + def __getstate__(self): + """Return minimal state for safe pickling across processes. + + We only serialise lightweight metadata – the on-disk store is reopened + in the child worker to avoid passing an open HDF5 handle. + """ + return { + "reference_state": self.reference.__getstate__() + if hasattr(self.reference, "__getstate__") + else None, + "oidx": self.oidx, + "vidx": self.vidx, + "convert": self.convert, + } + + def __setstate__(self, state): + from anndata.experimental.multi_files import ( + AnnCollection, # local import to avoid circular + ) + + # Rebuild from saved reference_state (in-memory collections) + parent = AnnCollection.__new__(AnnCollection) + if state["reference_state"] is not None: + parent.__setstate__(state["reference_state"]) + else: + msg = "Cannot restore AnnCollectionView without reference state" + raise ValueError(msg) + + # Recreate the view via slicing to reuse internal helpers + view = parent[state["oidx"], state["vidx"]] + self.__dict__.update(view.__dict__) + self.convert = state["convert"] + def _lazy_init_attr(self, attr: str, *, set_vidx: bool = False): if getattr(self, f"_{attr}_view") is not None: return @@ -779,7 +816,7 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915 ai_attr = getattr(a, attr) a0_attr = getattr(adatas[0], attr) new_keys = [] - for key in keys: + for key in keys or []: if key in ai_attr: a0_ashape = a0_attr[key].shape ai_ashape = ai_attr[key].shape @@ -806,6 +843,35 @@ def __init__( # noqa: PLR0912, PLR0913, PLR0915 self.indices_strict = indices_strict + # ------------------------------------------------------------------ + # Pickling support (worker-safe) + # ------------------------------------------------------------------ + + def __getstate__(self): + """Return state for pickling. For in-memory collections, we serialize all data.""" + return { + "adatas": self.adatas, + "join_obs": getattr(self, "_join_obs", "inner"), + "join_obsm": getattr(self, "_join_obsm", None), + "join_vars": getattr(self, "_join_vars", None), + "convert": self._convert, + "harmonize_dtypes": getattr(self, "_harmonize_dtypes", True), + "indices_strict": self.indices_strict, + } + + def __setstate__(self, state): + """Restore state from pickling.""" + # Reconstruct AnnCollection from saved adatas and parameters + self.__init__( + state["adatas"], + join_obs=state["join_obs"], + join_obsm=state["join_obsm"], + join_vars=state["join_vars"], + convert=state["convert"], + harmonize_dtypes=state["harmonize_dtypes"], + indices_strict=state["indices_strict"], + ) + def __getitem__(self, index: Index): oidx, vidx = _normalize_indices(index, self.obs_names, self.var_names) resolved_idx = self._resolve_idx(oidx, vidx) diff --git a/src/anndata/experimental/pytorch/__init__.py b/src/anndata/experimental/pytorch/__init__.py index 36c9441fe..813006543 100644 --- a/src/anndata/experimental/pytorch/__init__.py +++ b/src/anndata/experimental/pytorch/__init__.py @@ -1,5 +1,21 @@ from __future__ import annotations +from importlib.util import find_spec + +# public re-exports from ._annloader import AnnLoader -__all__ = ["AnnLoader"] +__all__: list[str] = ["AnnLoader"] + +# Only import batch_dict_converter if torch is available +if find_spec("torch"): + from .converters import to_tensor_dict as batch_dict_converter + + __all__ += ["batch_dict_converter"] +else: + # Provide a fallback that raises a helpful error + def batch_dict_converter(*args, **kwargs): + msg = "batch_dict_converter requires PyTorch. Install with: pip install torch" + raise ImportError(msg) + + __all__ += ["batch_dict_converter"] diff --git a/src/anndata/experimental/pytorch/_annloader.py b/src/anndata/experimental/pytorch/_annloader.py index d5e9fbf81..dbd3e5128 100644 --- a/src/anndata/experimental/pytorch/_annloader.py +++ b/src/anndata/experimental/pytorch/_annloader.py @@ -22,7 +22,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator, Sequence - from typing import TypeAlias, Union + from typing import Any, TypeAlias, Union from scipy.sparse import spmatrix @@ -118,6 +118,40 @@ def compose_convert(arr): return new_convert +class _WorkerCollateWrapper: + """Picklable wrapper for batch_converter to use in multiprocessing workers.""" + + def __init__(self, batch_converter): + self.batch_converter = batch_converter + + def __call__(self, batch): + # First, create a batch from AnnCollectionView objects by concatenating them + if len(batch) == 0: + return batch + + # Assume all samples are AnnCollectionView objects + first_sample = batch[0] + if not hasattr(first_sample, "reference"): + # Not an AnnCollectionView, fallback to default collate + from torch.utils.data._utils.collate import default_collate + + return default_collate(batch) + + # Create a batch view by concatenating the indices + reference = first_sample.reference + all_oidx = [] + all_vidx = first_sample.vidx # Assume same variables for all samples + + for sample in batch: + all_oidx.extend(sample.oidx) + + # Create a new view with all the indices + batch_view = reference[all_oidx, all_vidx] + + # Apply the batch converter to the combined view + return self.batch_converter(batch_view) + + # AnnLoader has the same arguments as DataLoader, but uses BatchIndexSampler by default class AnnLoader(DataLoader): """\ @@ -143,6 +177,9 @@ class AnnLoader(DataLoader): use_cuda Transfer pytorch tensors to the default cuda device after conversion. Only works if `use_default_converter=True` + batch_converter + Optional callable to transform each batch after collation. + Works with both single-threaded and multi-threaded data loading. **kwargs Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also arguments for `AnnCollection` initialization. @@ -157,6 +194,7 @@ def __init__( shuffle: bool = False, use_default_converter: bool = True, use_cuda: bool = False, + batch_converter: Callable[[Any], Any] | None = None, **kwargs, ): if isinstance(adatas, AnnData): @@ -199,6 +237,20 @@ def __init__( dataset.convert, _converter, dict(dataset.attrs_keys, X=[]) ) + # Remove in case user passed via **kwargs (for forward-compat) + batch_converter = kwargs.pop("batch_converter", batch_converter) + self._batch_converter = batch_converter + + # If workers >0 and user supplied a converter, apply it inside worker via custom collate + num_workers = kwargs.get("num_workers", 0) + if ( + batch_converter is not None + and num_workers > 0 + and "collate_fn" not in kwargs + ): + kwargs["collate_fn"] = _WorkerCollateWrapper(batch_converter) + # Set batch_converter to None so main process doesn't apply it again + self._batch_converter = None has_sampler = "sampler" in kwargs has_batch_sampler = "batch_sampler" in kwargs @@ -232,3 +284,11 @@ def __init__( super().__init__(dataset, batch_size=None, sampler=sampler, **kwargs) else: super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, **kwargs) + + def __iter__(self): # type: ignore[override] + for batch in super().__iter__(): + yield ( + self._batch_converter(batch) + if self._batch_converter is not None + else batch + ) diff --git a/src/anndata/experimental/pytorch/converters.py b/src/anndata/experimental/pytorch/converters.py new file mode 100644 index 000000000..f2a60b1c2 --- /dev/null +++ b/src/anndata/experimental/pytorch/converters.py @@ -0,0 +1,86 @@ +"""Helper converters for AnnLoader batches. + +This module provides convenience converters that can be passed to the +``batch_converter`` parameter of :pyclass:`~anndata.experimental.pytorch.AnnLoader`. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any + +import pandas as pd + +if find_spec("torch") or TYPE_CHECKING: + import torch + from torch import Tensor +else: + torch = None # type: ignore + Tensor = Any # type: ignore + +__all__ = ["to_tensor_dict"] + + +def _to_tensor(arr) -> Tensor | Any: + """Best-effort conversion of *arr* to ``torch.Tensor``. + + Falls back to returning *arr* unchanged if torch or numpy is not available. + """ + if torch is None: + return arr + + if isinstance(arr, torch.Tensor): + return arr + try: + import numpy as np + from scipy.sparse import issparse + + if issparse(arr): + arr = arr.toarray() + if isinstance(arr, (np.ndarray, list)): + return torch.tensor(arr) + except ImportError: # pragma: no cover + pass + return arr + + +def to_tensor_dict(batch: Any) -> dict[str, Any]: + """Convert an AnnLoader batch to a plain ``dict`` of tensors/arrays. + + * ``X`` → ``"x"`` + * each column in ``obs`` becomes a key in the output dict + * if *batch* is already a mapping it is returned as a *shallow copy*. + """ + # If user already returns dict-like we preserve it + if isinstance(batch, Mapping): + return dict(batch) + + out: dict[str, Any] = {} + + # AnnCollectionView has .X and .obs attributes + if hasattr(batch, "X"): + out["x"] = _to_tensor(batch.X) + + if hasattr(batch, "obs") and batch.obs is not None: + obs_data = batch.obs + # Handle pandas DataFrame + if isinstance(obs_data, pd.DataFrame): + for col in obs_data.columns: + # ensure unique keys – users can post-process if needed + out[col] = _to_tensor(obs_data[col].to_numpy()) + # Handle AnnCollection MapObsView (can be converted to dict directly) + elif hasattr(obs_data, "to_dict"): + obs_dict = obs_data.to_dict() + for key, value in obs_dict.items(): + out[key] = _to_tensor(value) + # Handle generic dict-like objects + elif hasattr(obs_data, "keys") and callable(obs_data.keys): + try: + obs_dict = dict(obs_data) + for key, value in obs_dict.items(): + out[key] = _to_tensor(value) + except (TypeError, AttributeError, ValueError): + pass # Skip if conversion fails + + return out diff --git a/tests/pytorch/__init__.py b/tests/pytorch/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/pytorch/test_batch_converter.py b/tests/pytorch/test_batch_converter.py new file mode 100644 index 000000000..283cfeafa --- /dev/null +++ b/tests/pytorch/test_batch_converter.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import anndata as ad + +pytest.importorskip("torch") + +import torch + +from anndata.experimental.pytorch import AnnLoader, batch_dict_converter + + +def _make_dummy_adata(n_obs: int = 10, n_vars: int = 4): + X = np.random.rand(n_obs, n_vars).astype(np.float32) + obs = {"group": np.arange(n_obs) % 2, "category": ["A", "B"] * (n_obs // 2)} + return ad.AnnData(X=X, obs=obs) + + +def test_batch_converter_default_behavior(): + """Test that without batch_converter, we get AnnCollectionView.""" + adata = _make_dummy_adata(16, 3) + loader = AnnLoader(adata, batch_size=4) + batch = next(iter(loader)) + + # Should be AnnCollectionView without converter + assert hasattr(batch, "X") + assert hasattr(batch, "obs") + + +def test_batch_converter_returns_dict(): + """Test that batch_dict_converter returns proper dict format.""" + adata = _make_dummy_adata(16, 3) + loader = AnnLoader( + adata, + batch_size=4, + batch_converter=batch_dict_converter, + num_workers=0, # Single-threaded for now + ) + batch = next(iter(loader)) + + assert isinstance(batch, dict) + assert "x" in batch + assert isinstance(batch["x"], torch.Tensor) + assert batch["x"].shape == (4, 3) + assert "group" in batch + assert isinstance(batch["group"], torch.Tensor) + assert "category" in batch + + +def test_custom_batch_converter(): + """Test that custom batch converters work.""" + adata = _make_dummy_adata(8, 2) + + def custom_converter(batch): + result = batch_dict_converter(batch) + result["custom_field"] = torch.tensor([42]) + result["x_sum"] = result["x"].sum() + return result + + loader = AnnLoader( + adata, + batch_size=4, + batch_converter=custom_converter, + num_workers=0, + ) + batch = next(iter(loader)) + + assert isinstance(batch, dict) + assert "x" in batch + assert "group" in batch + assert "custom_field" in batch + assert "x_sum" in batch + assert batch["custom_field"].item() == 42 + assert isinstance(batch["x_sum"], torch.Tensor) + + +def test_batch_converter_multiprocessing_works(): + """Test that batch converter now works with num_workers > 0.""" + adata = _make_dummy_adata(16, 3) + loader = AnnLoader( + adata, + batch_size=4, + batch_converter=batch_dict_converter, + num_workers=2, + ) + + # This should now work with our multiprocessing fix + batch = next(iter(loader)) + assert isinstance(batch, dict) + assert "x" in batch + assert batch["x"].shape == (4, 3) diff --git a/tests/pytorch/test_batch_converter_mp.py b/tests/pytorch/test_batch_converter_mp.py new file mode 100644 index 000000000..49447ac4a --- /dev/null +++ b/tests/pytorch/test_batch_converter_mp.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import anndata as ad + +pytest.importorskip("torch") + +from anndata.experimental.pytorch import AnnLoader, batch_dict_converter + + +@pytest.mark.filterwarnings("ignore:Series.__getitem__") +def test_worker_safe_batch_converter(): + """AnnLoader should work with num_workers > 0 when batch_converter is supplied.""" + adata = ad.AnnData(X=np.random.rand(32, 4).astype(np.float32)) + + loader = AnnLoader( + adata, + batch_size=8, + batch_converter=batch_dict_converter, + num_workers=2, + ) + + batch = next(iter(loader)) + assert isinstance(batch, dict) + assert batch["x"].shape == (8, 4) + + +def test_default_collate_fails_with_anncollection(): + """Sanity-check that vanilla DataLoader still fails, documenting why fix is needed.""" + from torch.utils.data import DataLoader + + from anndata.experimental.multi_files import AnnCollection + + adata1 = ad.AnnData(X=np.random.rand(4, 3).astype(np.float32)) + adata2 = ad.AnnData(X=np.random.rand(4, 3).astype(np.float32)) + coll = AnnCollection([adata1, adata2]) + dataset = coll # AnnCollection implements __getitem__/__len__ + + failing_loader = DataLoader(dataset, batch_size=4, num_workers=2) + with pytest.raises(TypeError): + next(iter(failing_loader))