From 111d11891c233699dc9ac33bee1522823c5ab857 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 13:17:54 +0200 Subject: [PATCH 01/19] add MuData accessors --- src/mudata/__init__.py | 5 ++ src/mudata/_core/mudata.py | 12 ++- src/mudata/acc/__init__.py | 172 +++++++++++++++++++++++++++++++++++++ tests/conftest.py | 6 +- tests/test_accessors.py | 84 ++++++++++++++++++ tests/test_obs_var.py | 14 --- 6 files changed, 275 insertions(+), 18 deletions(-) create mode 100644 src/mudata/acc/__init__.py create mode 100644 tests/test_accessors.py diff --git a/src/mudata/__init__.py b/src/mudata/__init__.py index a526a9b..bf7695d 100644 --- a/src/mudata/__init__.py +++ b/src/mudata/__init__.py @@ -1,5 +1,7 @@ """Multimodal datasets""" +from contextlib import suppress + from anndata import AnnData from scverse_misc import ExtensionNamespace from scverse_misc import make_register_namespace_decorator as _make_register_namespace_decorator @@ -23,6 +25,9 @@ from ._core.to_ import to_anndata, to_mudata from ._version import __version__, __version_tuple__ +with suppress(ImportError): + from . import acc + # file format versions __anndataversion__ = "0.1.0" __mudataversion__ = "0.1.0" diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index da0b38f..c42e0a3 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -555,9 +555,17 @@ def __contains__(self, key) -> bool: return key in self._mod with suppress(ImportError): from anndata.acc import AdRef, MapAcc, RefAcc - - if isinstance(key, AdRef | RefAcc | MapAcc): + from ..acc import ModAcc, ModMapAcc, _ModalityMapAcc, _ModalityMixin + + if isinstance(key, ModAcc | _ModalityMapAcc): + return key.isin(self) + elif isinstance(key, _ModalityMixin): + return AnnData.__contains__(self.mod[key.mod], key) + elif isinstance(key, ModMapAcc): + return bool(self.mod) + elif isinstance(key, AdRef | RefAcc | MapAcc): return AnnData.__contains__(self, key) + raise TypeError(f"Unexpected key {key!r}.") @property diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py new file mode 100644 index 0000000..527bdcf --- /dev/null +++ b/src/mudata/acc/__init__.py @@ -0,0 +1,172 @@ +from __future__ import annotations + +from dataclasses import KW_ONLY, dataclass, field +from typing import TYPE_CHECKING + +import pandas as pd +from anndata.acc import ( + AdAcc, + AdRef, + GraphAcc, + GraphMapAcc, + Idx2D, + LayerAcc, + LayerMapAcc, + MapAcc, + MetaAcc, + MultiAcc, + MultiMapAcc, +) +from anndata.compat import XVariable +from anndata.typing import InMemoryArray + +if TYPE_CHECKING: + from anndata import AnnData + + from .. import MuData + + +@dataclass(frozen=True, kw_only=True) +class _ModalityMixin: + mod: str + + +@dataclass(frozen=True) +class _ModalityMapAcc[I, R](_ModalityMixin): + def isin(self, mdata: MuData, idx: I | None = None) -> bool: + if self.mod not in mdata.mod: + return False + else: + return super().isin(mdata.mod[self.mod], idx) + + def get(self, mdata: MuData, idx: I, /) -> R: + return super().get(mdata.mod[self.mod], idx) + + +@dataclass(frozen=True) +class ModLayerAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], LayerAcc[R]): + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].X" if self.k is None else f"A.mod[{self.mod}].layers[{self.k!r}]" + + +@dataclass(frozen=True, kw_only=True) +class ModLayerMapAcc[R: AdRef](_ModalityMixin, LayerMapAcc[R]): + ref_acc_cls: type[ModLayerAcc] = ModLayerAcc + + def __getitem__(self, k: str | None, /) -> ModLayerAcc[R]: + if not isinstance(k, str | None): + raise TypeError(f"Unsupported layer {k!r}") + return self.ref_acc_cls(mod=self.mod, k=k, ref_class=self.ref_class) + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].layers" + + +@dataclass(frozen=True) +class ModMetaAcc[R: AdRef[str | None]](_ModalityMapAcc[str, pd.api.extensions.ExtensionArray | XVariable], MetaAcc[R]): + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}" + + +@dataclass(frozen=True) +class ModMultiAcc[R: AdRef[int]](_ModalityMapAcc[int, InMemoryArray], MultiAcc[R]): + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}m[self.k!r]" + + +@dataclass(frozen=True, kw_only=True) +class ModMultiMapAcc[R: AdRef](_ModalityMixin, MultiMapAcc[R]): + ref_acc_cls: type[ModMultiAcc] = ModMultiAcc + + def __getitem__(self, k: str, /) -> ModMultiAcc[R]: + if not isinstance(k, str): + raise TypeError(f"Unsupported {self.dim}m key {k!r}") + return self.ref_acc_cls(mod=self.mod, k=k, dim=self.dim, ref_class=self.ref_class) + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}m" + + +@dataclass(frozen=True) +class ModGraphAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], GraphAcc[R]): + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}p[{self.k!r}]" + + +@dataclass(frozen=True, kw_only=True) +class ModGraphMapAcc[R: AdRef](_ModalityMixin, GraphMapAcc[R]): + ref_acc_cls: type[ModGraphAcc] = ModGraphAcc + + def __getitem__(self, k: str, /) -> ModGraphAcc[R]: + if not isinstance(k, str): + raise TypeError(f"Unsupported {self.dim}p key {k!r}") + return self.ref_acc_cls(mod=self.mod, k=k, dim=self.dim, ref_class=self.ref_class) + + def __repr__(self) -> str: + return f"A.mod[{self.mod!r}].{self.dim}p" + + +@dataclass(frozen=True, kw_only=True) +class ModAcc[R: AdRef](_ModalityMixin, AdAcc[R]): + layer_cls: type[ModLayerAcc] = ModLayerAcc + meta_cls: type[ModMetaAcc] = ModMetaAcc + multi_cls: type[ModMultiAcc] = ModMultiAcc + graph_cls: type[ModGraphAcc] = ModGraphAcc + + def isin(self, mdata: MuData) -> bool: + return self.mod in mdata.mod + + def get(self, mdata: MuData) -> ad.AnnData: + return mdata.mod[self.mod] + + def __post_init__(self) -> None: + x = self.layer_cls(mod=self.mod, k=None, ref_class=self.ref_class) + layers = ModLayerMapAcc(mod=self.mod, ref_class=self.ref_class, ref_acc_cls=self.layer_cls) + object.__setattr__(self, "X", x) + object.__setattr__(self, "layers", layers) + for dim in ("obs", "var"): + meta = self.meta_cls(mod=self.mod, dim=dim, ref_class=self.ref_class) + multi = ModMultiMapAcc(mod=self.mod, dim=dim, ref_class=self.ref_class, ref_acc_cls=self.multi_cls) + graphs = ModGraphMapAcc(mod=self.mod, dim=dim, ref_class=self.ref_class, ref_acc_cls=self.graph_cls) + object.__setattr__(self, dim, meta) + object.__setattr__(self, f"{dim}m", multi) + object.__setattr__(self, f"{dim}p", graphs) + + def __repr__(self) -> str: + return f"A.mod[{self.mod}]" + + +@dataclass(frozen=True) +class ModMapAcc[R: AdRef](MapAcc[ModAcc[R]]): + ref_class: type[R] + ref_acc_cls: type[ModAcc] = ModAcc + + def __getitem__(self, k: str, /) -> ModAcc[R]: + if not isinstance(k, str): + raise TypeError(f"Unsupported mod key {k!r}") + return self.ref_acc_cls(mod=k, ref_class=self.ref_class) + + def __repr__(self) -> str: + return "A.mod" + + +@dataclass(frozen=True) +class MuAcc[R: AdRef](AdAcc[R]): + mod_cls: type[ModAcc] = ModAcc + """Class to use for `mod` accessors.""" + + mod: ModMapAcc[R] = field(init=False) + + def __post_init__(self) -> None: + super().__post_init__() + mod = ModMapAcc(ref_class=self.ref_class, ref_acc_cls=self.mod_cls) + object.__setattr__(self, "mod", mod) + + def __getitem__(self, k: str, /) -> ModAcc[R]: + return self.mod[k] + + def __repr__(self) -> str: + return "A" + + +A = MuAcc() diff --git a/tests/conftest.py b/tests/conftest.py index d637007..9a4bc66 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,10 +37,12 @@ def rng() -> np.random.Generator: def mdata(rng: np.random.Generator, request: pytest.FixtureRequest) -> MuData: axis = getattr(request, "param", 0) mod1 = AnnData( - np.arange(0, 200, 0.1).reshape(-1, 20), obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)) + np.arange(0, 200, 0.1).reshape(-1, 20), + obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False).astype(str)), ) mod2 = AnnData( - np.arange(101, 3101, 1).reshape(-1, 30), obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False)) + np.arange(101, 3101, 1).reshape(-1, 30), + obs=pd.DataFrame(index=rng.choice(150, size=100, replace=False).astype(str)), ) mod1.var["assert-bool"] = True mod2.var["assert-bool"] = False diff --git a/tests/test_accessors.py b/tests/test_accessors.py new file mode 100644 index 0000000..cf5a4b3 --- /dev/null +++ b/tests/test_accessors.py @@ -0,0 +1,84 @@ +import anndata as ad +import numpy as np +import pandas as pd +import pytest +from packaging.version import Version + +import mudata as md +from mudata.acc import A + +if Version(ad.__version__) < Version("0.13dev0"): + pytest.skip("anndata version too old, no accessor support", allow_module_level=True) + + +@pytest.fixture +def mdata_augmented(mdata: md.MuData, rng: np.random.Generator): + mdata["mod1"].layers["counts"] = rng.poisson(1, size=mdata["mod1"].shape) + mdata["mod2"].obsp["test"] = rng.normal(size=(mdata["mod2"].n_obs, mdata["mod2"].n_obs)) + + return mdata + + +def test_anndata_accessors(mdata: md.MuData): + assert ad.acc.A.obs["arange"] in mdata + assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() + with pytest.raises(KeyError, match="test"): + mdata[ad.acc.A.var["test"]] + with pytest.raises(KeyError, match="there is one in"): + mdata[ad.acc.A.var["assert-bool"]] + + +PATHS = [ + (A.mod["mod1"], lambda md: md.mod["mod1"]), + (A["mod1"], lambda md: md["mod1"]), + (A.mod["mod1"].var, lambda md: md.mod["mod1"].var), + (A.mod["mod1"].var["assert-bool"], lambda md: md.mod["mod1"].var["assert-bool"]), + (A.mod["mod1"].X, lambda md: md.mod["mod1"].X), + (A.mod["mod1"].X["obs_2", :], lambda md: md.mod["mod1"]["obs_2", :].X.squeeze()), + (A.mod["mod1"].X[:, "mod1_var_1"], lambda md: md.mod["mod1"][:, "mod1_var_1"].X.squeeze()), + (A["mod1"].layers, lambda md: md["mod1"].layers), + (A["mod1"].layers["counts"], lambda md: md["mod1"].layers["counts"]), + (A["mod1"].layers["counts"]["obs_2", :], lambda md: md["mod1"]["obs_2", :].layers["counts"].squeeze()), + (A["mod2"].obsp, lambda md: md["mod2"].obsp), + (A["mod2"].obsp["test"], lambda md: md["mod2"].obsp["test"]), + (A["mod2"].obsp["test"][:, "obs_3"], lambda md: md["mod2"].obsp["test"][:, md["mod2"].obs_names.get_loc("obs_3")]), +] + + +@pytest.mark.parametrize("acc", [path[0] for path in PATHS]) +def test_in(mdata_augmented: md.MuData, acc): + assert acc in mdata_augmented + + +@pytest.mark.parametrize( + "acc", + [ + A.mod["mod3"], + A["mod3"], + A.mod["mod3"].var, + A.mod["mod1"].var["does_not_exist"], + A.mod["mod3"].X, + A.mod["mod3"].X["obs_2", :], + A.mod["mod3"].X[:, "mod1_var_1"], + A["mod2"].layers, + A["mod1"].layers["does_not_exist"], + A["mod1"].layers["does_not_exist"]["obs_2", :], + A["mod1"].obsp, + A["mod2"].obsp["does_not_exist"], + A["mod2"].obsp["does_not_exist"][:, "obs_3"], + ], +) +def test_not_in(mdata: md.MuData, acc): + assert acc not in mdata + + +@pytest.mark.parametrize("acc_expected", [path for path in PATHS if isinstance(path[0], ad.acc.AdRef)]) +def test_get(mdata_augmented: md.MuData, acc_expected): + acc, expected = acc_expected + + val = mdata_augmented[acc] + expected = expected(mdata_augmented) + if isinstance(expected, pd.DataFrame | pd.Series | np.ndarray): + assert np.all(val == expected) + else: + assert val == expected diff --git a/tests/test_obs_var.py b/tests/test_obs_var.py index 4636b4f..3057c50 100644 --- a/tests/test_obs_var.py +++ b/tests/test_obs_var.py @@ -1,10 +1,8 @@ from pathlib import Path -import anndata as ad import numpy as np import pandas as pd import pytest -from packaging.version import Version import mudata as md @@ -147,15 +145,3 @@ def test_names_make_unique(mdata: md.MuData): with pytest.raises(TypeError, match="axis="): getattr(mdata, f"{attr}_names_make_unique")() - - -@pytest.mark.skipif( - Version(ad.__version__) < Version("0.13dev0"), reason="anndata version too old, no accessor support" -) -def test_accessors(mdata: md.MuData): - assert ad.acc.A.obs["arange"] in mdata - assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() - with pytest.raises(KeyError, match="test"): - mdata[ad.acc.A.var["test"]] - with pytest.raises(KeyError, match="there is one in"): - mdata[ad.acc.A.var["assert-bool"]] From 7580584bace6d371bba87f8ded8fb8a7160850da Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 14:22:58 +0200 Subject: [PATCH 02/19] prevent X and layers accessors for MuData objects --- src/mudata/acc/__init__.py | 5 +++++ tests/test_accessors.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 527bdcf..b10573d 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -168,5 +168,10 @@ def __getitem__(self, k: str, /) -> ModAcc[R]: def __repr__(self) -> str: return "A" + def __getattribute__(self, name: str): + if name in ("X", "layers"): + raise AttributeError(f"{self.__class__.__name__!r} object has no attribute {name!r}") + return super().__getattribute__(name) + A = MuAcc() diff --git a/tests/test_accessors.py b/tests/test_accessors.py index cf5a4b3..71600b0 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -82,3 +82,10 @@ def test_get(mdata_augmented: md.MuData, acc_expected): assert np.all(val == expected) else: assert val == expected + + +def test_no_data(): + with pytest.raises(AttributeError): + A.X # noqa: B018 + with pytest.raises(AttributeError): + A.layers # noqa: B018 From d34fd49c204206b806f77b329c9bf4ec07307cea Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 14:57:33 +0200 Subject: [PATCH 03/19] add obsmap/varmap accessors --- src/mudata/_core/mudata.py | 6 +++--- src/mudata/acc/__init__.py | 43 +++++++++++++++++++++++++++++++------- tests/test_accessors.py | 6 ++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index c42e0a3..e661653 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -555,13 +555,13 @@ def __contains__(self, key) -> bool: return key in self._mod with suppress(ImportError): from anndata.acc import AdRef, MapAcc, RefAcc - from ..acc import ModAcc, ModMapAcc, _ModalityMapAcc, _ModalityMixin + from ..acc import ModAcc, MultiModAcc, _ModalityMapAcc, _ModalityMixin if isinstance(key, ModAcc | _ModalityMapAcc): return key.isin(self) elif isinstance(key, _ModalityMixin): - return AnnData.__contains__(self.mod[key.mod], key) - elif isinstance(key, ModMapAcc): + return key in self.mod[key.mod] + elif isinstance(key, MultiModAcc): return bool(self.mod) elif isinstance(key, AdRef | RefAcc | MapAcc): return AnnData.__contains__(self, key) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index b10573d..64123e2 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,12 +1,13 @@ from __future__ import annotations from dataclasses import KW_ONLY, dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pandas as pd from anndata.acc import ( AdAcc, AdRef, + Axes, GraphAcc, GraphMapAcc, Idx2D, @@ -16,6 +17,7 @@ MetaAcc, MultiAcc, MultiMapAcc, + RefAcc, ) from anndata.compat import XVariable from anndata.typing import InMemoryArray @@ -49,7 +51,7 @@ def __repr__(self) -> str: return f"A.mod[{self.mod!r}].X" if self.k is None else f"A.mod[{self.mod}].layers[{self.k!r}]" -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class ModLayerMapAcc[R: AdRef](_ModalityMixin, LayerMapAcc[R]): ref_acc_cls: type[ModLayerAcc] = ModLayerAcc @@ -74,7 +76,7 @@ def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}m[self.k!r]" -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class ModMultiMapAcc[R: AdRef](_ModalityMixin, MultiMapAcc[R]): ref_acc_cls: type[ModMultiAcc] = ModMultiAcc @@ -93,7 +95,7 @@ def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}p[{self.k!r}]" -@dataclass(frozen=True, kw_only=True) +@dataclass(frozen=True) class ModGraphMapAcc[R: AdRef](_ModalityMixin, GraphMapAcc[R]): ref_acc_cls: type[ModGraphAcc] = ModGraphAcc @@ -106,6 +108,28 @@ def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}p" +@dataclass(frozen=True) +class ModMapAcc[R: AdRef[str]](RefAcc[R, str]): + dim: Literal[obs, var] + + def dims(self, idx: Any, /) -> Axes: + return (self.dim,) + + def __repr__(self) -> str: + return f"A.{self.dim}map" + + def idx_repr(self, idx: str, /) -> str: + return f"[{idx}]" + + def isin(self, mdata: MuData, idx: str | None = None) -> bool: + m = getattr(mdata, f"{self.dim}map") + return idx is None or idx in m + + def get(self, mdata: MuData, idx: str, /) -> InMemoryArray: + m = getattr(mdata, f"{self.dim}map") + return m[idx] + + @dataclass(frozen=True, kw_only=True) class ModAcc[R: AdRef](_ModalityMixin, AdAcc[R]): layer_cls: type[ModLayerAcc] = ModLayerAcc @@ -137,7 +161,7 @@ def __repr__(self) -> str: @dataclass(frozen=True) -class ModMapAcc[R: AdRef](MapAcc[ModAcc[R]]): +class MultiModAcc[R: AdRef](MapAcc[ModAcc[R]]): ref_class: type[R] ref_acc_cls: type[ModAcc] = ModAcc @@ -155,12 +179,15 @@ class MuAcc[R: AdRef](AdAcc[R]): mod_cls: type[ModAcc] = ModAcc """Class to use for `mod` accessors.""" - mod: ModMapAcc[R] = field(init=False) + mod: MultiModAcc[R] = field(init=False) + obsmap: ModMapAcc[R] = field(init=False) + varmap: ModMapAcc[R] = field(init=False) def __post_init__(self) -> None: super().__post_init__() - mod = ModMapAcc(ref_class=self.ref_class, ref_acc_cls=self.mod_cls) - object.__setattr__(self, "mod", mod) + object.__setattr__(self, "mod", MultiModAcc(ref_class=self.ref_class, ref_acc_cls=self.mod_cls)) + object.__setattr__(self, "obsmap", ModMapAcc("obs", ref_class=self.ref_class)) + object.__setattr__(self, "varmap", ModMapAcc("var", ref_class=self.ref_class)) def __getitem__(self, k: str, /) -> ModAcc[R]: return self.mod[k] diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 71600b0..a5ce402 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -42,6 +42,10 @@ def test_anndata_accessors(mdata: md.MuData): (A["mod2"].obsp, lambda md: md["mod2"].obsp), (A["mod2"].obsp["test"], lambda md: md["mod2"].obsp["test"]), (A["mod2"].obsp["test"][:, "obs_3"], lambda md: md["mod2"].obsp["test"][:, md["mod2"].obs_names.get_loc("obs_3")]), + (A.obsmap, lambda md: md.obsmap), + (A.varmap, lambda md: md.varmap), + (A.obsmap["mod1"], lambda md: md.obsmap["mod1"]), + (A.varmap["mod2"], lambda md: md.varmap["mod2"]), ] @@ -66,6 +70,8 @@ def test_in(mdata_augmented: md.MuData, acc): A["mod1"].obsp, A["mod2"].obsp["does_not_exist"], A["mod2"].obsp["does_not_exist"][:, "obs_3"], + A.obsmap["mod3"], + A.varmap["mod3"], ], ) def test_not_in(mdata: md.MuData, acc): From 8e303f7a3729c1e23bc4826b9ffea5d4d15e9f8a Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 15:09:01 +0200 Subject: [PATCH 04/19] fixup! prevent X and layers accessors for MuData objects --- src/mudata/acc/__init__.py | 10 +++++----- tests/test_accessors.py | 5 +++++ 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 64123e2..0eb683b 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -189,16 +189,16 @@ def __post_init__(self) -> None: object.__setattr__(self, "obsmap", ModMapAcc("obs", ref_class=self.ref_class)) object.__setattr__(self, "varmap", ModMapAcc("var", ref_class=self.ref_class)) + del self.__dict__["X"] + del self.__dict__["layers"] + del self.__dataclass_fields__["X"] + del self.__dataclass_fields__["layers"] + def __getitem__(self, k: str, /) -> ModAcc[R]: return self.mod[k] def __repr__(self) -> str: return "A" - def __getattribute__(self, name: str): - if name in ("X", "layers"): - raise AttributeError(f"{self.__class__.__name__!r} object has no attribute {name!r}") - return super().__getattribute__(name) - A = MuAcc() diff --git a/tests/test_accessors.py b/tests/test_accessors.py index a5ce402..745b469 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -1,3 +1,5 @@ +from dataclasses import fields + import anndata as ad import numpy as np import pandas as pd @@ -95,3 +97,6 @@ def test_no_data(): A.X # noqa: B018 with pytest.raises(AttributeError): A.layers # noqa: B018 + + for field in fields(A): + assert field.name not in ("X", "layers") From 8756654f0b941d1f551d3eba0b74f2c1b96f4349 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 16:31:31 +0200 Subject: [PATCH 05/19] add docs --- docs/accessors.md | 34 ++++++++++++++++++++++++++ docs/api.md | 8 ++++++ docs/conf.py | 2 +- src/mudata/acc/__init__.py | 50 +++++++++++++++++++++++++++++++++++--- 4 files changed, 89 insertions(+), 5 deletions(-) create mode 100644 docs/accessors.md diff --git a/docs/accessors.md b/docs/accessors.md new file mode 100644 index 0000000..d711789 --- /dev/null +++ b/docs/accessors.md @@ -0,0 +1,34 @@ +# Accessors and paths + +```{eval-rst} +.. module:: mudata.acc +``` + +[](#mudata.acc) provides [accessors](inv:anndata:*:term#accessor) that create [references](inv:anndata:*:term#reference) to axis-aligned 1D and 2D arrays in [MuData](#mudata.MuData) objects. +See the corresponding [AnnData documentation](inv:anndata:*:doc#accessors). + +:::{important} +This functionality requires AnnData 0.13 or later. +::: + +The central [accessor](inv:anndata:*:term#accessor) is [](#A). +```{eval-rst} +.. autodata:: A +``` +See [](#MuAcc) and [AdAcc](#anndata.acc.AdAcc) for examples of how to use it to create [references](inv:anndata:*:term#reference) (i.e. [AdRefs](#anndata.acc.AdRef)). + +```{eval-rst} +.. autosummary:: + :toctree: generated + + MuAcc + MultiModAcc + ModAcc + ModMapAcc + ModMetaAcc + ModLayerAcc + ModGraphAcc + ModMultiAcc + ModMultiMapAcc + ModGraphMapAcc +``` diff --git a/docs/api.md b/docs/api.md index 9b2291a..00cae91 100644 --- a/docs/api.md +++ b/docs/api.md @@ -35,6 +35,14 @@ write_zarr ``` +## Accessors +```{eval-rst} +.. toctree:: + :hidden: + + mudata.acc +``` + ## Extensions ```{eval-rst} .. autosummary:: diff --git a/docs/conf.py b/docs/conf.py index 97fc30c..8ab633a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -97,7 +97,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), - "anndata": ("https://anndata.readthedocs.io/en/stable/", None), + "anndata": ("https://anndata.readthedocs.io/en/latest/", None), "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "pandas": ("https://pandas.pydata.org/docs/", None), diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 0eb683b..1d91d3c 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import KW_ONLY, dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import pandas as pd from anndata.acc import ( @@ -31,6 +31,7 @@ @dataclass(frozen=True, kw_only=True) class _ModalityMixin: mod: str + """Modality this accessor refers to.""" @dataclass(frozen=True) @@ -47,12 +48,16 @@ def get(self, mdata: MuData, idx: I, /) -> R: @dataclass(frozen=True) class ModLayerAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], LayerAcc[R]): + """Reference accessor for arrays in :attr:`~anndata.acc.AdAcc.layers`.""" + def __repr__(self) -> str: return f"A.mod[{self.mod!r}].X" if self.k is None else f"A.mod[{self.mod}].layers[{self.k!r}]" @dataclass(frozen=True) class ModLayerMapAcc[R: AdRef](_ModalityMixin, LayerMapAcc[R]): + """Accessor for arrays in :attr:~anndata.acc.AdAcc.layers`.""" + ref_acc_cls: type[ModLayerAcc] = ModLayerAcc def __getitem__(self, k: str | None, /) -> ModLayerAcc[R]: @@ -66,18 +71,24 @@ def __repr__(self) -> str: @dataclass(frozen=True) class ModMetaAcc[R: AdRef[str | None]](_ModalityMapAcc[str, pd.api.extensions.ExtensionArray | XVariable], MetaAcc[R]): + """Reference accessor for arrays from metadata containers (:attr:`~anndata.acc.AdAcc.obs` / :attr:`~anndata.acc.AdAcc.var`).""" + def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}" @dataclass(frozen=True) class ModMultiAcc[R: AdRef[int]](_ModalityMapAcc[int, InMemoryArray], MultiAcc[R]): + """Reference accessor for arrays from multi-dimensional containers (:attr:`~anndata.acc.AdAcc.obsm` / :attr:`~anndata.acc.AdAcc.varm`).""" + def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}m[self.k!r]" @dataclass(frozen=True) class ModMultiMapAcc[R: AdRef](_ModalityMixin, MultiMapAcc[R]): + """Accessor for multi-dimensional array containers (:attr:`~anndata.acc.AdAcc.obsm` / :attr:`~anndata.acc.AdAcc.varm`).""" + ref_acc_cls: type[ModMultiAcc] = ModMultiAcc def __getitem__(self, k: str, /) -> ModMultiAcc[R]: @@ -91,12 +102,16 @@ def __repr__(self) -> str: @dataclass(frozen=True) class ModGraphAcc[R: AdRef[Idx2D]](_ModalityMapAcc[Idx2D, InMemoryArray], GraphAcc[R]): + """Reference accessor for arrays from graph containers (:attr:`~anndata.acc.AdAcc.obsp` / :attr:`~anndata.acc.AdAcc.varp`).""" + def __repr__(self) -> str: return f"A.mod[{self.mod!r}].{self.dim}p[{self.k!r}]" @dataclass(frozen=True) class ModGraphMapAcc[R: AdRef](_ModalityMixin, GraphMapAcc[R]): + """Accessor for graph containers (:attr:`~anndata.acc.AdAcc.obsp` / :attr:`~anndata.acc.AdAcc.varp`)""" + ref_acc_cls: type[ModGraphAcc] = ModGraphAcc def __getitem__(self, k: str, /) -> ModGraphAcc[R]: @@ -110,37 +125,55 @@ def __repr__(self) -> str: @dataclass(frozen=True) class ModMapAcc[R: AdRef[str]](RefAcc[R, str]): - dim: Literal[obs, var] + """Reference accessor for modality maps (:attr:`~MuAcc.obsmap` / :attr:`~MuAcc.varmap`).""" + + dim: Literal["obs", "var"] + """Axis this accessor refers to, e.g. `A.obsmap[k].dim == "var"`.""" def dims(self, idx: Any, /) -> Axes: + """Get which dimension this array refers to.""" return (self.dim,) def __repr__(self) -> str: return f"A.{self.dim}map" def idx_repr(self, idx: str, /) -> str: + """Get a string representation of the index.""" return f"[{idx}]" def isin(self, mdata: MuData, idx: str | None = None) -> bool: + """Check if the referenced array is in the :class:`~mudata.MuData` object.""" m = getattr(mdata, f"{self.dim}map") return idx is None or idx in m def get(self, mdata: MuData, idx: str, /) -> InMemoryArray: + """Get the referenced array from the :class:`~mudata.MuData` object.""" m = getattr(mdata, f"{self.dim}map") return m[idx] @dataclass(frozen=True, kw_only=True) class ModAcc[R: AdRef](_ModalityMixin, AdAcc[R]): + """Accessor to create :class:`AdRefs ` (:data:`A`) for modalities (:attr:`~MuAcc.mod`).""" + layer_cls: type[ModLayerAcc] = ModLayerAcc + """Class to use for `layers` accessors.""" + meta_cls: type[ModMetaAcc] = ModMetaAcc + """Class to use for `obs`/`var` accessors.""" + multi_cls: type[ModMultiAcc] = ModMultiAcc + """Class to use for `obsm`/`varm` accessors.""" + graph_cls: type[ModGraphAcc] = ModGraphAcc + """Class to use for `obsp`/`varp` accessors.""" def isin(self, mdata: MuData) -> bool: + """Check if the referenced modality is in the :class:`~mudata.MuData` object.""" return self.mod in mdata.mod - def get(self, mdata: MuData) -> ad.AnnData: + def get(self, mdata: MuData) -> AnnData: + """Get the referenced modality from the :class:`~mudata.MuData` object.""" return mdata.mod[self.mod] def __post_init__(self) -> None: @@ -162,6 +195,8 @@ def __repr__(self) -> str: @dataclass(frozen=True) class MultiModAcc[R: AdRef](MapAcc[ModAcc[R]]): + """Accessor for modalities (:attr:`~MuAcc.mod`).""" + ref_class: type[R] ref_acc_cls: type[ModAcc] = ModAcc @@ -176,12 +211,19 @@ def __repr__(self) -> str: @dataclass(frozen=True) class MuAcc[R: AdRef](AdAcc[R]): + """Accessor to create :class:`AdRefs ` (:data:`A`).""" + mod_cls: type[ModAcc] = ModAcc """Class to use for `mod` accessors.""" mod: MultiModAcc[R] = field(init=False) + """Access modalities.""" + obsmap: ModMapAcc[R] = field(init=False) + """Access mappings of observation indices in the MuData to indices in individual modalities.""" + varmap: ModMapAcc[R] = field(init=False) + """Access mappings of variable indices in the MuData to indices in individual modalities.""" def __post_init__(self) -> None: super().__post_init__() @@ -201,4 +243,4 @@ def __repr__(self) -> str: return "A" -A = MuAcc() +A: MuAcc[AdRef] = MuAcc() From c6e5c14b8b731084a97949b5dfbdb0298ca59ea8 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Fri, 5 Jun 2026 17:11:17 +0200 Subject: [PATCH 06/19] implement resolve() --- src/mudata/acc/__init__.py | 32 +++++++++++++++++++++++++++++++- tests/test_accessors.py | 14 ++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 1d91d3c..4c26293 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import KW_ONLY, dataclass, field -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, ClassVar, Literal import pandas as pd from anndata.acc import ( @@ -225,6 +225,8 @@ class MuAcc[R: AdRef](AdAcc[R]): varmap: ModMapAcc[R] = field(init=False) """Access mappings of variable indices in the MuData to indices in individual modalities.""" + ATTRS: ClassVar = frozenset(("mod", "obs", "var", "obsm", "varm", "obsp", "varp", "obsmap", "varmap")) + def __post_init__(self) -> None: super().__post_init__() object.__setattr__(self, "mod", MultiModAcc(ref_class=self.ref_class, ref_acc_cls=self.mod_cls)) @@ -242,5 +244,33 @@ def __getitem__(self, k: str, /) -> ModAcc[R]: def __repr__(self) -> str: return "A" + def resolve(self, spec: str, *, strict: bool = True) -> R | None: + """Create :class:`~anndata.acc.AdRef` from a simplified string.""" + if not strict: + try: + self.resolve(spec) + except ValueError: + return None + + firstdot = spec.find(".") + if firstdot < 0: + raise ValueError(f"Cannot parse accessor {spec!r} that is not period-separated.") + firstattr = spec[:firstdot] + match firstattr: + case "mod": + modend = spec.find(".", firstdot + 1) + mod = spec[firstdot + 1 : modend] + if not mod: + raise ValueError(f"Cannot parse accessor{spec!r} that has an empty modality.") + acc = self.mod[mod] + return super().resolve.__func__(acc, spec[modend + 1 :], strict=strict) + case "obsmap" | "varmap": + if firstdot == len(spec): + raise ValueError(f"Cannot parse accessor{spec!r} that has an empty modality.") + mod = spec[firstdot + 1 :] + return getattr(self, firstattr)[mod] + case _: + return super().resolve(spec, strict=strict) + A: MuAcc[AdRef] = MuAcc() diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 745b469..d6e2405 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -100,3 +100,17 @@ def test_no_data(): for field in fields(A): assert field.name not in ("X", "layers") + + +def test_resolve(): + assert A.resolve("mod.rna.X[:, ACT1]") == A.mod["rna"].X[:, "ACT1"] + assert A.resolve("obsmap.rna") == A.obsmap["rna"] + + with pytest.raises(ValueError, match="Unknown accessor"): + A.resolve("rna.X[:, :]") + + with pytest.raises(ValueError, match="empty modality"): + A.resolve("mod..X[:, :]") + + with pytest.raises(ValueError, match="period-separated"): + A.resolve("abcd") From 14de44a62ea323617b2ce89685a09806810d34ae Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 11:22:44 +0200 Subject: [PATCH 07/19] fix docs skip over X and layers attributes in the autosummary template, this is the only way to keep them out of MuAcc docs --- .../_templates/autosummary/class-accessor.rst | 26 +++++++++++++++++++ docs/accessors.md | 3 ++- docs/conf.py | 7 +++-- docs/extensions/skip_private_bases.py | 17 ++++++++++++ pyproject.toml | 2 +- src/mudata/acc/__init__.py | 17 ++++++++---- 6 files changed, 63 insertions(+), 9 deletions(-) create mode 100644 docs/_templates/autosummary/class-accessor.rst create mode 100644 docs/extensions/skip_private_bases.py diff --git a/docs/_templates/autosummary/class-accessor.rst b/docs/_templates/autosummary/class-accessor.rst new file mode 100644 index 0000000..cf3deb0 --- /dev/null +++ b/docs/_templates/autosummary/class-accessor.rst @@ -0,0 +1,26 @@ +{{ fullname | escape | underline}} + +{% set attributes = attributes | select("ne", "X") | select("ne", "layers") %} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :show-inheritance: + + {% block attributes %} + {%- for item in attributes %} + {%- if loop.first %} + .. rubric:: Attributes + {% endif %} + .. autoattribute:: {{ item }} + {%- endfor %} + {% endblock %} + + {% block methods %} + {%- for item in methods if item != "__init__" and item not in inherited_members %} + {%- if loop.first %} + .. rubric:: Methods + {% endif %} + .. automethod:: {{ item }} + {%- endfor %} + {% endblock %} diff --git a/docs/accessors.md b/docs/accessors.md index d711789..45ec21a 100644 --- a/docs/accessors.md +++ b/docs/accessors.md @@ -8,7 +8,7 @@ See the corresponding [AnnData documentation](inv:anndata:*:doc#accessors). :::{important} -This functionality requires AnnData 0.13 or later. +This functionality requires AnnData 0.13 or newer. ::: The central [accessor](inv:anndata:*:term#accessor) is [](#A). @@ -20,6 +20,7 @@ See [](#MuAcc) and [AdAcc](#anndata.acc.AdAcc) for examples of how to use it to ```{eval-rst} .. autosummary:: :toctree: generated + :template: class-accessor MuAcc MultiModAcc diff --git a/docs/conf.py b/docs/conf.py index 8ab633a..493c8d5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -135,8 +135,11 @@ pygments_style = "default" katex_prerender = shutil.which(katex.NODEJS_BINARY) is not None -nitpick_ignore = [ +nitpick_ignore = ( # If building the documentation fails because of a missing link that is outside your control, # you can add an exception to this list. # ("py:class", "igraph.Graph"), -] + ("py:obj", "typing.R"), + ("py:obj", "anndata.acc.Axes"), + ("py:class", "AdRef"), +) diff --git a/docs/extensions/skip_private_bases.py b/docs/extensions/skip_private_bases.py new file mode 100644 index 0000000..eb4a07c --- /dev/null +++ b/docs/extensions/skip_private_bases.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Generic, get_origin + +from sphinx.util.typing import ExtensionMetadata + +if TYPE_CHECKING: + from sphinx.application import Sphinx + + +def skip_private_bases(app: Sphinx, name: str, obj: type, _unused, bases: list[type]) -> None: + bases[:] = [b for b in bases if b is not object if get_origin(b) is not Generic if not b.__name__.startswith("_")] + + +def setup(app: Sphinx) -> ExtensionMetadata: + app.connect("autodoc-process-bases", skip_private_bases) + return ExtensionMetadata(parallel_read_safe=True) diff --git a/pyproject.toml b/pyproject.toml index 20c0987..ede5a0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,7 +124,7 @@ lint.ignore = [ "TID252", # allow relative imports ] lint.per-file-ignores."*/__init__.py" = [ "F401" ] -lint.per-file-ignores."docs/*" = [ "I" ] +lint.per-file-ignores."docs/*" = [ "D", "I" ] lint.per-file-ignores."docs/notebooks/*" = [ "D", "F403", "F405" ] lint.per-file-ignores."tests/*" = [ "D" ] lint.pydocstyle.convention = "numpy" diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 4c26293..a62ad62 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Hashable from dataclasses import KW_ONLY, dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -70,7 +71,7 @@ def __repr__(self) -> str: @dataclass(frozen=True) -class ModMetaAcc[R: AdRef[str | None]](_ModalityMapAcc[str, pd.api.extensions.ExtensionArray | XVariable], MetaAcc[R]): +class ModMetaAcc[R: AdRef[str]](_ModalityMapAcc[str, pd.api.extensions.ExtensionArray | XVariable], MetaAcc[R]): """Reference accessor for arrays from metadata containers (:attr:`~anndata.acc.AdAcc.obs` / :attr:`~anndata.acc.AdAcc.var`).""" def __repr__(self) -> str: @@ -193,8 +194,8 @@ def __repr__(self) -> str: return f"A.mod[{self.mod}]" -@dataclass(frozen=True) -class MultiModAcc[R: AdRef](MapAcc[ModAcc[R]]): +@dataclass(frozen=True, kw_only=True) +class MultiModAcc[R: AdRef](MapAcc[ModAcc]): """Accessor for modalities (:attr:`~MuAcc.mod`).""" ref_class: type[R] @@ -235,8 +236,6 @@ def __post_init__(self) -> None: del self.__dict__["X"] del self.__dict__["layers"] - del self.__dataclass_fields__["X"] - del self.__dataclass_fields__["layers"] def __getitem__(self, k: str, /) -> ModAcc[R]: return self.mod[k] @@ -273,4 +272,12 @@ def resolve(self, spec: str, *, strict: bool = True) -> R | None: return super().resolve(spec, strict=strict) +del MuAcc.__dataclass_fields__["X"] +del MuAcc.__dataclass_fields__["layers"] +del MuAcc.__dataclass_fields__["layer_cls"] + A: MuAcc[AdRef] = MuAcc() + + +if not TYPE_CHECKING: # https://github.com/tox-dev/sphinx-autodoc-typehints/issues/580 + R = AdRef[Hashable] From 35b2138ef081950fcd52a1bb8155516155fe216f Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 12:01:44 +0200 Subject: [PATCH 08/19] implement to/from_json --- src/mudata/acc/__init__.py | 27 ++++++++++++++++++++++++++- tests/test_accessors.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index a62ad62..8992061 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Hashable +from collections.abc import Hashable, Sequence from dataclasses import KW_ONLY, dataclass, field from typing import TYPE_CHECKING, Any, ClassVar, Literal @@ -193,6 +193,10 @@ def __post_init__(self) -> None: def __repr__(self) -> str: return f"A.mod[{self.mod}]" + def to_json(self, ref: R) -> list[str | int | None]: + """Serialize :class:`~anndata.acc.AdRef` to a JSON-compatible list.""" + return ["mod", self.mod, super().to_json(ref)] + @dataclass(frozen=True, kw_only=True) class MultiModAcc[R: AdRef](MapAcc[ModAcc]): @@ -271,6 +275,27 @@ def resolve(self, spec: str, *, strict: bool = True) -> R | None: case _: return super().resolve(spec, strict=strict) + def to_json(self, ref: R) -> list[str | int | None]: + """Serialize :class:`~anndata.acc.AdRef` to a JSON-compatible list.""" + if isinstance(ref.acc, ModMapAcc): + return [f"{ref.acc.dim}map", ref.idx] + + ret = super().to_json(ref) + if isinstance(ref.acc, _ModalityMixin): + ret = ["mod", ref.acc.mod, ret] + return ret + + def from_json(self, data: Sequence[str | int | None]) -> R: + """Create a :class:`~anndata.acc.AdRef` from a JSON sequence.""" + match data: + case ["mod", str() as modname, list() as inner]: + return self.mod[modname].from_json(inner) + case ["obsmap" | "varmap" as dim, str() as modname]: + acc = self.obsmap if dim == "obsmap" else self.varmap + return acc[modname] + case _: + return super().from_json(data) + del MuAcc.__dataclass_fields__["X"] del MuAcc.__dataclass_fields__["layers"] diff --git a/tests/test_accessors.py b/tests/test_accessors.py index d6e2405..2b8df83 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -114,3 +114,32 @@ def test_resolve(): with pytest.raises(ValueError, match="period-separated"): A.resolve("abcd") + + +@pytest.mark.parametrize("acc", [path[0] for path in PATHS if isinstance(path[0], ad.acc.AdRef)]) +def test_to_from_json(acc): + serialized = A.to_json(acc) + if isinstance(acc.acc, md.acc.ModMapAcc): + assert serialized[0] == f"{acc.acc.dim}map" + assert serialized[1] == acc.idx + else: + assert serialized[0] == "mod" + assert serialized[1] == acc.acc.mod + assert serialized[2] == ad.acc.A.to_json(acc) + + assert A.from_json(serialized) == acc + + +@pytest.mark.parametrize( + "acc", + [path[0] for path in PATHS if isinstance(path[0], ad.acc.AdRef) and not isinstance(path[0].acc, md.acc.ModMapAcc)], +) +def test_to_from_json_mod(acc): + modA = A.mod["foobar"] + + serialized = modA.to_json(acc) + assert serialized[0] == "mod" + assert serialized[1] == "foobar" + assert serialized[2] == ad.acc.A.to_json(acc) + + assert A.from_json(serialized).acc.mod == "foobar" From eba2cdb4057de8deb2edf63fc47e7506c92e23bd Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 14:38:58 +0200 Subject: [PATCH 09/19] add changelog entry and require anndata 0.13 for docs --- CHANGELOG.md | 5 +++++ docs/conf.py | 2 +- pyproject.toml | 1 + 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4057365..a259394 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ and this project adheres to [Semantic Versioning][]. ## [0.4.0] (Unreleased) +### Added + +- MuData accessors. These are similar to and build on [AnnData accessors](https://anndata.scverse.org/page/accessors.html), but add an additional + level for modalities. + ### Changed - `update()` no longer automatically pulls obs/var columns from individual modalities by default. Set `mudata.set_options(pull_on_update=true)` diff --git a/docs/conf.py b/docs/conf.py index 493c8d5..b2012db 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -97,7 +97,7 @@ intersphinx_mapping = { "python": ("https://docs.python.org/3", None), - "anndata": ("https://anndata.readthedocs.io/en/latest/", None), + "anndata": ("https://anndata.readthedocs.io/en/stable/", None), "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), "numpy": ("https://numpy.org/doc/stable/", None), "pandas": ("https://pandas.pydata.org/docs/", None), diff --git a/pyproject.toml b/pyproject.toml index ede5a0c..126133d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dev = [ ] test = [ "coverage>=7.10", "mudata[io]", "packaging", "pytest" ] doc = [ + "anndata>=0.13rc1", "docutils>=0.8,!=0.18.*,!=0.19.*", "ipykernel", "ipython", From 7dc6577517379b80a2d701adea99a14909e8b3eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Jun 2026 12:46:30 +0000 Subject: [PATCH 10/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/mudata/_core/mudata.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mudata/_core/mudata.py b/src/mudata/_core/mudata.py index e661653..f34b9e8 100644 --- a/src/mudata/_core/mudata.py +++ b/src/mudata/_core/mudata.py @@ -555,6 +555,7 @@ def __contains__(self, key) -> bool: return key in self._mod with suppress(ImportError): from anndata.acc import AdRef, MapAcc, RefAcc + from ..acc import ModAcc, MultiModAcc, _ModalityMapAcc, _ModalityMixin if isinstance(key, ModAcc | _ModalityMapAcc): From c18f8b4c2c7c9740c89a563628a31601fe2e6405 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 14:49:37 +0200 Subject: [PATCH 11/19] docs: don't exclude X and layers from ModAcc --- docs/_templates/autosummary/class-accessor.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/_templates/autosummary/class-accessor.rst b/docs/_templates/autosummary/class-accessor.rst index cf3deb0..64c380a 100644 --- a/docs/_templates/autosummary/class-accessor.rst +++ b/docs/_templates/autosummary/class-accessor.rst @@ -1,6 +1,8 @@ {{ fullname | escape | underline}} +{%if fullname == "mudata.acc.MuAcc" %} {% set attributes = attributes | select("ne", "X") | select("ne", "layers") %} +{% endif %} .. currentmodule:: {{ module }} From f43b7bb818bbfaf1fa47e622465dd23d710ed548 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 14:54:44 +0200 Subject: [PATCH 12/19] fix test --- tests/test_accessors.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 2b8df83..0cdb854 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -7,11 +7,12 @@ from packaging.version import Version import mudata as md -from mudata.acc import A if Version(ad.__version__) < Version("0.13dev0"): pytest.skip("anndata version too old, no accessor support", allow_module_level=True) +from mudata.acc import A + @pytest.fixture def mdata_augmented(mdata: md.MuData, rng: np.random.Generator): From 250b05cc61bac117d66067e79abb4927f352f582 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 Jun 2026 16:47:02 +0200 Subject: [PATCH 13/19] export mudata.acc when available --- src/mudata/__init__.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mudata/__init__.py b/src/mudata/__init__.py index bf7695d..6f1b3d8 100644 --- a/src/mudata/__init__.py +++ b/src/mudata/__init__.py @@ -25,9 +25,6 @@ from ._core.to_ import to_anndata, to_mudata from ._version import __version__, __version_tuple__ -with suppress(ImportError): - from . import acc - # file format versions __anndataversion__ = "0.1.0" __mudataversion__ = "0.1.0" @@ -56,3 +53,8 @@ "register_mudata_namespace", "ExtensionNamespace", ] + +with suppress(ImportError): + from . import acc + + __all__.append("acc") From 1f8b73b9151740ca8a7bdce460b4899338f0e900 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 17 Jun 2026 16:51:10 +0200 Subject: [PATCH 14/19] add test --- tests/test_accessors.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 0cdb854..0c31f9c 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -17,6 +17,7 @@ @pytest.fixture def mdata_augmented(mdata: md.MuData, rng: np.random.Generator): mdata["mod1"].layers["counts"] = rng.poisson(1, size=mdata["mod1"].shape) + mdata["mod2"].varm["test"] = rng.normal(size=(mdata["mod2"].n_vars, 3)) mdata["mod2"].obsp["test"] = rng.normal(size=(mdata["mod2"].n_obs, mdata["mod2"].n_obs)) return mdata @@ -42,6 +43,8 @@ def test_anndata_accessors(mdata: md.MuData): (A["mod1"].layers, lambda md: md["mod1"].layers), (A["mod1"].layers["counts"], lambda md: md["mod1"].layers["counts"]), (A["mod1"].layers["counts"]["obs_2", :], lambda md: md["mod1"]["obs_2", :].layers["counts"].squeeze()), + (A["mod2"].varm, lambda md: md["mod2"].varm), + (A["mod2"].varm["test"], lambda md: md["mod2"].varm["test"]), (A["mod2"].obsp, lambda md: md["mod2"].obsp), (A["mod2"].obsp["test"], lambda md: md["mod2"].obsp["test"]), (A["mod2"].obsp["test"][:, "obs_3"], lambda md: md["mod2"].obsp["test"][:, md["mod2"].obs_names.get_loc("obs_3")]), From fae8a29ceff9859e8e136368b9fcd7c31a46ba48 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Thu, 18 Jun 2026 15:46:52 +0200 Subject: [PATCH 15/19] update for 0.13rc2 --- src/mudata/acc/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index 8992061..a542e8d 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -23,11 +23,11 @@ from anndata.compat import XVariable from anndata.typing import InMemoryArray +from .. import MuData + if TYPE_CHECKING: from anndata import AnnData - from .. import MuData - @dataclass(frozen=True, kw_only=True) class _ModalityMixin: @@ -125,7 +125,7 @@ def __repr__(self) -> str: @dataclass(frozen=True) -class ModMapAcc[R: AdRef[str]](RefAcc[R, str]): +class ModMapAcc[R: AdRef[str]](RefAcc[R, str, MuData]): """Reference accessor for modality maps (:attr:`~MuAcc.obsmap` / :attr:`~MuAcc.varmap`).""" dim: Literal["obs", "var"] @@ -305,4 +305,4 @@ def from_json(self, data: Sequence[str | int | None]) -> R: if not TYPE_CHECKING: # https://github.com/tox-dev/sphinx-autodoc-typehints/issues/580 - R = AdRef[Hashable] + R = AdRef[Hashable, MuData] From 701bbe335bdf6d75c7791bb455ef3ea5da0eff2d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jun 2026 18:16:02 +0000 Subject: [PATCH 16/19] [pre-commit.ci] pre-commit autoupdate (#162) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit updates: - [github.com/tox-dev/pyproject-fmt: v2.24.1 → v2.25.0](https://github.com/tox-dev/pyproject-fmt/compare/v2.24.1...v2.25.0) - [github.com/astral-sh/ruff-pre-commit: v0.15.17 → v0.15.18](https://github.com/astral-sh/ruff-pre-commit/compare/v0.15.17...v0.15.18) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .pre-commit-config.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 78cd043..eda29b6 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,11 +12,11 @@ repos: - id: biome-format exclude: ^\.cruft\.json$ # inconsistent indentation with cruft - file never to be modified manually. - repo: https://github.com/tox-dev/pyproject-fmt - rev: v2.24.1 + rev: v2.25.0 hooks: - id: pyproject-fmt - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.17 + rev: v0.15.18 hooks: - id: ruff-check types_or: [python, pyi, jupyter] From 24461e27619263e8fad1cf3df61c7e18f8e1a994 Mon Sep 17 00:00:00 2001 From: ilia-kats Date: Wed, 24 Jun 2026 09:19:27 +0200 Subject: [PATCH 17/19] Update tests/test_accessors.py Co-authored-by: Ilan Gold --- tests/test_accessors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 0c31f9c..3f3867b 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -25,7 +25,7 @@ def mdata_augmented(mdata: md.MuData, rng: np.random.Generator): def test_anndata_accessors(mdata: md.MuData): assert ad.acc.A.obs["arange"] in mdata - assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() + assert mdata[ad.acc.A.obs["arange"]] is mdata.obs["arange"] with pytest.raises(KeyError, match="test"): mdata[ad.acc.A.var["test"]] with pytest.raises(KeyError, match="there is one in"): From 712d4c060da04104762c4d7bb5e96cb6fca91104 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 24 Jun 2026 09:25:02 +0200 Subject: [PATCH 18/19] Revert "Update tests/test_accessors.py" This reverts commit 36b3aeae18fae70716238182012f35c6c2218995. --- tests/test_accessors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 3f3867b..0c31f9c 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -25,7 +25,7 @@ def mdata_augmented(mdata: md.MuData, rng: np.random.Generator): def test_anndata_accessors(mdata: md.MuData): assert ad.acc.A.obs["arange"] in mdata - assert mdata[ad.acc.A.obs["arange"]] is mdata.obs["arange"] + assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() with pytest.raises(KeyError, match="test"): mdata[ad.acc.A.var["test"]] with pytest.raises(KeyError, match="there is one in"): From 2432894d31bc4f0d70cd400900cc66ea0710f0f5 Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Wed, 24 Jun 2026 11:05:40 +0200 Subject: [PATCH 19/19] add JSON schema --- docs/conf.py | 2 ++ pyproject.toml | 2 +- src/mudata/acc/__init__.py | 10 ++++++++-- src/mudata/acc/acc-schema-v1.json | 31 +++++++++++++++++++++++++++++++ tests/test_accessors.py | 27 +++++++++++++++++++++++++-- 5 files changed, 67 insertions(+), 5 deletions(-) create mode 100644 src/mudata/acc/acc-schema-v1.json diff --git a/docs/conf.py b/docs/conf.py index b2012db..4245940 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,6 +15,7 @@ HERE = Path(__file__).parent sys.path.insert(0, str(HERE / "extensions")) +acc_schema = HERE.parent / "src/mudata/acc/acc-schema-v1.json" # -- Project information ----------------------------------------------------- @@ -120,6 +121,7 @@ # html_theme = "sphinx_book_theme" html_static_path = ["_static"] +html_extra_path = [str(acc_schema)] html_logo = "_static/img/mudata.svg" html_css_files = ["css/custom.css"] diff --git a/pyproject.toml b/pyproject.toml index 126133d..ab8d499 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ dev = [ "pre-commit", "twine>=4.0.2", ] -test = [ "coverage>=7.10", "mudata[io]", "packaging", "pytest" ] +test = [ "coverage>=7.10", "jsonschema", "mudata[io]", "packaging", "pytest", "referencing" ] doc = [ "anndata>=0.13rc1", "docutils>=0.8,!=0.18.*,!=0.19.*", diff --git a/src/mudata/acc/__init__.py b/src/mudata/acc/__init__.py index a542e8d..b0d815c 100644 --- a/src/mudata/acc/__init__.py +++ b/src/mudata/acc/__init__.py @@ -276,7 +276,10 @@ def resolve(self, spec: str, *, strict: bool = True) -> R | None: return super().resolve(spec, strict=strict) def to_json(self, ref: R) -> list[str | int | None]: - """Serialize :class:`~anndata.acc.AdRef` to a JSON-compatible list.""" + """Serialize :class:`~anndata.acc.AdRef` to a JSON-compatible list. + + Schema: `acc-schema-v1.json <../acc-schema-v1.json>`_ + """ if isinstance(ref.acc, ModMapAcc): return [f"{ref.acc.dim}map", ref.idx] @@ -286,7 +289,10 @@ def to_json(self, ref: R) -> list[str | int | None]: return ret def from_json(self, data: Sequence[str | int | None]) -> R: - """Create a :class:`~anndata.acc.AdRef` from a JSON sequence.""" + """Create a :class:`~anndata.acc.AdRef` from a JSON sequence. + + Schema: `acc-schema-v1.json <../acc-schema-v1.json>`_ + """ match data: case ["mod", str() as modname, list() as inner]: return self.mod[modname].from_json(inner) diff --git a/src/mudata/acc/acc-schema-v1.json b/src/mudata/acc/acc-schema-v1.json new file mode 100644 index 0000000..e0db520 --- /dev/null +++ b/src/mudata/acc/acc-schema-v1.json @@ -0,0 +1,31 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "array", + "oneOf": [ + { + "title": "Modality index mapping", + "minItems": 2, + "maxItems": 2, + "prefixItems": [ + { "enum": ["obsmap", "varmap"] }, + { "type": "string" } + ] + }, + { + "title": "Modality data", + "minItems": 3, + "maxItems": 3, + "prefixItems": [ + { "const": "mod" }, + { "type": "string" }, + { + "$ref": "https://anndata.scverse.org/en/latest/acc-schema-v1.json" + } + ] + }, + { + "title": "Global metadata", + "$ref": "https://anndata.scverse.org/en/latest/acc-schema-v1.json" + } + ] +} diff --git a/tests/test_accessors.py b/tests/test_accessors.py index 0c31f9c..a70d740 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -1,9 +1,14 @@ +import json from dataclasses import fields +from importlib import resources +from urllib.request import urlopen import anndata as ad +import jsonschema import numpy as np import pandas as pd import pytest +import referencing from packaging.version import Version import mudata as md @@ -23,6 +28,22 @@ def mdata_augmented(mdata: md.MuData, rng: np.random.Generator): return mdata +@pytest.fixture(scope="session") +def mudata_json_schema(): + schema = json.loads(resources.files(md).joinpath("acc/acc-schema-v1.json").read_text()) + jsonschema.Draft202012Validator.check_schema(schema) + return schema + + +@pytest.fixture(scope="session") +def anndata_schema_registry(): + anndata_schema_uri = "https://anndata.scverse.org/en/latest/acc-schema-v1.json" + with urlopen(anndata_schema_uri) as response: + anndata_schema = json.load(response) + schema = referencing.Resource.from_contents(anndata_schema) + return referencing.Registry().with_resource(anndata_schema_uri, schema) + + def test_anndata_accessors(mdata: md.MuData): assert ad.acc.A.obs["arange"] in mdata assert (mdata[ad.acc.A.obs["arange"]] == mdata.obs["arange"]).all() @@ -121,8 +142,9 @@ def test_resolve(): @pytest.mark.parametrize("acc", [path[0] for path in PATHS if isinstance(path[0], ad.acc.AdRef)]) -def test_to_from_json(acc): +def test_to_from_json(mudata_json_schema, anndata_schema_registry, acc): serialized = A.to_json(acc) + jsonschema.validate(serialized, mudata_json_schema, registry=anndata_schema_registry) if isinstance(acc.acc, md.acc.ModMapAcc): assert serialized[0] == f"{acc.acc.dim}map" assert serialized[1] == acc.idx @@ -138,10 +160,11 @@ def test_to_from_json(acc): "acc", [path[0] for path in PATHS if isinstance(path[0], ad.acc.AdRef) and not isinstance(path[0].acc, md.acc.ModMapAcc)], ) -def test_to_from_json_mod(acc): +def test_to_from_json_mod(mudata_json_schema, anndata_schema_registry, acc): modA = A.mod["foobar"] serialized = modA.to_json(acc) + jsonschema.validate(serialized, mudata_json_schema, registry=anndata_schema_registry) assert serialized[0] == "mod" assert serialized[1] == "foobar" assert serialized[2] == ad.acc.A.to_json(acc)