From 1e705d87555b64ae8e614d95ee314b812b0d1dfc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 6 Feb 2026 19:41:40 +0100 Subject: [PATCH 1/4] feat: `mlx` compat --- pyproject.toml | 1 + src/anndata/_core/anndata.py | 5 ++++- src/anndata/_core/merge.py | 5 ++++- src/anndata/tests/helpers.py | 15 +++++++++++++++ src/anndata/types.py | 2 -- tests/test_concatenate.py | 7 ++++--- tests/test_readwrite.py | 11 ++++++++++- tests/test_views.py | 6 ++++++ tests/test_x.py | 2 ++ 9 files changed, 46 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 666cba289..11504a041 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ test-min = [ test-array-api = [ { include-group = "test-min" }, { include-group = "test-jax" }, + "mlx[cpu]", ] test-jax = [ "jax>=0.6.0", diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index fa6fd299e..263c33705 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -2202,7 +2202,10 @@ def _check_2d_shape(X): Assure that X is always 2D: Unlike numpy we always deal with 2D arrays. """ - if X.dtype.names is None and len(X.shape) != 2: + if ( + isinstance(X, np.ndarray) + or (isinstance(X, np.ndarray) and X.dtype.names is None) + ) and len(X.shape) != 2: msg = f"X needs to be 2-dimensional, not {len(X.shape)}-dimensional." raise ValueError(msg) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index de067bd27..cc2960ca4 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1132,7 +1132,10 @@ def concat_pairwise_mapping( elif all(isinstance(el, DaskArray) for el in els): result[k] = _dask_block_diag(els) else: - result[k] = sparse.block_diag(els, format="csr") + result[k] = sparse.block_diag( + [np.array(e) if not sparse.issparse(e) else e for e in els], + format="csr", + ) return result diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index f237fa5e9..6681e86ab 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -86,6 +86,21 @@ def jnp_array_or_idempotent(x): return jnp.array(x) +if TYPE_CHECKING or find_spec("mlx"): + import mlx.core as mx + + mx = _or_none(mx) +else: + mx = None + + +def mlx_array_or_idempotent(x): + if mx is None: + # In this case, the test should be marked by `array_api` to be conditionally skipped + return x + return mx.array(x) + + try: import fast_array_utils as _ except ImportError: diff --git a/src/anndata/types.py b/src/anndata/types.py index 9842ef834..c49392b9f 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -28,7 +28,6 @@ def __init__(self, adata: AnnData) -> None: @runtime_checkable class SupportsArrayApi(Protocol): - device: str shape: tuple[int, ...] def __array_namespace__( @@ -36,7 +35,6 @@ def __array_namespace__( *, api_version: Literal["2021.12", "2022.12", "2023.12", "2024.12"] | None = None, ) -> ModuleType: ... - def to_device(self, device: str, /, *, stream: int | Any | None = ...) -> Any: ... def __dlpack__( self, *, diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 419d28003..595cb173d 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -36,6 +36,7 @@ gen_vstr_recarray, jnp, jnp_array_or_idempotent, + mlx_array_or_idempotent, ) from anndata.utils import asarray @@ -1526,6 +1527,7 @@ def expected_shape( [ pytest.param(np.array, id="np"), pytest.param(jnp_array_or_idempotent, id="jax", marks=pytest.mark.array_api), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), ], ) def test_concat_size_0_axis( @@ -1577,9 +1579,8 @@ def test_concat_size_0_axis( altaxis_new_inds = ~axis_labels(result, alt_axis).isin(axis_labels(a, alt_axis)) axis_idx = make_idx_tuple(axis_new_inds, axis) altaxis_idx = make_idx_tuple(altaxis_new_inds, 1 - axis) - - check_filled_like(result.X[axis_idx], elem_name="X") - check_filled_like(result.X[altaxis_idx], elem_name="X") + check_filled_like(result[axis_idx].X, elem_name="X") + check_filled_like(result[altaxis_idx].X, elem_name="X") for k, elem in result.layers.items(): check_filled_like(elem[axis_idx], elem_name=f"layers/{k}") check_filled_like(elem[altaxis_idx], elem_name=f"layers/{k}") diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 3359b2ff8..6741b238d 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -35,6 +35,7 @@ gen_adata, jnp, jnp_array_or_idempotent, + mlx_array_or_idempotent, ) if TYPE_CHECKING: @@ -48,6 +49,7 @@ pytest.param(csr_matrix, id="csr_matrix"), pytest.param(csr_array, id="csr_array"), pytest.param(jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), ] # ------------------------------------------------------------------------------ @@ -292,6 +294,7 @@ def test_readwrite_backed(typ, backing_h5ad: Path) -> None: pytest.param( jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api ), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), ], ) def test_readwrite_equivalent_h5ad_zarr(tmp_path: Path, typ) -> None: @@ -502,6 +505,7 @@ def test_read_tsv_iter(): pytest.param( jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api ), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), np.array, csr_matrix, ], @@ -518,6 +522,7 @@ def test_write_csv(typ, tmp_path): pytest.param( jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api ), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), np.array, csr_matrix, ], @@ -562,7 +567,11 @@ def hash_dir_contents(dir: Path) -> dict[str, bytes]: ) @pytest.mark.parametrize( "xp_array", - [np.array, pytest.param(jnp_array_or_idempotent, marks=pytest.mark.array_api)], + [ + np.array, + pytest.param(jnp_array_or_idempotent, marks=pytest.mark.array_api), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), + ], ) def test_readwrite_empty(read, write, name: str, tmp_path: Path, xp_array) -> None: adata = ad.AnnData(uns=dict(empty=xp_array([]).astype(float))) diff --git a/tests/test_views.py b/tests/test_views.py index db3e44592..74a1ca461 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -45,6 +45,7 @@ gen_adata, jnp, jnp_array_or_idempotent, + mlx_array_or_idempotent, single_int_subset, single_subset, slice_int_subset, @@ -1034,6 +1035,11 @@ def test_normalize_index_jax_boolean() -> None: id="jax", marks=pytest.mark.array_api, ), + pytest.param( + type(mlx_array_or_idempotent(np.array([1]))), + id="mlx", + marks=pytest.mark.array_api, + ), *( [pytest.param(XDataArray, _from_xarray, id="xarray")] if find_spec("xarray") is not None diff --git a/tests/test_x.py b/tests/test_x.py index 60e93ca61..142048271 100644 --- a/tests/test_x.py +++ b/tests/test_x.py @@ -18,6 +18,7 @@ gen_adata, jnp, jnp_array_or_idempotent, + mlx_array_or_idempotent, ) from anndata.utils import asarray @@ -28,6 +29,7 @@ pytest.param(sparse.csc_array, id="csc_array"), pytest.param(asarray, id="ndarray"), pytest.param(jnp_array_or_idempotent, id="jax", marks=pytest.mark.array_api), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), ] SINGULAR_SHAPES = [ pytest.param(shape, id=str(shape)) for shape in [(1, 10), (10, 1), (1, 1)] From fa255372d6ffdbef99fc4f0764a7fc31e254591d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 6 Feb 2026 19:48:46 +0100 Subject: [PATCH 2/4] fix: params --- tests/test_views.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_views.py b/tests/test_views.py index 74a1ca461..e775f0c4b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -1037,6 +1037,7 @@ def test_normalize_index_jax_boolean() -> None: ), pytest.param( type(mlx_array_or_idempotent(np.array([1]))), + _from_array, id="mlx", marks=pytest.mark.array_api, ), From 60c418b337a30155782f4a863a50c2e4233dd39e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 6 Feb 2026 19:52:23 +0100 Subject: [PATCH 3/4] fix: no need for float with empty --- tests/test_readwrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 6741b238d..0f577010b 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -574,7 +574,7 @@ def hash_dir_contents(dir: Path) -> dict[str, bytes]: ], ) def test_readwrite_empty(read, write, name: str, tmp_path: Path, xp_array) -> None: - adata = ad.AnnData(uns=dict(empty=xp_array([]).astype(float))) + adata = ad.AnnData(uns=dict(empty=xp_array([]))) write(tmp_path / name, adata) ad_read = read(tmp_path / name) assert ad_read.uns["empty"] is not None From 1aa66d78632ca20c7b4095be2b28cba7fc55e2f9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 6 Feb 2026 20:02:26 +0100 Subject: [PATCH 4/4] fix: dlpack compat --- src/anndata/compat/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index 40643db03..ef8eeeccb 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -89,8 +89,9 @@ def get_for_array(self, arr: SupportsArrayApi) -> SupportsArrayApi: existing_xp = existing.__array_namespace__() if existing_xp is xp: return existing - return xp.from_dlpack(existing) - self.add_array(xp.from_dlpack(src_arr, copy=True)) + # https://github.com/ml-explore/mlx/issues/48#issuecomment-3862013095 for asarray/dlpack + return getattr(xp, "from_dlpack", xp.asarray)(existing) + self.add_array(getattr(xp, "from_dlpack", xp.asarray)(src_arr, copy=True)) return self._manager[device]