Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ test-min = [
test-array-api = [
{ include-group = "test-min" },
{ include-group = "test-jax" },
"mlx[cpu]",
]
test-jax = [
"jax>=0.6.0",
Expand Down
5 changes: 4 additions & 1 deletion src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2202,7 +2202,10 @@ def _check_2d_shape(X):

Assure that X is always 2D: Unlike numpy we always deal with 2D arrays.
"""
if X.dtype.names is None and len(X.shape) != 2:
if (
isinstance(X, np.ndarray)
or (isinstance(X, np.ndarray) and X.dtype.names is None)
) and len(X.shape) != 2:
msg = f"X needs to be 2-dimensional, not {len(X.shape)}-dimensional."
raise ValueError(msg)

Expand Down
5 changes: 4 additions & 1 deletion src/anndata/_core/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,7 +1132,10 @@ def concat_pairwise_mapping(
elif all(isinstance(el, DaskArray) for el in els):
result[k] = _dask_block_diag(els)
else:
result[k] = sparse.block_diag(els, format="csr")
result[k] = sparse.block_diag(
[np.array(e) if not sparse.issparse(e) else e for e in els],
format="csr",
)
return result


Expand Down
5 changes: 3 additions & 2 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ def get_for_array(self, arr: SupportsArrayApi) -> SupportsArrayApi:
existing_xp = existing.__array_namespace__()
if existing_xp is xp:
return existing
return xp.from_dlpack(existing)
self.add_array(xp.from_dlpack(src_arr, copy=True))
# https://github.com/ml-explore/mlx/issues/48#issuecomment-3862013095 for asarray/dlpack
return getattr(xp, "from_dlpack", xp.asarray)(existing)
self.add_array(getattr(xp, "from_dlpack", xp.asarray)(src_arr, copy=True))
return self._manager[device]


Expand Down
15 changes: 15 additions & 0 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,21 @@ def jnp_array_or_idempotent(x):
return jnp.array(x)


if TYPE_CHECKING or find_spec("mlx"):
import mlx.core as mx

mx = _or_none(mx)
else:
mx = None


def mlx_array_or_idempotent(x):
if mx is None:
# In this case, the test should be marked by `array_api` to be conditionally skipped
return x
return mx.array(x)


try:
import fast_array_utils as _
except ImportError:
Expand Down
2 changes: 0 additions & 2 deletions src/anndata/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@ def __init__(self, adata: AnnData) -> None:

@runtime_checkable
class SupportsArrayApi(Protocol):
device: str
shape: tuple[int, ...]

def __array_namespace__(
self,
*,
api_version: Literal["2021.12", "2022.12", "2023.12", "2024.12"] | None = None,
) -> ModuleType: ...
def to_device(self, device: str, /, *, stream: int | Any | None = ...) -> Any: ...
def __dlpack__(
self,
*,
Expand Down
7 changes: 4 additions & 3 deletions tests/test_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
gen_vstr_recarray,
jnp,
jnp_array_or_idempotent,
mlx_array_or_idempotent,
)
from anndata.utils import asarray

Expand Down Expand Up @@ -1526,6 +1527,7 @@ def expected_shape(
[
pytest.param(np.array, id="np"),
pytest.param(jnp_array_or_idempotent, id="jax", marks=pytest.mark.array_api),
pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api),
],
)
def test_concat_size_0_axis(
Expand Down Expand Up @@ -1577,9 +1579,8 @@ def test_concat_size_0_axis(
altaxis_new_inds = ~axis_labels(result, alt_axis).isin(axis_labels(a, alt_axis))
axis_idx = make_idx_tuple(axis_new_inds, axis)
altaxis_idx = make_idx_tuple(altaxis_new_inds, 1 - axis)

check_filled_like(result.X[axis_idx], elem_name="X")
check_filled_like(result.X[altaxis_idx], elem_name="X")
check_filled_like(result[axis_idx].X, elem_name="X")
check_filled_like(result[altaxis_idx].X, elem_name="X")
for k, elem in result.layers.items():
check_filled_like(elem[axis_idx], elem_name=f"layers/{k}")
check_filled_like(elem[altaxis_idx], elem_name=f"layers/{k}")
Expand Down
13 changes: 11 additions & 2 deletions tests/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
gen_adata,
jnp,
jnp_array_or_idempotent,
mlx_array_or_idempotent,
)

if TYPE_CHECKING:
Expand All @@ -48,6 +49,7 @@
pytest.param(csr_matrix, id="csr_matrix"),
pytest.param(csr_array, id="csr_array"),
pytest.param(jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api),
pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api),
]

# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -292,6 +294,7 @@ def test_readwrite_backed(typ, backing_h5ad: Path) -> None:
pytest.param(
jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api
),
pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api),
],
)
def test_readwrite_equivalent_h5ad_zarr(tmp_path: Path, typ) -> None:
Expand Down Expand Up @@ -502,6 +505,7 @@ def test_read_tsv_iter():
pytest.param(
jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api
),
pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api),
np.array,
csr_matrix,
],
Expand All @@ -518,6 +522,7 @@ def test_write_csv(typ, tmp_path):
pytest.param(
jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api
),
pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api),
np.array,
csr_matrix,
],
Expand Down Expand Up @@ -562,10 +567,14 @@ def hash_dir_contents(dir: Path) -> dict[str, bytes]:
)
@pytest.mark.parametrize(
"xp_array",
[np.array, pytest.param(jnp_array_or_idempotent, marks=pytest.mark.array_api)],
[
np.array,
pytest.param(jnp_array_or_idempotent, marks=pytest.mark.array_api),
pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api),
],
)
def test_readwrite_empty(read, write, name: str, tmp_path: Path, xp_array) -> None:
adata = ad.AnnData(uns=dict(empty=xp_array([]).astype(float)))
adata = ad.AnnData(uns=dict(empty=xp_array([])))
write(tmp_path / name, adata)
ad_read = read(tmp_path / name)
assert ad_read.uns["empty"] is not None
Expand Down
7 changes: 7 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
gen_adata,
jnp,
jnp_array_or_idempotent,
mlx_array_or_idempotent,
single_int_subset,
single_subset,
slice_int_subset,
Expand Down Expand Up @@ -1034,6 +1035,12 @@ def test_normalize_index_jax_boolean() -> None:
id="jax",
marks=pytest.mark.array_api,
),
pytest.param(
type(mlx_array_or_idempotent(np.array([1]))),
_from_array,
id="mlx",
marks=pytest.mark.array_api,
),
*(
[pytest.param(XDataArray, _from_xarray, id="xarray")]
if find_spec("xarray") is not None
Expand Down
2 changes: 2 additions & 0 deletions tests/test_x.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
gen_adata,
jnp,
jnp_array_or_idempotent,
mlx_array_or_idempotent,
)
from anndata.utils import asarray

Expand All @@ -28,6 +29,7 @@
pytest.param(sparse.csc_array, id="csc_array"),
pytest.param(asarray, id="ndarray"),
pytest.param(jnp_array_or_idempotent, id="jax", marks=pytest.mark.array_api),
pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api),
]
SINGULAR_SHAPES = [
pytest.param(shape, id=str(shape)) for shape in [(1, 10), (10, 1), (1, 1)]
Expand Down
Loading