Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def res(
"numpy.dtypes.StringDType": ("py:attr", "numpy.dtypes.StringDType"),
"pandas.DataFrame.iloc": ("py:attr", "pandas.DataFrame.iloc"),
"pandas.DataFrame.loc": ("py:attr", "pandas.DataFrame.loc"),
"pandas.core.dtypes.dtypes.BaseMaskedDtype": "pandas.api.extensions.ExtensionDtype",
# should be fixed soon: https://github.com/tox-dev/sphinx-autodoc-typehints/pull/516
"types.EllipsisType": ("py:data", "types.EllipsisType"),
"pathlib._local.Path": "pathlib.Path",
Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/2287.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix {obj}`numpy.uint` support in {func}`anndata.experimental.read_lazy` and {func}`anndata.experimental.read_elem_lazy` {user}`flying-sheep`
31 changes: 30 additions & 1 deletion src/anndata/_io/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
from __future__ import annotations

from collections.abc import Callable
from functools import WRAPPER_ASSIGNMENTS, wraps
from functools import WRAPPER_ASSIGNMENTS, cache, wraps
from itertools import pairwise
from typing import TYPE_CHECKING, Literal, cast
from warnings import warn

import numpy as np
import pandas as pd

from .._core.sparse_dataset import BaseCompressedSparseDataset

if TYPE_CHECKING:
from collections.abc import Callable, Mapping
from typing import Any, Literal

from pandas.core.dtypes.dtypes import BaseMaskedDtype

from .._types import StorageType, _WriteInternal
from ..compat import H5Group, ZarrGroup
from ..typing import RWAble
Expand Down Expand Up @@ -119,6 +124,30 @@ def check_key(key):
raise TypeError(msg)


@cache
def pandas_nullable_dtype(dtype: np.dtype) -> BaseMaskedDtype:
"""Infer nullable dtype from numpy dtype.

There is no public pandas API for this, so this is the cleanest way.
See <https://github.com/pandas-dev/pandas/issues/63608>
"""
try:
from pandas.core.dtypes.dtypes import BaseMaskedDtype
except ImportError:
pass
else:
return BaseMaskedDtype.from_numpy_dtype(dtype)

match dtype.kind:
case "b":
array_type = pd.arrays.BooleanArray
case "i" | "u":
array_type = pd.arrays.IntegerArray
case _:
raise NotImplementedError
return array_type(np.ones(1, dtype), np.ones(1, bool)).dtype


# -------------------------------------------------------------------------------
# Generic functions
# -------------------------------------------------------------------------------
Expand Down
45 changes: 20 additions & 25 deletions src/anndata/experimental/backed/_lazy_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from anndata._core.views import as_view
from anndata._io.specs.lazy_methods import get_chunksize

from ..._io.utils import pandas_nullable_dtype
from ..._settings import settings
from ...compat import (
NULLABLE_NUMPY_STRING_TYPE,
Expand All @@ -26,8 +27,9 @@
from pathlib import Path
from typing import Literal

from numpy.typing import NDArray
from pandas._libs.missing import NAType
from pandas.core.dtypes.base import ExtensionDtype
from pandas.core.dtypes.dtypes import BaseMaskedDtype

from anndata.compat import ZarrGroup

Expand Down Expand Up @@ -172,40 +174,33 @@ def __init__(
self.file_format = "zarr" if isinstance(mask, ZarrArray) else "h5"
self.elem_name = elem_name

def __getitem__(self, key: ExplicitIndexer) -> PandasExtensionArray | np.ndarray:
from xarray.core.extension_array import PandasExtensionArray

def __getitem__(
self, key: ExplicitIndexer
) -> PandasExtensionArray | NDArray[np.str_]:
values = self._values[key]
mask = self._mask[key]
if self._dtype_str == "nullable-integer":
# numpy does not support nan ints
extension_array = pd.arrays.IntegerArray(values, mask=mask)
elif self._dtype_str == "nullable-boolean":
extension_array = pd.arrays.BooleanArray(values, mask=mask)
elif self._dtype_str == "nullable-string-array":

if isinstance(self.dtype, np.dtypes.StringDType):
# https://github.com/pydata/xarray/issues/10419
values = values.astype(self.dtype)
values[mask] = pd.NA
return values
else:
msg = f"Invalid dtype_str {self._dtype_str}"
raise RuntimeError(msg)
return PandasExtensionArray(extension_array)

from xarray.core.extension_array import PandasExtensionArray

cls = self.dtype.construct_array_type()
return PandasExtensionArray(cls(values, mask))

@cached_property
def dtype(self) -> np.dtypes.StringDType[NAType] | ExtensionDtype:
if self._dtype_str == "nullable-integer":
return pd.array(
[],
dtype=str(pd.api.types.pandas_dtype(self._values.dtype)).capitalize(),
).dtype
elif self._dtype_str == "nullable-boolean":
return pd.BooleanDtype()
elif self._dtype_str == "nullable-string-array":
def dtype(self) -> BaseMaskedDtype | np.dtypes.StringDType[NAType]:
if self._dtype_str == "nullable-string-array":
# https://github.com/pydata/xarray/issues/10419
return NULLABLE_NUMPY_STRING_TYPE
msg = f"Invalid dtype_str {self._dtype_str}"
raise RuntimeError(msg)
try:
return pandas_nullable_dtype(self._values.dtype)
except NotImplementedError:
msg = f"Invalid dtype_str {self._dtype_str}"
raise RuntimeError(msg) from None


@_subset.register(XDataArray)
Expand Down
32 changes: 14 additions & 18 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@
DEFAULT_COL_TYPES = (
pd.CategoricalDtype(ordered=False),
pd.CategoricalDtype(ordered=True),
np.int64,
np.float64,
np.uint8,
np.bool_,
pd.BooleanDtype,
pd.Int32Dtype,
np.dtype(np.int64),
np.dtype(np.float64),
np.dtype(np.uint8),
np.dtype(bool),
pd.BooleanDtype(),
pd.Int32Dtype(),
pd.UInt8Dtype(),
)


Expand Down Expand Up @@ -108,13 +109,11 @@ def gen_vstr_recarray(m, n, dtype=None):


def issubdtype(
a: np.dtype | pd.api.extensions.ExtensionDtype | type,
b: type[DT] | tuple[type[DT], ...],
a: np.dtype | pd.api.extensions.ExtensionDtype, b: type[DT] | tuple[type[DT], ...]
) -> TypeGuard[DT]:
assert not isinstance(a, type)
if isinstance(b, tuple):
return any(issubdtype(a, t) for t in b)
if isinstance(a, type) and issubclass(a, pd.api.extensions.ExtensionDtype):
return issubclass(a, b)
if isinstance(a, pd.api.extensions.ExtensionDtype):
return isinstance(a, b)
try:
Expand All @@ -126,6 +125,7 @@ def issubdtype(
def gen_random_column( # noqa: PLR0911
n: int, dtype: np.dtype | pd.api.extensions.ExtensionDtype
) -> tuple[str, np.ndarray | pd.api.extensions.ExtensionArray]:
assert isinstance(dtype, np.dtype | pd.api.extensions.ExtensionDtype)
if issubdtype(dtype, pd.CategoricalDtype):
# TODO: Think about allowing index to be passed for n
letters = np.fromiter(iter(ascii_letters), "U1")
Expand All @@ -142,13 +142,9 @@ def gen_random_column( # noqa: PLR0911
),
)
if issubdtype(dtype, IntegerDtype):
return (
"nullable-int",
pd.arrays.IntegerArray(
np.random.randint(0, 1000, size=n, dtype=np.int32),
mask=np.random.randint(0, 2, size=n, dtype=bool),
),
)
name, values = gen_random_column(n, dtype.numpy_dtype)
mask = np.random.randint(0, 2, size=n, dtype=bool)
return f"nullable-{name}", pd.arrays.IntegerArray(values, mask)
if issubdtype(dtype, pd.StringDtype):
letters = np.fromiter(iter(ascii_letters), "U1")
array = pd.array(np.random.choice(letters, n), dtype=pd.StringDtype())
Expand All @@ -162,7 +158,7 @@ def gen_random_column( # noqa: PLR0911
if not issubdtype(dtype, np.number): # pragma: no cover
pytest.fail(f"Unexpected dtype: {dtype}")

n_bits = 8 * (dtype().itemsize if isinstance(dtype, type) else dtype.itemsize)
n_bits = 8 * dtype.itemsize

if issubdtype(dtype, np.unsignedinteger):
return f"uint{n_bits}", np.random.randint(0, 255, n, dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions tests/lazy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def adata_remote_orig_with_path(
orig = gen_adata(
(100, 110),
mtx_format,
obs_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype),
var_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype),
obs_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype()),
var_dtypes=(*DEFAULT_COL_TYPES, pd.StringDtype()),
obsm_types=(*DEFAULT_KEY_TYPES, AwkArray),
varm_types=(*DEFAULT_KEY_TYPES, AwkArray),
)
Expand Down
18 changes: 11 additions & 7 deletions tests/lazy/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,21 @@ def test_concat_to_memory_obs(


def test_concat_to_memory_obs_dtypes(
subtests: pytest.Subtests,
lazy_adatas_for_concat: list[AnnData],
join: Join_T,
):
) -> None:
concated_remote = ad.concat(lazy_adatas_for_concat, join=join)
# check preservation of non-categorical dtypes on the concat axis
assert concated_remote.obs["int64"].dtype == "int64"
assert concated_remote.obs["uint8"].dtype == "uint8"
assert concated_remote.obs["nullable-int"].dtype == "int32"
assert concated_remote.obs["float64"].dtype == "float64"
assert concated_remote.obs["bool"].dtype == "bool"
assert concated_remote.obs["nullable-bool"].dtype == "bool"
for name in concated_remote.obs.columns:
dtype = name.removeprefix("nullable-")
with subtests.test(col=name):
try:
assert concated_remote.obs[name].dtype == dtype
except AssertionError:
if "cat" in name:
pytest.xfail("categorical dtypes are not preserved")
raise


def test_concat_to_memory_var(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1509,7 +1509,7 @@ def test_concat_size_0_axis(
"""Regression test for https://github.com/scverse/anndata/issues/526"""
axis, axis_name = merge._resolve_axis(axis_name)
alt_axis = 1 - axis
col_dtypes = (*DEFAULT_COL_TYPES, pd.StringDtype)
col_dtypes = (*DEFAULT_COL_TYPES, pd.StringDtype())
a = gen_adata(
(5, 7),
obs_dtypes=col_dtypes,
Expand Down
7 changes: 6 additions & 1 deletion tests/test_dask_view_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ def attr_name(request):
return request.param


@pytest.fixture(params=[True, False])
@pytest.fixture(
params=[
pytest.param(True, id="give_chunks"),
pytest.param(False, id="no_give_chunks"),
]
)
def give_chunks(request):
return request.param

Expand Down
12 changes: 10 additions & 2 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,15 @@ def test_gen_awkward(shape, datashape):
assert arr.type == arr_type


@pytest.mark.parametrize("dtype", [*DEFAULT_COL_TYPES, pd.StringDtype])
@pytest.mark.parametrize(
"dtype",
[*DEFAULT_COL_TYPES, pd.StringDtype()],
ids=lambda dt: (
f"{dt}-{'' if dt.ordered else 'un'}ordered"
if isinstance(dt, pd.CategoricalDtype)
else str(dt)
),
)
def test_gen_random_column(dtype):
_, col = gen_random_column(10, dtype)
assert len(col) == 10
Expand All @@ -96,7 +104,7 @@ def test_gen_random_column(dtype):
assert issubdtype(col.dtype, pd.CategoricalDtype)
assert col.dtype.ordered == dtype.ordered
else:
assert issubdtype(col.dtype, dtype)
assert col.dtype == dtype


# Does this work for every warning?
Expand Down
2 changes: 1 addition & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_view_subset_shapes():
adata = gen_adata((20, 10), **GEN_ADATA_DASK_ARGS)

view = adata[:, ::2]
assert view.var.shape == (5, 8)
assert view.var.shape == (5, adata.var.shape[1])
assert {k: v.shape[0] for k, v in view.varm.items()} == dict.fromkeys(view.varm, 5)


Expand Down
Loading