diff --git a/src/anndata/experimental/pytorch/_annloader.py b/src/anndata/experimental/pytorch/_annloader.py index 2e3525fb1..ae37fd8ca 100644 --- a/src/anndata/experimental/pytorch/_annloader.py +++ b/src/anndata/experimental/pytorch/_annloader.py @@ -11,7 +11,8 @@ from scipy.sparse import issparse from ..._core.anndata import AnnData -from ...compat import old_positionals +from ..._warnings import warn +from ...compat import Empty, old_positionals from ..multi_files._anncollection import AnnCollection, _ConcatViewMixin if find_spec("torch") or TYPE_CHECKING: @@ -22,6 +23,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator, Sequence + from typing import Literal from scipy.sparse import spmatrix @@ -70,21 +72,19 @@ def __len__(self) -> int: return length -# maybe replace use_cuda with explicit device option -def default_converter(arr: Array, *, use_cuda: bool, pin_memory: bool): +def default_converter( + arr: Array, *, device: Literal["cpu", "cuda", "mps"] = "cpu", pin_memory: bool +): if isinstance(arr, torch.Tensor): - if use_cuda: - arr = arr.cuda() - elif pin_memory: + arr = arr.to(device) + if device == "cpu" and pin_memory: arr = arr.pin_memory() elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number): if issparse(arr): arr = arr.toarray() - if use_cuda: - arr = torch.tensor(arr, device="cuda") - else: - arr = torch.tensor(arr) - arr = arr.pin_memory() if pin_memory else arr + arr = torch.tensor(arr, device=device) + if device == "cpu" and pin_memory: + arr = arr.pin_memory() return arr @@ -135,12 +135,15 @@ class AnnLoader(DataLoader): Set to `True` to have the data reshuffled at every epoch. use_default_converter Use the default converter to convert arrays to pytorch tensors, transfer to - the default cuda device (if `use_cuda=True`), do memory pinning (if `pin_memory=True`). + the specified device (if `device` is set), do memory pinning (if `pin_memory=True`). If you pass an AnnCollection object with prespecified converters, the default converter won't overwrite these converters but will be applied on top of them. + device + Transfer pytorch tensors to the specified device after conversion. + Only works if `use_default_converter=True`. use_cuda - Transfer pytorch tensors to the default cuda device after conversion. - Only works if `use_default_converter=True` + .. deprecated:: + Use `device='cuda'` instead. **kwargs Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also arguments for `AnnCollection` initialization. @@ -154,9 +157,23 @@ def __init__( batch_size: int = 1, shuffle: bool = False, use_default_converter: bool = True, - use_cuda: bool = False, + device: Literal["cpu", "cuda", "mps"] = "cpu", + use_cuda: bool = Empty.TOKEN, **kwargs, ): + if use_cuda is not Empty.TOKEN: + if device != "cpu": + msg = ( + "Cannot specify both 'device' and 'use_cuda'. Use 'device' instead." + ) + raise ValueError(msg) + warn( + "'use_cuda' is deprecated, use 'device' instead. " + "Pass device='cuda' instead of use_cuda=True.", + FutureWarning, + ) + device = "cuda" if use_cuda else "cpu" + if isinstance(adatas, AnnData): adatas = [adatas] @@ -191,7 +208,7 @@ def __init__( if use_default_converter: pin_memory = kwargs.pop("pin_memory", False) _converter = partial( - default_converter, use_cuda=use_cuda, pin_memory=pin_memory + default_converter, device=device, pin_memory=pin_memory ) dataset.convert = _convert_on_top( dataset.convert, _converter, dict(dataset.attrs_keys, X=[]) diff --git a/tests/test_annloader.py b/tests/test_annloader.py new file mode 100644 index 000000000..8ab3d29da --- /dev/null +++ b/tests/test_annloader.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import anndata as ad + +pytest.importorskip("torch") + +from anndata.experimental.pytorch import AnnLoader + + +@pytest.fixture +def adata(): + return ad.AnnData(X=np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + + +def test_annloader_default_device(adata): + """AnnLoader with default device='cpu' produces CPU tensors.""" + loader = AnnLoader(adata, batch_size=2) + batch = next(iter(loader)) + assert batch.X.device.type == "cpu" + + +def test_annloader_explicit_cpu_device(adata): + """AnnLoader with explicit device='cpu' produces CPU tensors.""" + loader = AnnLoader(adata, batch_size=2, device="cpu") + batch = next(iter(loader)) + assert batch.X.device.type == "cpu" + + +def test_annloader_use_cuda_deprecation_warning(adata): + """Passing use_cuda emits a FutureWarning.""" + with pytest.warns(FutureWarning, match="use_cuda.*deprecated"): + # use_cuda=False should still emit warning (parameter was explicitly passed) + AnnLoader(adata, batch_size=2, use_cuda=False) + + +def test_annloader_use_cuda_and_device_conflict(adata): + """Passing both use_cuda and device raises ValueError.""" + with pytest.raises(ValueError, match="Cannot specify both"): + AnnLoader(adata, batch_size=2, use_cuda=True, device="cuda")