diff --git a/docs/conf.py b/docs/conf.py index a14cc70ea..bbbf55f87 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -173,6 +173,7 @@ def res( "numpy.dtypes.StringDType": ("py:attr", "numpy.dtypes.StringDType"), "pandas.DataFrame.iloc": ("py:attr", "pandas.DataFrame.iloc"), "pandas.DataFrame.loc": ("py:attr", "pandas.DataFrame.loc"), + "pandas.core.dtypes.dtypes.BaseMaskedDtype": "pandas.api.extensions.ExtensionDtype", # should be fixed soon: https://github.com/tox-dev/sphinx-autodoc-typehints/pull/516 "types.EllipsisType": ("py:data", "types.EllipsisType"), "pathlib._local.Path": "pathlib.Path", diff --git a/docs/release-notes/2287.fix.md b/docs/release-notes/2287.fix.md new file mode 100644 index 000000000..0c3a3299e --- /dev/null +++ b/docs/release-notes/2287.fix.md @@ -0,0 +1 @@ +Fix {obj}`numpy.uint` support in {func}`anndata.experimental.read_lazy` and {func}`anndata.experimental.read_elem_lazy` {user}`flying-sheep` diff --git a/src/anndata/_io/utils.py b/src/anndata/_io/utils.py index 223412fdd..b3cc6a56d 100644 --- a/src/anndata/_io/utils.py +++ b/src/anndata/_io/utils.py @@ -1,17 +1,22 @@ from __future__ import annotations from collections.abc import Callable -from functools import WRAPPER_ASSIGNMENTS, wraps +from functools import WRAPPER_ASSIGNMENTS, cache, wraps from itertools import pairwise from typing import TYPE_CHECKING, Literal, cast from warnings import warn +import numpy as np +import pandas as pd + from .._core.sparse_dataset import BaseCompressedSparseDataset if TYPE_CHECKING: from collections.abc import Callable, Mapping from typing import Any, Literal + from pandas.core.dtypes.dtypes import BaseMaskedDtype + from .._types import StorageType, _WriteInternal from ..compat import H5Group, ZarrGroup from ..typing import RWAble @@ -119,6 +124,30 @@ def check_key(key): raise TypeError(msg) +@cache +def pandas_nullable_dtype(dtype: np.dtype) -> BaseMaskedDtype: + """Infer nullable dtype from numpy dtype. + + There is no public pandas API for this, so this is the cleanest way. + See + """ + try: + from pandas.core.dtypes.dtypes import BaseMaskedDtype + except ImportError: + pass + else: + return BaseMaskedDtype.from_numpy_dtype(dtype) + + match dtype.kind: + case "b": + array_type = pd.arrays.BooleanArray + case "i" | "u": + array_type = pd.arrays.IntegerArray + case _: + raise NotImplementedError + return array_type(np.ones(1, dtype), np.ones(1, bool)).dtype + + # ------------------------------------------------------------------------------- # Generic functions # ------------------------------------------------------------------------------- diff --git a/src/anndata/experimental/backed/_lazy_arrays.py b/src/anndata/experimental/backed/_lazy_arrays.py index 0b87de381..f073473e0 100644 --- a/src/anndata/experimental/backed/_lazy_arrays.py +++ b/src/anndata/experimental/backed/_lazy_arrays.py @@ -10,6 +10,7 @@ from anndata._core.views import as_view from anndata._io.specs.lazy_methods import get_chunksize +from ..._io.utils import pandas_nullable_dtype from ..._settings import settings from ...compat import ( NULLABLE_NUMPY_STRING_TYPE, @@ -26,8 +27,9 @@ from pathlib import Path from typing import Literal + from numpy.typing import NDArray from pandas._libs.missing import NAType - from pandas.core.dtypes.base import ExtensionDtype + from pandas.core.dtypes.dtypes import BaseMaskedDtype from anndata.compat import ZarrGroup @@ -172,40 +174,33 @@ def __init__( self.file_format = "zarr" if isinstance(mask, ZarrArray) else "h5" self.elem_name = elem_name - def __getitem__(self, key: ExplicitIndexer) -> PandasExtensionArray | np.ndarray: - from xarray.core.extension_array import PandasExtensionArray - + def __getitem__( + self, key: ExplicitIndexer + ) -> PandasExtensionArray | NDArray[np.str_]: values = self._values[key] mask = self._mask[key] - if self._dtype_str == "nullable-integer": - # numpy does not support nan ints - extension_array = pd.arrays.IntegerArray(values, mask=mask) - elif self._dtype_str == "nullable-boolean": - extension_array = pd.arrays.BooleanArray(values, mask=mask) - elif self._dtype_str == "nullable-string-array": + + if isinstance(self.dtype, np.dtypes.StringDType): # https://github.com/pydata/xarray/issues/10419 values = values.astype(self.dtype) values[mask] = pd.NA return values - else: - msg = f"Invalid dtype_str {self._dtype_str}" - raise RuntimeError(msg) - return PandasExtensionArray(extension_array) + + from xarray.core.extension_array import PandasExtensionArray + + cls = self.dtype.construct_array_type() + return PandasExtensionArray(cls(values, mask)) @cached_property - def dtype(self) -> np.dtypes.StringDType[NAType] | ExtensionDtype: - if self._dtype_str == "nullable-integer": - return pd.array( - [], - dtype=str(pd.api.types.pandas_dtype(self._values.dtype)).capitalize(), - ).dtype - elif self._dtype_str == "nullable-boolean": - return pd.BooleanDtype() - elif self._dtype_str == "nullable-string-array": + def dtype(self) -> BaseMaskedDtype | np.dtypes.StringDType[NAType]: + if self._dtype_str == "nullable-string-array": # https://github.com/pydata/xarray/issues/10419 return NULLABLE_NUMPY_STRING_TYPE - msg = f"Invalid dtype_str {self._dtype_str}" - raise RuntimeError(msg) + try: + return pandas_nullable_dtype(self._values.dtype) + except NotImplementedError: + msg = f"Invalid dtype_str {self._dtype_str}" + raise RuntimeError(msg) from None @_subset.register(XDataArray) diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 7cf7011be..bfaf8e0d7 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -75,12 +75,13 @@ DEFAULT_COL_TYPES = ( pd.CategoricalDtype(ordered=False), pd.CategoricalDtype(ordered=True), - np.int64, - np.float64, - np.uint8, - np.bool_, - pd.BooleanDtype, - pd.Int32Dtype, + np.dtype(np.int64), + np.dtype(np.float64), + np.dtype(np.uint8), + np.dtype(bool), + pd.BooleanDtype(), + pd.Int32Dtype(), + pd.UInt8Dtype(), ) @@ -108,13 +109,11 @@ def gen_vstr_recarray(m, n, dtype=None): def issubdtype( - a: np.dtype | pd.api.extensions.ExtensionDtype | type, - b: type[DT] | tuple[type[DT], ...], + a: np.dtype | pd.api.extensions.ExtensionDtype, b: type[DT] | tuple[type[DT], ...] ) -> TypeGuard[DT]: + assert not isinstance(a, type) if isinstance(b, tuple): return any(issubdtype(a, t) for t in b) - if isinstance(a, type) and issubclass(a, pd.api.extensions.ExtensionDtype): - return issubclass(a, b) if isinstance(a, pd.api.extensions.ExtensionDtype): return isinstance(a, b) try: @@ -126,6 +125,7 @@ def issubdtype( def gen_random_column( # noqa: PLR0911 n: int, dtype: np.dtype | pd.api.extensions.ExtensionDtype ) -> tuple[str, np.ndarray | pd.api.extensions.ExtensionArray]: + assert isinstance(dtype, np.dtype | pd.api.extensions.ExtensionDtype) if issubdtype(dtype, pd.CategoricalDtype): # TODO: Think about allowing index to be passed for n letters = np.fromiter(iter(ascii_letters), "U1") @@ -142,13 +142,9 @@ def gen_random_column( # noqa: PLR0911 ), ) if issubdtype(dtype, IntegerDtype): - return ( - "nullable-int", - pd.arrays.IntegerArray( - np.random.randint(0, 1000, size=n, dtype=np.int32), - mask=np.random.randint(0, 2, size=n, dtype=bool), - ), - ) + name, values = gen_random_column(n, dtype.numpy_dtype) + mask = np.random.randint(0, 2, size=n, dtype=bool) + return f"nullable-{name}", pd.arrays.IntegerArray(values, mask) if issubdtype(dtype, pd.StringDtype): letters = np.fromiter(iter(ascii_letters), "U1") array = pd.array(np.random.choice(letters, n), dtype=pd.StringDtype()) @@ -162,7 +158,7 @@ def gen_random_column( # noqa: PLR0911 if not issubdtype(dtype, np.number): # pragma: no cover pytest.fail(f"Unexpected dtype: {dtype}") - n_bits = 8 * (dtype().itemsize if isinstance(dtype, type) else dtype.itemsize) + n_bits = 8 * dtype.itemsize if issubdtype(dtype, np.unsignedinteger): return f"uint{n_bits}", np.random.randint(0, 255, n, dtype=dtype) diff --git a/tests/lazy/conftest.py b/tests/lazy/conftest.py index 50d0f84c3..b71ac75e6 100644 --- a/tests/lazy/conftest.py +++ b/tests/lazy/conftest.py @@ -93,8 +93,8 @@ def adata_remote_orig_with_path( orig = gen_adata( (100, 110), mtx_format, - obs_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype), - var_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype), + obs_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype()), + var_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype()), obsm_types=(*DEFAULT_KEY_TYPES, AwkArray), varm_types=(*DEFAULT_KEY_TYPES, AwkArray), ) diff --git a/tests/lazy/test_concat.py b/tests/lazy/test_concat.py index c9b25bcb2..fe256fdfc 100644 --- a/tests/lazy/test_concat.py +++ b/tests/lazy/test_concat.py @@ -147,17 +147,21 @@ def test_concat_to_memory_obs( def test_concat_to_memory_obs_dtypes( + subtests: pytest.Subtests, lazy_adatas_for_concat: list[AnnData], join: Join_T, -): +) -> None: concated_remote = ad.concat(lazy_adatas_for_concat, join=join) # check preservation of non-categorical dtypes on the concat axis - assert concated_remote.obs["int64"].dtype == "int64" - assert concated_remote.obs["uint8"].dtype == "uint8" - assert concated_remote.obs["nullable-int"].dtype == "int32" - assert concated_remote.obs["float64"].dtype == "float64" - assert concated_remote.obs["bool"].dtype == "bool" - assert concated_remote.obs["nullable-bool"].dtype == "bool" + for name in concated_remote.obs.columns: + dtype = name.removeprefix("nullable-") + with subtests.test(col=name): + try: + assert concated_remote.obs[name].dtype == dtype + except AssertionError: + if "cat" in name: + pytest.xfail("categorical dtypes are not preserved") + raise def test_concat_to_memory_var( diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index cb1a2b2b9..8572c2a13 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -1509,7 +1509,7 @@ def test_concat_size_0_axis( """Regression test for https://github.com/scverse/anndata/issues/526""" axis, axis_name = merge._resolve_axis(axis_name) alt_axis = 1 - axis - col_dtypes = (*DEFAULT_COL_TYPES, pd.StringDtype) + col_dtypes = (*DEFAULT_COL_TYPES, pd.StringDtype()) a = gen_adata( (5, 7), obs_dtypes=col_dtypes, diff --git a/tests/test_dask_view_mem.py b/tests/test_dask_view_mem.py index b7fe7c72c..0b2720397 100644 --- a/tests/test_dask_view_mem.py +++ b/tests/test_dask_view_mem.py @@ -26,7 +26,12 @@ def attr_name(request): return request.param -@pytest.fixture(params=[True, False]) +@pytest.fixture( + params=[ + pytest.param(True, id="give_chunks"), + pytest.param(False, id="no_give_chunks"), + ] +) def give_chunks(request): return request.param diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5fa26f16f..da9e923e8 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -87,7 +87,15 @@ def test_gen_awkward(shape, datashape): assert arr.type == arr_type -@pytest.mark.parametrize("dtype", [*DEFAULT_COL_TYPES, pd.StringDtype]) +@pytest.mark.parametrize( + "dtype", + [*DEFAULT_COL_TYPES, pd.StringDtype()], + ids=lambda dt: ( + f"{dt}-{'' if dt.ordered else 'un'}ordered" + if isinstance(dt, pd.CategoricalDtype) + else str(dt) + ), +) def test_gen_random_column(dtype): _, col = gen_random_column(10, dtype) assert len(col) == 10 @@ -96,7 +104,7 @@ def test_gen_random_column(dtype): assert issubdtype(col.dtype, pd.CategoricalDtype) assert col.dtype.ordered == dtype.ordered else: - assert issubdtype(col.dtype, dtype) + assert col.dtype == dtype # Does this work for every warning? diff --git a/tests/test_views.py b/tests/test_views.py index 2f206f921..83bc58748 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -158,7 +158,7 @@ def test_view_subset_shapes(): adata = gen_adata((20, 10), **GEN_ADATA_DASK_ARGS) view = adata[:, ::2] - assert view.var.shape == (5, 8) + assert view.var.shape == (5, adata.var.shape[1]) assert {k: v.shape[0] for k, v in view.varm.items()} == dict.fromkeys(view.varm, 5)