diff --git a/pyproject.toml b/pyproject.toml index 666cba289..11504a041 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ test-min = [ test-array-api = [ { include-group = "test-min" }, { include-group = "test-jax" }, + "mlx[cpu]", ] test-jax = [ "jax>=0.6.0", diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index fa6fd299e..263c33705 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -2202,7 +2202,10 @@ def _check_2d_shape(X): Assure that X is always 2D: Unlike numpy we always deal with 2D arrays. """ - if X.dtype.names is None and len(X.shape) != 2: + if ( + isinstance(X, np.ndarray) + or (isinstance(X, np.ndarray) and X.dtype.names is None) + ) and len(X.shape) != 2: msg = f"X needs to be 2-dimensional, not {len(X.shape)}-dimensional." raise ValueError(msg) diff --git a/src/anndata/_core/merge.py b/src/anndata/_core/merge.py index de067bd27..cc2960ca4 100644 --- a/src/anndata/_core/merge.py +++ b/src/anndata/_core/merge.py @@ -1132,7 +1132,10 @@ def concat_pairwise_mapping( elif all(isinstance(el, DaskArray) for el in els): result[k] = _dask_block_diag(els) else: - result[k] = sparse.block_diag(els, format="csr") + result[k] = sparse.block_diag( + [np.array(e) if not sparse.issparse(e) else e for e in els], + format="csr", + ) return result diff --git a/src/anndata/compat/__init__.py b/src/anndata/compat/__init__.py index 40643db03..ef8eeeccb 100644 --- a/src/anndata/compat/__init__.py +++ b/src/anndata/compat/__init__.py @@ -89,8 +89,9 @@ def get_for_array(self, arr: SupportsArrayApi) -> SupportsArrayApi: existing_xp = existing.__array_namespace__() if existing_xp is xp: return existing - return xp.from_dlpack(existing) - self.add_array(xp.from_dlpack(src_arr, copy=True)) + # https://github.com/ml-explore/mlx/issues/48#issuecomment-3862013095 for asarray/dlpack + return getattr(xp, "from_dlpack", xp.asarray)(existing) + self.add_array(getattr(xp, "from_dlpack", xp.asarray)(src_arr, copy=True)) return self._manager[device] diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index f237fa5e9..6681e86ab 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -86,6 +86,21 @@ def jnp_array_or_idempotent(x): return jnp.array(x) +if TYPE_CHECKING or find_spec("mlx"): + import mlx.core as mx + + mx = _or_none(mx) +else: + mx = None + + +def mlx_array_or_idempotent(x): + if mx is None: + # In this case, the test should be marked by `array_api` to be conditionally skipped + return x + return mx.array(x) + + try: import fast_array_utils as _ except ImportError: diff --git a/src/anndata/types.py b/src/anndata/types.py index 9842ef834..c49392b9f 100644 --- a/src/anndata/types.py +++ b/src/anndata/types.py @@ -28,7 +28,6 @@ def __init__(self, adata: AnnData) -> None: @runtime_checkable class SupportsArrayApi(Protocol): - device: str shape: tuple[int, ...] def __array_namespace__( @@ -36,7 +35,6 @@ def __array_namespace__( *, api_version: Literal["2021.12", "2022.12", "2023.12", "2024.12"] | None = None, ) -> ModuleType: ... - def to_device(self, device: str, /, *, stream: int | Any | None = ...) -> Any: ... def __dlpack__( self, *, diff --git a/tests/test_concatenate.py b/tests/test_concatenate.py index 419d28003..595cb173d 100644 --- a/tests/test_concatenate.py +++ b/tests/test_concatenate.py @@ -36,6 +36,7 @@ gen_vstr_recarray, jnp, jnp_array_or_idempotent, + mlx_array_or_idempotent, ) from anndata.utils import asarray @@ -1526,6 +1527,7 @@ def expected_shape( [ pytest.param(np.array, id="np"), pytest.param(jnp_array_or_idempotent, id="jax", marks=pytest.mark.array_api), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), ], ) def test_concat_size_0_axis( @@ -1577,9 +1579,8 @@ def test_concat_size_0_axis( altaxis_new_inds = ~axis_labels(result, alt_axis).isin(axis_labels(a, alt_axis)) axis_idx = make_idx_tuple(axis_new_inds, axis) altaxis_idx = make_idx_tuple(altaxis_new_inds, 1 - axis) - - check_filled_like(result.X[axis_idx], elem_name="X") - check_filled_like(result.X[altaxis_idx], elem_name="X") + check_filled_like(result[axis_idx].X, elem_name="X") + check_filled_like(result[altaxis_idx].X, elem_name="X") for k, elem in result.layers.items(): check_filled_like(elem[axis_idx], elem_name=f"layers/{k}") check_filled_like(elem[altaxis_idx], elem_name=f"layers/{k}") diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 3359b2ff8..0f577010b 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -35,6 +35,7 @@ gen_adata, jnp, jnp_array_or_idempotent, + mlx_array_or_idempotent, ) if TYPE_CHECKING: @@ -48,6 +49,7 @@ pytest.param(csr_matrix, id="csr_matrix"), pytest.param(csr_array, id="csr_array"), pytest.param(jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), ] # ------------------------------------------------------------------------------ @@ -292,6 +294,7 @@ def test_readwrite_backed(typ, backing_h5ad: Path) -> None: pytest.param( jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api ), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), ], ) def test_readwrite_equivalent_h5ad_zarr(tmp_path: Path, typ) -> None: @@ -502,6 +505,7 @@ def test_read_tsv_iter(): pytest.param( jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api ), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), np.array, csr_matrix, ], @@ -518,6 +522,7 @@ def test_write_csv(typ, tmp_path): pytest.param( jnp_array_or_idempotent, id="jax.array", marks=pytest.mark.array_api ), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), np.array, csr_matrix, ], @@ -562,10 +567,14 @@ def hash_dir_contents(dir: Path) -> dict[str, bytes]: ) @pytest.mark.parametrize( "xp_array", - [np.array, pytest.param(jnp_array_or_idempotent, marks=pytest.mark.array_api)], + [ + np.array, + pytest.param(jnp_array_or_idempotent, marks=pytest.mark.array_api), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), + ], ) def test_readwrite_empty(read, write, name: str, tmp_path: Path, xp_array) -> None: - adata = ad.AnnData(uns=dict(empty=xp_array([]).astype(float))) + adata = ad.AnnData(uns=dict(empty=xp_array([]))) write(tmp_path / name, adata) ad_read = read(tmp_path / name) assert ad_read.uns["empty"] is not None diff --git a/tests/test_views.py b/tests/test_views.py index db3e44592..e775f0c4b 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -45,6 +45,7 @@ gen_adata, jnp, jnp_array_or_idempotent, + mlx_array_or_idempotent, single_int_subset, single_subset, slice_int_subset, @@ -1034,6 +1035,12 @@ def test_normalize_index_jax_boolean() -> None: id="jax", marks=pytest.mark.array_api, ), + pytest.param( + type(mlx_array_or_idempotent(np.array([1]))), + _from_array, + id="mlx", + marks=pytest.mark.array_api, + ), *( [pytest.param(XDataArray, _from_xarray, id="xarray")] if find_spec("xarray") is not None diff --git a/tests/test_x.py b/tests/test_x.py index 60e93ca61..142048271 100644 --- a/tests/test_x.py +++ b/tests/test_x.py @@ -18,6 +18,7 @@ gen_adata, jnp, jnp_array_or_idempotent, + mlx_array_or_idempotent, ) from anndata.utils import asarray @@ -28,6 +29,7 @@ pytest.param(sparse.csc_array, id="csc_array"), pytest.param(asarray, id="ndarray"), pytest.param(jnp_array_or_idempotent, id="jax", marks=pytest.mark.array_api), + pytest.param(mlx_array_or_idempotent, id="mlx", marks=pytest.mark.array_api), ] SINGULAR_SHAPES = [ pytest.param(shape, id=str(shape)) for shape in [(1, 10), (10, 1), (1, 1)]