From 8afd1816a324d38593f8bb970e58cadc8b416bc4 Mon Sep 17 00:00:00 2001 From: Liudeng Zhang Date: Sun, 8 Mar 2026 22:43:07 -0500 Subject: [PATCH 1/5] Replace use_cuda with device parameter in AnnLoader --- .../experimental/pytorch/_annloader.py | 47 +++++++++++++------ tests/test_annloader.py | 42 +++++++++++++++++ 2 files changed, 74 insertions(+), 15 deletions(-) create mode 100644 tests/test_annloader.py diff --git a/src/anndata/experimental/pytorch/_annloader.py b/src/anndata/experimental/pytorch/_annloader.py index 2e3525fb1..525e35c81 100644 --- a/src/anndata/experimental/pytorch/_annloader.py +++ b/src/anndata/experimental/pytorch/_annloader.py @@ -11,6 +11,7 @@ from scipy.sparse import issparse from ..._core.anndata import AnnData +from ..._warnings import warn from ...compat import old_positionals from ..multi_files._anncollection import AnnCollection, _ConcatViewMixin @@ -27,6 +28,8 @@ type Array = torch.Tensor | np.ndarray | spmatrix +_UNSET = object() + # Custom sampler to get proper batches instead of joined separate indices # maybe move to multi_files @@ -70,21 +73,17 @@ def __len__(self) -> int: return length -# maybe replace use_cuda with explicit device option -def default_converter(arr: Array, *, use_cuda: bool, pin_memory: bool): +def default_converter(arr: Array, *, device: str = "cpu", pin_memory: bool): if isinstance(arr, torch.Tensor): - if use_cuda: - arr = arr.cuda() - elif pin_memory: + arr = arr.to(device) + if device == "cpu" and pin_memory: arr = arr.pin_memory() elif arr.dtype.name != "category" and np.issubdtype(arr.dtype, np.number): if issparse(arr): arr = arr.toarray() - if use_cuda: - arr = torch.tensor(arr, device="cuda") - else: - arr = torch.tensor(arr) - arr = arr.pin_memory() if pin_memory else arr + arr = torch.tensor(arr, device=device) + if device == "cpu" and pin_memory: + arr = arr.pin_memory() return arr @@ -135,12 +134,15 @@ class AnnLoader(DataLoader): Set to `True` to have the data reshuffled at every epoch. use_default_converter Use the default converter to convert arrays to pytorch tensors, transfer to - the default cuda device (if `use_cuda=True`), do memory pinning (if `pin_memory=True`). + the specified device (if `device` is set), do memory pinning (if `pin_memory=True`). If you pass an AnnCollection object with prespecified converters, the default converter won't overwrite these converters but will be applied on top of them. + device + Transfer pytorch tensors to the specified device after conversion. + Only works if `use_default_converter=True`. use_cuda - Transfer pytorch tensors to the default cuda device after conversion. - Only works if `use_default_converter=True` + .. deprecated:: + Use `device='cuda'` instead. **kwargs Arguments for PyTorch DataLoader. If `adatas` is not an `AnnCollection` object, then also arguments for `AnnCollection` initialization. @@ -154,9 +156,24 @@ def __init__( batch_size: int = 1, shuffle: bool = False, use_default_converter: bool = True, - use_cuda: bool = False, + device: str = "cpu", + use_cuda: bool = _UNSET, **kwargs, ): + if use_cuda is not _UNSET: + if device != "cpu": + msg = ( + "Cannot specify both 'device' and 'use_cuda'. " + "Use 'device' instead." + ) + raise ValueError(msg) + warn( + "'use_cuda' is deprecated, use 'device' instead. " + "Pass device='cuda' instead of use_cuda=True.", + FutureWarning, + ) + device = "cuda" if use_cuda else "cpu" + if isinstance(adatas, AnnData): adatas = [adatas] @@ -191,7 +208,7 @@ def __init__( if use_default_converter: pin_memory = kwargs.pop("pin_memory", False) _converter = partial( - default_converter, use_cuda=use_cuda, pin_memory=pin_memory + default_converter, device=device, pin_memory=pin_memory ) dataset.convert = _convert_on_top( dataset.convert, _converter, dict(dataset.attrs_keys, X=[]) diff --git a/tests/test_annloader.py b/tests/test_annloader.py new file mode 100644 index 000000000..8ab3d29da --- /dev/null +++ b/tests/test_annloader.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import numpy as np +import pytest + +import anndata as ad + +pytest.importorskip("torch") + +from anndata.experimental.pytorch import AnnLoader + + +@pytest.fixture +def adata(): + return ad.AnnData(X=np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) + + +def test_annloader_default_device(adata): + """AnnLoader with default device='cpu' produces CPU tensors.""" + loader = AnnLoader(adata, batch_size=2) + batch = next(iter(loader)) + assert batch.X.device.type == "cpu" + + +def test_annloader_explicit_cpu_device(adata): + """AnnLoader with explicit device='cpu' produces CPU tensors.""" + loader = AnnLoader(adata, batch_size=2, device="cpu") + batch = next(iter(loader)) + assert batch.X.device.type == "cpu" + + +def test_annloader_use_cuda_deprecation_warning(adata): + """Passing use_cuda emits a FutureWarning.""" + with pytest.warns(FutureWarning, match="use_cuda.*deprecated"): + # use_cuda=False should still emit warning (parameter was explicitly passed) + AnnLoader(adata, batch_size=2, use_cuda=False) + + +def test_annloader_use_cuda_and_device_conflict(adata): + """Passing both use_cuda and device raises ValueError.""" + with pytest.raises(ValueError, match="Cannot specify both"): + AnnLoader(adata, batch_size=2, use_cuda=True, device="cuda") From 8447a5f754f79e8dbb197e500d199a25c335b9d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Mar 2026 03:43:48 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anndata/experimental/pytorch/_annloader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anndata/experimental/pytorch/_annloader.py b/src/anndata/experimental/pytorch/_annloader.py index 525e35c81..55e132326 100644 --- a/src/anndata/experimental/pytorch/_annloader.py +++ b/src/anndata/experimental/pytorch/_annloader.py @@ -163,8 +163,7 @@ def __init__( if use_cuda is not _UNSET: if device != "cpu": msg = ( - "Cannot specify both 'device' and 'use_cuda'. " - "Use 'device' instead." + "Cannot specify both 'device' and 'use_cuda'. Use 'device' instead." ) raise ValueError(msg) warn( From 6f01e7490a2c0b4df0c62af0e76965b18c982a8d Mon Sep 17 00:00:00 2001 From: Liudeng Zhang Date: Tue, 10 Mar 2026 14:46:16 -0500 Subject: [PATCH 3/5] Use Empty.TOKEN sentinel and Literal type for device parameter --- src/anndata/experimental/pytorch/_annloader.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/anndata/experimental/pytorch/_annloader.py b/src/anndata/experimental/pytorch/_annloader.py index 55e132326..22a2f5ea6 100644 --- a/src/anndata/experimental/pytorch/_annloader.py +++ b/src/anndata/experimental/pytorch/_annloader.py @@ -5,14 +5,14 @@ from functools import partial from importlib.util import find_spec from math import ceil -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np from scipy.sparse import issparse from ..._core.anndata import AnnData from ..._warnings import warn -from ...compat import old_positionals +from ...compat import Empty, old_positionals from ..multi_files._anncollection import AnnCollection, _ConcatViewMixin if find_spec("torch") or TYPE_CHECKING: @@ -28,8 +28,6 @@ type Array = torch.Tensor | np.ndarray | spmatrix -_UNSET = object() - # Custom sampler to get proper batches instead of joined separate indices # maybe move to multi_files @@ -73,7 +71,7 @@ def __len__(self) -> int: return length -def default_converter(arr: Array, *, device: str = "cpu", pin_memory: bool): +def default_converter(arr: Array, *, device: Literal["cpu", "cuda", "mps"] = "cpu", pin_memory: bool): if isinstance(arr, torch.Tensor): arr = arr.to(device) if device == "cpu" and pin_memory: @@ -156,11 +154,11 @@ def __init__( batch_size: int = 1, shuffle: bool = False, use_default_converter: bool = True, - device: str = "cpu", - use_cuda: bool = _UNSET, + device: Literal["cpu", "cuda", "mps"] = "cpu", + use_cuda: bool = Empty.TOKEN, **kwargs, ): - if use_cuda is not _UNSET: + if use_cuda is not Empty.TOKEN: if device != "cpu": msg = ( "Cannot specify both 'device' and 'use_cuda'. Use 'device' instead." From 57b35155af18cd2e9abbe64fa0bb48dc84bb01ab Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 19:46:47 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anndata/experimental/pytorch/_annloader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anndata/experimental/pytorch/_annloader.py b/src/anndata/experimental/pytorch/_annloader.py index 22a2f5ea6..26bc11b1a 100644 --- a/src/anndata/experimental/pytorch/_annloader.py +++ b/src/anndata/experimental/pytorch/_annloader.py @@ -71,7 +71,9 @@ def __len__(self) -> int: return length -def default_converter(arr: Array, *, device: Literal["cpu", "cuda", "mps"] = "cpu", pin_memory: bool): +def default_converter( + arr: Array, *, device: Literal["cpu", "cuda", "mps"] = "cpu", pin_memory: bool +): if isinstance(arr, torch.Tensor): arr = arr.to(device) if device == "cpu" and pin_memory: From a520fbc1a713e46a5f37f90842791693bbd20ab8 Mon Sep 17 00:00:00 2001 From: Liudeng Zhang Date: Tue, 10 Mar 2026 15:01:54 -0500 Subject: [PATCH 5/5] Move Literal import into TYPE_CHECKING block --- src/anndata/experimental/pytorch/_annloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anndata/experimental/pytorch/_annloader.py b/src/anndata/experimental/pytorch/_annloader.py index 26bc11b1a..ae37fd8ca 100644 --- a/src/anndata/experimental/pytorch/_annloader.py +++ b/src/anndata/experimental/pytorch/_annloader.py @@ -5,7 +5,7 @@ from functools import partial from importlib.util import find_spec from math import ceil -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING import numpy as np from scipy.sparse import issparse @@ -23,6 +23,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Generator, Sequence + from typing import Literal from scipy.sparse import spmatrix