diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index 3ac1c33d7..38b0863be 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -98,10 +98,19 @@ def _validate_value(self, val: Value, key: str) -> Value: name = f"{self.attrname.title().rstrip('s')} {key!r}" return coerce_array(val, name=name, allow_df=self._allow_df) + _attrname_override: str | None = None + @property - @abstractmethod def attrname(self) -> str: """What attr for the AnnData is this?""" + if self._attrname_override is not None: + return self._attrname_override + return self._default_attrname + + @property + @abstractmethod + def _default_attrname(self) -> str: + """Default attr name derived from axis (e.g., 'obsm', 'varp').""" @property @abstractmethod @@ -151,6 +160,9 @@ def __init__(self, parent_mapping: P, parent_view: AnnData, subset_idx: I) -> No self.parent_mapping = parent_mapping self._parent = parent_view self.subset_idx = subset_idx + # Propagate attrname override from actual to view (for registered sections) + if parent_mapping._attrname_override is not None: + self._attrname_override = parent_mapping._attrname_override if hasattr(parent_mapping, "_axis"): # LayersBase has no _axis, the rest does self._axis = parent_mapping._axis # type: ignore @@ -237,7 +249,7 @@ class AxisArraysBase(AlignedMappingBase): _axis: Literal[0, 1] @property - def attrname(self) -> str: + def _default_attrname(self) -> str: return f"{self.dim}m" @property @@ -311,9 +323,12 @@ class LayersBase(AlignedMappingBase): """ _allow_df: ClassVar = False - attrname: ClassVar[Literal["layers"]] = "layers" axes: ClassVar[tuple[Literal[0], Literal[1]]] = (0, 1) + @property + def _default_attrname(self) -> str: + return "layers" + class Layers(AlignedActual, LayersBase): pass @@ -339,7 +354,7 @@ class PairwiseArraysBase(AlignedMappingBase): _axis: Literal[0, 1] @property - def attrname(self) -> str: + def _default_attrname(self) -> str: return f"{self.dim}p" @property @@ -402,8 +417,13 @@ class AlignedMappingProperty[T: AlignedMapping](property): def construct(self, obj: AnnData, *, store: MutableMapping[str, Value]) -> T: if self.axis is None: - return self.cls(obj, store=store) - return self.cls(obj, axis=self.axis, store=store) + mapping = self.cls(obj, store=store) + else: + mapping = self.cls(obj, axis=self.axis, store=store) + # Override attrname for registered sections (e.g., "obst" instead of "obsm") + if mapping._default_attrname != self.name: + mapping._attrname_override = self.name + return mapping @property def fget(self) -> Callable[[], None]: @@ -420,7 +440,11 @@ def __get__(self, obj: None | AnnData, objtype: type | None = None) -> T: # this needs to return a `property` instance, e.g. for Sphinx return self # type: ignore if not obj.is_view: - return self.construct(obj, store=getattr(obj, f"_{self.name}")) + store = getattr(obj, f"_{self.name}", None) + if store is None: + store = {} + setattr(obj, f"_{self.name}", store) + return self.construct(obj, store=store) parent_anndata = obj._adata_ref idxs = (obj._oidx, obj._vidx) parent: AlignedMapping = getattr(parent_anndata, self.name) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index b4e7fb3c2..ab65b8cd3 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -205,6 +205,7 @@ class AnnData(metaclass=utils.DeprecationMixinMeta): # noqa: PLW1641 ) _accessors: ClassVar[set[str]] = set() + _registered_sections: ClassVar[dict] = {} # str -> SectionSpec # view attributes _adata_ref: AnnData | None @@ -242,6 +243,7 @@ def __init__( # noqa: PLR0913 varp: np.ndarray | Mapping[str, Sequence[Any]] | None = None, oidx: _Index1DNorm | int | np.integer | None = None, vidx: _Index1DNorm | int | np.integer | None = None, + **extra_sections, ): # check for any multi-indices that aren’t later checked in coerce_array for attr, key in [(obs, "obs"), (var, "var"), (X, "X")]: @@ -270,6 +272,7 @@ def __init__( # noqa: PLR0913 varp=varp, filename=filename, filemode=filemode, + **extra_sections, ) def _init_as_view( @@ -361,6 +364,7 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 shape=None, filename=None, filemode=None, + **extra_sections, ): # view attributes self._is_view = False @@ -391,6 +395,15 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 if any((obs, var, uns, obsm, varm, obsp, varp)): msg = "If `X` is a dict no further arguments must be provided." raise ValueError(msg) + # Copy extension sections from source AnnData + # (built-in sections are handled by the explicit unpacking below) + for sec_name, spec in self._registered_sections.items(): + if spec.builtin: + continue + if sec_name not in extra_sections: + src_mapping = getattr(X, sec_name, None) + if src_mapping is not None and len(src_mapping) > 0: + extra_sections[sec_name] = dict(src_mapping) X, obs, var, uns, obsm, varm, obsp, varp, layers, raw = ( X._X, X.obs, @@ -509,6 +522,12 @@ def _init_as_actual( # noqa: PLR0912, PLR0913, PLR0915 # layers self.layers = layers + # registered sections (e.g., obst, vart from extensions) + for sec_name in self._registered_sections: + value = extra_sections.get(sec_name) + if value is not None: + setattr(self, sec_name, value) + @old_positionals("show_stratified", "with_disk") def __sizeof__( self, *, show_stratified: bool = False, with_disk: bool = False @@ -545,21 +564,17 @@ def cs_to_bytes(X) -> int: return sum(sizes.values()) def _gen_repr(self, n_obs, n_vars) -> str: + from .section_registry import iter_sections + backed_at = f" backed at {str(self.filename)!r}" if self.isbacked else "" descr = f"AnnData object with n_obs × n_vars = {n_obs} × {n_vars}{backed_at}" - for attr in [ - "obs", - "var", - "uns", - "obsm", - "varm", - "layers", - "obsp", - "varp", - ]: - keys = getattr(self, attr).keys() + for spec, value in iter_sections(self, exclude_kinds={"X", "raw"}): + try: + keys = value.keys() + except Exception: # noqa: BLE001 + continue if len(keys) > 0: - descr += f"\n {attr}: {str(list(keys))[1:-1]}" + descr += f"\n {spec.name}: {str(list(keys))[1:-1]}" return descr def __repr__(self) -> str: @@ -1413,11 +1428,13 @@ def _mutated_copy(self, **kwargs) -> AnnData: raise NotImplementedError(msg) new = {} - for key in ["obs", "var", "obsm", "varm", "obsp", "varp", "layers"]: - if key in kwargs: - new[key] = kwargs[key] + from .section_registry import iter_sections + + for spec, value in iter_sections(self, kinds={"dataframe", "mapping"}): + if spec.name in kwargs: + new[spec.name] = kwargs[spec.name] else: - new[key] = getattr(self, key).copy() + new[spec.name] = value.copy() if "X" in kwargs: new["X"] = kwargs["X"] elif self._has_X(): @@ -2154,6 +2171,13 @@ def _remove_unused_categories_xr( pass # this is handled automatically by the categorical arrays themselves i.e., they dedup upon access. +# Populate _registered_sections with built-in section specs. +# Must happen after AnnData class definition is complete. +from .section_registry import _init_builtin_sections # noqa: E402 + +_init_builtin_sections(AnnData) + + def _check_2d_shape(X): """\ Check shape of array or sparse matrix. diff --git a/src/anndata/_core/extensions.py b/src/anndata/_core/extensions.py index 180a15a61..4fabf0f49 100644 --- a/src/anndata/_core/extensions.py +++ b/src/anndata/_core/extensions.py @@ -9,12 +9,18 @@ if TYPE_CHECKING: from collections.abc import Callable + from typing import Literal + + from anndata._repr.registry import FormattedEntry, FormatterContext # Based off of the extension framework in Polars # https://github.com/pola-rs/polars/blob/main/py-polars/polars/api.py -__all__ = ["register_anndata_namespace"] +__all__ = ["SectionSpec", "register_anndata_namespace", "register_section"] + +# Protocol for accessors that provide section visualization +REPR_SECTION_METHOD = "_repr_section_" # Reserved namespaces include accessors built into AnnData (currently there are none) @@ -121,6 +127,81 @@ def _check_namespace_signature(ns_class: type) -> None: raise TypeError(msg) +def _create_accessor_section_formatter( + name: str, ns_class: type[ExtensionNamespace] +) -> None: + """Create and register a SectionFormatter for an accessor with _repr_section_ method. + + This enables unified accessor + visualization registration. When an accessor + class defines a `_repr_section_` method, a SectionFormatter is automatically + registered that delegates to the accessor instance. + + Parameters + ---------- + name + The accessor name (used as section name) + ns_class + The accessor class that has a _repr_section_ method + """ + from anndata._repr.registry import ( + FormatterContext, + SectionFormatter, + register_formatter, + ) + + # Get optional section configuration from class attributes + after_section = getattr(ns_class, "section_after", None) + display_name = getattr(ns_class, "section_display_name", name) + tooltip = getattr(ns_class, "section_tooltip", "") + doc_url = getattr(ns_class, "section_doc_url", None) + + class AccessorSectionFormatter(SectionFormatter): + """Auto-generated SectionFormatter that delegates to accessor._repr_section_.""" + + @property + def section_name(self) -> str: + return name + + @property + def display_name(self) -> str: + return display_name + + @property + def after_section(self) -> str | None: + return after_section + + @property + def tooltip(self) -> str: + return tooltip + + @property + def doc_url(self) -> str | None: + return doc_url + + def should_show(self, obj: AnnData) -> bool: + if not hasattr(obj, name): + return False + accessor = getattr(obj, name) + if not hasattr(accessor, REPR_SECTION_METHOD): + return False + # Call _repr_section_ to check if it returns entries + result = getattr(accessor, REPR_SECTION_METHOD)(FormatterContext()) + return result is not None and len(result) > 0 + + def get_entries( + self, obj: AnnData, context: FormatterContext + ) -> list[FormattedEntry]: + accessor = getattr(obj, name) + result = getattr(accessor, REPR_SECTION_METHOD)(context) + return result if result is not None else [] + + # Give it a meaningful name for debugging + AccessorSectionFormatter.__name__ = f"{ns_class.__name__}SectionFormatter" + AccessorSectionFormatter.__qualname__ = f"{ns_class.__name__}SectionFormatter" + + register_formatter(AccessorSectionFormatter()) + + def _create_namespace[NameSpT: ExtensionNamespace]( name: str, cls: type[AnnData] ) -> Callable[[type[NameSpT]], type[NameSpT]]: @@ -138,6 +219,11 @@ def namespace(ns_class: type[NameSpT]) -> type[NameSpT]: ) setattr(cls, name, AccessorNameSpace(name, ns_class)) cls._accessors.add(name) + + # Auto-register SectionFormatter if accessor has _repr_section_ method + if hasattr(ns_class, REPR_SECTION_METHOD): + _create_accessor_section_formatter(name, ns_class) + return ns_class return namespace @@ -169,13 +255,32 @@ def register_anndata_namespace[NameSpT: ExtensionNamespace]( ----- Implementation requirements: - 1. The decorated class must have an `__init__` method that accepts exactly one parameter + 1. The decorated class must have an `__init__`` method that accepts exactly one parameter (besides `self`) named `adata` and annotated with type :class:`~anndata.AnnData`. 2. The namespace will be initialized with the AnnData object on first access and then cached on the instance. 3. If the namespace name conflicts with an existing namespace, a warning is issued. 4. If the namespace name conflicts with a built-in AnnData attribute, an AttributeError is raised. + HTML Representation + ~~~~~~~~~~~~~~~~~~~ + If the accessor class defines a ``_repr_section_`` method, a section will automatically + be added to the HTML representation. This enables unified accessor + visualization + registration with a single decorator. + + The ``_repr_section_`` method should have the signature:: + + def _repr_section_(self, context: FormatterContext) -> list[FormattedEntry] | None: + '''Return entries for HTML repr, or None to hide section.''' + ... + + Optional class attributes for section configuration: + + - ``section_after``: Section name after which this section appears (e.g., "obsm") + - ``section_display_name``: Display name for the section header (defaults to accessor name) + - ``section_tooltip``: Tooltip text for the section header + - ``section_doc_url``: URL to documentation (shown as link icon in header) + Examples -------- Simple transformation namespace with two methods: @@ -233,5 +338,253 @@ def register_anndata_namespace[NameSpT: ExtensionNamespace]( >>> adata.transform.arcsinh() # Transforms X and returns the AnnData object AnnData object with n_obs × n_vars = 100 × 2000 layers: 'log1p', 'arcsinh' + + Accessor with HTML section visualization: + + .. code-block:: python + + from anndata.extensions import ( + register_anndata_namespace, + FormattedEntry, + FormattedOutput, + ) + + + @register_anndata_namespace("spatial") + class SpatialAccessor: + # Optional: configure section positioning and display + section_after = "obsm" + section_display_name = "spatial" + section_tooltip = "Spatial data (images, coordinates)" + section_doc_url = "https://spatialdata.readthedocs.io/" + + def __init__(self, adata: ad.AnnData): + self._adata = adata + + @property + def images(self): + return self._adata.uns.get("spatial_images", {}) + + def add_image(self, key, image): + if "spatial_images" not in self._adata.uns: + self._adata.uns["spatial_images"] = {} + self._adata.uns["spatial_images"][key] = image + + def _repr_section_(self, context) -> list[FormattedEntry] | None: + '''Return entries for HTML repr, or None to hide section.''' + if not self.images: + return None + return [ + FormattedEntry( + key=k, + output=FormattedOutput( + type_name=f"Image {v.shape}", + css_class="dtype-array", + ), + ) + for k, v in self.images.items() + ] + + + # Usage: + adata.spatial.add_image("hires", np.zeros((100, 100, 3))) + adata._repr_html_() # Shows "spatial" section with "hires" entry """ return _create_namespace(name, AnnData) + + +# --------------------------------------------------------------------------- +# Section registration +# --------------------------------------------------------------------------- + +from .section_registry import SectionProperty, SectionSpec # noqa: E402 + + +def register_section( + name: str, + *, + alignment: Literal["obs", "var"] | tuple[Literal["obs", "var"], ...] = (), + io_key: str | None = None, +) -> Callable[[type], type]: + """Register a new section on :class:`~anndata.AnnData`. + + Decorator that creates a section from a class definition. The class + can optionally define methods and attributes to customize behavior. + + Parameters + ---------- + name + Attribute name on AnnData (e.g., ``"obst"``). Becomes ``adata.obst``. + alignment + Axes each dimension is aligned to. A string for single-axis + alignment, or a tuple for multi-axis. Examples: + ``"obs"`` for obs-aligned (like obsm), + ``("obs", "var")`` for both axes (like layers), + ``("obs", "obs")`` for pairwise (like obsp), + ``()`` for unaligned. + io_key + Key used in h5ad/zarr files. Defaults to *name*. + + Class Attributes (all optional) + -------------------------------- + value_type : type + Type check on assignment (e.g., ``nx.DiGraph``). + section_after : str + Position in repr (e.g., ``"obsm"``). + section_tooltip : str + Hover text in HTML repr. + section_doc_url : str + Documentation link in HTML repr. + + Class Methods (all optional, must be static) + --------------------------------------------- + validate(key, value) + Custom validation on assignment. Raise on invalid. + subset(value, idx) + Custom subsetting for ``adata[idx]``. Default uses anndata's + ``_subset`` dispatch (works for arrays, sparse, DataFrames). + serialize(value) + Custom serialization for IO. Return a serializable object. + deserialize(data) + Custom deserialization for IO. + repr_entry(key, value, context) + Custom HTML repr formatting. Return ``FormattedOutput``. + + Examples + -------- + Simple axis-aligned section (arrays, no custom behavior): + + .. code-block:: python + + @register_section("obst", alignment="obs") + class ObstSection: + pass + + Full-featured section (TreeData-like): + + .. code-block:: python + + @register_section("obst", alignment="obs") + class ObstSection: + value_type = nx.DiGraph + section_after = "obsm" + section_tooltip = "Observation trees" + + @staticmethod + def validate(key, value): + if not nx.is_tree(value): + raise ValueError(f"{key} must be a tree") + + @staticmethod + def subset(value, idx): + return subset_tree(value, idx) + + @staticmethod + def serialize(value): + return digraph_to_json(value) + + @staticmethod + def deserialize(data): + return json_to_digraph(data) + + Unaligned section (SpatialData-like): + + .. code-block:: python + + @register_section("images", alignment=()) + class ImagesSection: + value_type = MultiscaleImage + """ + + # Normalize alignment: string → 1-tuple + if isinstance(alignment, str): + alignment = (alignment,) + + def decorator(cls: type) -> type: + if name in AnnData._registered_sections: + msg = f"Section {name!r} is already registered" + raise ValueError(msg) + if name in _reserved_namespaces: + msg = f"Cannot register section {name!r}: conflicts with existing AnnData attribute" + raise AttributeError(msg) + + # Extract optional methods and attributes from the class + spec = SectionSpec( + name=name, + alignment=alignment, + io_key=io_key or name, + value_type=getattr(cls, "value_type", None), + validate_fn=getattr(cls, "validate", None), + subset_fn=getattr(cls, "subset", None), + serialize_fn=getattr(cls, "serialize", None), + deserialize_fn=getattr(cls, "deserialize", None), + repr_entry_fn=getattr(cls, "repr_entry", None), + section_after=getattr(cls, "section_after", None), + section_tooltip=getattr(cls, "section_tooltip", ""), + section_doc_url=getattr(cls, "section_doc_url", None), + ) + + # Create and attach the property descriptor + prop = SectionProperty(spec) + setattr(AnnData, name, prop) + + # Register + AnnData._registered_sections[name] = spec + _reserved_namespaces.add(name) + + # Auto-register SectionFormatter for HTML repr if repr metadata is present + if spec.section_after or spec.repr_entry_fn: + _create_section_repr_formatter(spec) + + return cls + + return decorator + + +def _create_section_repr_formatter(spec: SectionSpec) -> None: + """Auto-register a SectionFormatter for a registered section.""" + from anndata._repr.registry import ( + FormattedEntry, + FormattedOutput, + SectionFormatter, + register_formatter, + ) + + class RegisteredSectionFormatter(SectionFormatter): + @property + def section_name(self) -> str: + return spec.name + + @property + def after_section(self) -> str | None: + return spec.section_after + + @property + def tooltip(self) -> str: + return spec.section_tooltip + + @property + def doc_url(self) -> str | None: + return spec.section_doc_url + + def should_show(self, obj: AnnData) -> bool: + mapping = getattr(obj, spec.name, None) + return mapping is not None and len(mapping) > 0 + + def get_entries( + self, obj: AnnData, context: FormatterContext + ) -> list[FormattedEntry]: + mapping = getattr(obj, spec.name) + entries = [] + for k in mapping: + if spec.repr_entry_fn is not None: + output = spec.repr_entry_fn(k, mapping[k], context) + else: + output = FormattedOutput( + type_name=type(mapping[k]).__name__, + ) + entries.append(FormattedEntry(key=k, output=output)) + return entries + + RegisteredSectionFormatter.__name__ = f"{spec.name}SectionFormatter" + register_formatter(RegisteredSectionFormatter()) diff --git a/src/anndata/_core/section_registry.py b/src/anndata/_core/section_registry.py new file mode 100644 index 000000000..833e2ab1e --- /dev/null +++ b/src/anndata/_core/section_registry.py @@ -0,0 +1,423 @@ +"""Pluggable section registry for AnnData. + +Provides the infrastructure for :func:`~anndata.extensions.register_section`: +container classes, view handling, and property descriptors that let external +packages add new sections to AnnData without subclassing. +""" + +from __future__ import annotations + +from collections.abc import Mapping, MutableMapping +from copy import copy +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from .views import view_update + +if TYPE_CHECKING: + from collections.abc import Callable, Iterator + from typing import Any, Literal + +if TYPE_CHECKING: + from anndata import AnnData + + from .._repr.registry import FormattedOutput, FormatterContext + + +def _axis_len(value: Any, dim: int) -> int | None: + """Get length of value along a dimension, or None if not applicable.""" + if hasattr(value, "shape"): + shape = value.shape + if dim < len(shape): + return shape[dim] + return None + + +@dataclass(frozen=True) +class SectionSpec: + """Complete specification for a registered section. + + Created by :func:`register_section` from the decorated class, + or internally for built-in sections. + """ + + name: str + """Attribute name on AnnData (e.g., ``"obst"``).""" + alignment: tuple[Literal["obs", "var"], ...] + """Axes each dimension is aligned to. Empty tuple for unaligned.""" + io_key: str + """Key used in h5ad/zarr files.""" + kind: Literal["X", "dataframe", "mapping", "unstructured", "raw"] = "mapping" + """Section category. Used by :func:`iter_sections` for filtering.""" + builtin: bool = False + """Whether this is a built-in section (vs. registered by an extension).""" + + # Optional callbacks extracted from the section class + value_type: type | None = None + validate_fn: Callable[[str, Any], None] | None = None + subset_fn: Callable[[Any, Any], Any] | None = None + serialize_fn: Callable[[Any], Any] | None = None + deserialize_fn: Callable[[Any], Any] | None = None + repr_entry_fn: Callable[[str, Any, FormatterContext], FormattedOutput] | None = None + + # Repr metadata + section_after: str | None = None + section_tooltip: str = "" + section_doc_url: str | None = None + + +class SectionMapping(MutableMapping): + """Container for a registered section's data. + + Validates values on assignment using the section's spec (type check, + alignment validation, custom validator). + """ + + def __init__( + self, parent: AnnData, spec: SectionSpec, data: dict | None = None + ) -> None: + self._parent = parent + self._spec = spec + self._data: dict[str, Any] = data if data is not None else {} + + def __repr__(self) -> str: + return f"{self._spec.name}: {', '.join(map(repr, self._data.keys()))}" + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __setitem__(self, key: str, value: Any) -> None: + # Type check + if self._spec.value_type is not None and not isinstance( + value, self._spec.value_type + ): + msg = ( + f"Values in {self._spec.name!r} must be {self._spec.value_type.__name__}, " + f"got {type(value).__name__}" + ) + raise TypeError(msg) + # Alignment validation + self._validate_alignment(key, value) + # Custom validation + if self._spec.validate_fn is not None: + self._spec.validate_fn(key, value) + self._data[key] = value + + def __delitem__(self, key: str) -> None: + del self._data[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._data) + + def __len__(self) -> int: + return len(self._data) + + def __contains__(self, key: object) -> bool: + return key in self._data + + def _validate_alignment(self, key: str, value: Any) -> None: + """Check that value dimensions match the expected axes.""" + for i, axis in enumerate(self._spec.alignment): + expected = self._parent.n_obs if axis == "obs" else self._parent.n_vars + actual = _axis_len(value, i) + if actual is not None and actual != expected: + n_name = "n_obs" if axis == "obs" else "n_vars" + msg = ( + f"Value for {self._spec.name}[{key!r}] has shape[{i}]={actual}, " + f"expected {expected} ({n_name})" + ) + raise ValueError(msg) + + def copy(self) -> dict[str, Any]: + """Return a deep copy of the underlying data.""" + return { + k: copy(v) if not hasattr(v, "copy") else v.copy() + for k, v in self._data.items() + } + + +class SectionMappingView(Mapping): + """Read-only view of a registered section that subsets on access. + + Writing triggers copy-on-write via anndata's view_update mechanism. + """ + + def __init__( + self, + parent_mapping: SectionMapping, + parent_view: AnnData, + obs_idx: Any, + var_idx: Any, + ) -> None: + self._parent_mapping = parent_mapping + self._parent = parent_view + self._spec = parent_mapping._spec + self._obs_idx = obs_idx + self._var_idx = var_idx + + def __repr__(self) -> str: + return f"{self._spec.name} (view): {', '.join(map(repr, self._parent_mapping._data.keys()))}" + + def __getitem__(self, key: str) -> Any: + value = self._parent_mapping[key] + if not self._spec.alignment: + return value # unaligned, no subsetting + return self._subset_value(value) + + def __setitem__(self, key: str, value: Any) -> None: + from .._warnings import ImplicitModificationWarning + from ..utils import warn + + warn( + f"Setting element `.{self._spec.name}[{key!r}]` of view, " + "initializing view as actual.", + ImplicitModificationWarning, + ) + with view_update(self._parent, self._spec.name, ()) as new_mapping: + new_mapping[key] = value + + def __delitem__(self, key: str) -> None: + from .._warnings import ImplicitModificationWarning + from ..utils import warn + + if key not in self: + msg = f"{key!r} not found in view of {self._spec.name}" + raise KeyError(msg) + warn( + f"Removing element `.{self._spec.name}[{key!r}]` of view, " + "initializing view as actual.", + ImplicitModificationWarning, + ) + with view_update(self._parent, self._spec.name, ()) as new_mapping: + del new_mapping[key] + + def __iter__(self) -> Iterator[str]: + return iter(self._parent_mapping) + + def __len__(self) -> int: + return len(self._parent_mapping) + + def __contains__(self, key: object) -> bool: + return key in self._parent_mapping + + def _subset_value(self, value: Any) -> Any: + """Subset a value according to the alignment tuple.""" + idx = self._build_index() + if self._spec.subset_fn is not None: + return self._spec.subset_fn(value, idx) + # Default subsetting: handle N-dimensional alignment + # anndata's _subset is designed for ≤2D, so for higher dims + # we do the indexing directly. + import numpy as np + + from anndata.compat import IndexManager + + if isinstance(idx, tuple) and len(idx) > 2: + # Convert IndexManagers to numpy arrays + resolved = [] + for ix in idx: + if isinstance(ix, IndexManager): + resolved.append(np.asarray(ix)) + else: + resolved.append(ix) + # Use np.ix_ for fancy indexing on non-slice dims + fancy_dims = [ + i for i, ix in enumerate(resolved) if not isinstance(ix, slice) + ] + if fancy_dims: + # Build an open mesh for fancy-indexed dims + fancy_arrs = [resolved[i] for i in fancy_dims] + mesh = np.ix_(*fancy_arrs) + # Build the full index tuple + full_idx = list(resolved) + for mi, di in enumerate(fancy_dims): + full_idx[di] = mesh[mi] + return value[tuple(full_idx)] + return value[tuple(resolved)] + # ≤2D: use anndata's _subset + from .index import _subset + + return _subset(value, idx) + + def _build_index(self) -> tuple: + """Build the index tuple from alignment and view indices.""" + indices = [] + for axis in self._spec.alignment: + if axis == "obs": + indices.append(self._obs_idx) + elif axis == "var": + indices.append(self._var_idx) + return tuple(indices) + + def copy(self) -> dict[str, Any]: + """Copy with subsetting applied.""" + return { + k: self[k].copy() if hasattr(self[k], "copy") else self[k] for k in self + } + + +class SectionProperty: + """Descriptor for registered sections on AnnData. + + Creates ephemeral SectionMapping / SectionMappingView on access, + similar to AlignedMappingProperty for built-in sections. + """ + + def __init__(self, spec: SectionSpec) -> None: + self.spec = spec + + def __get__(self, obj: AnnData | None, objtype: type | None = None) -> Any: + if obj is None: + return self + if not obj.is_view: + data = getattr(obj, f"_{self.spec.name}", None) + if data is None: + data = {} + setattr(obj, f"_{self.spec.name}", data) + return SectionMapping(obj, self.spec, data) + # View: create subsetting view + parent = obj._adata_ref + parent_mapping = getattr(parent, self.spec.name) + return SectionMappingView(parent_mapping, obj, obj._oidx, obj._vidx) + + def __set__(self, obj: AnnData, value: Mapping[str, Any] | None) -> None: + if value is None: + value = {} + if isinstance(value, (SectionMapping, SectionMappingView, Mapping)): + value = dict(value) + # Validate all values via SectionMapping + mapping = SectionMapping(obj, self.spec, {}) + for k, v in value.items(): + mapping[k] = v # validates each + if obj.is_view: + obj._init_as_actual(obj.copy()) + setattr(obj, f"_{self.spec.name}", mapping._data) + + def __delete__(self, obj: AnnData) -> None: + setattr(obj, f"_{self.spec.name}", {}) + + +# --------------------------------------------------------------------------- +# Built-in section specs (metadata only — the actual descriptors are +# AlignedMappingProperty instances already on the AnnData class) +# --------------------------------------------------------------------------- + +#: Ordered list of all built-in sections, used to seed ``_registered_sections``. +BUILTIN_SECTIONS: list[SectionSpec] = [ + SectionSpec(name="X", alignment=("obs", "var"), io_key="X", kind="X", builtin=True), + SectionSpec( + name="obs", alignment=("obs",), io_key="obs", kind="dataframe", builtin=True + ), + SectionSpec( + name="var", alignment=("var",), io_key="var", kind="dataframe", builtin=True + ), + SectionSpec( + name="uns", alignment=(), io_key="uns", kind="unstructured", builtin=True + ), + SectionSpec( + name="obsm", alignment=("obs",), io_key="obsm", kind="mapping", builtin=True + ), + SectionSpec( + name="varm", alignment=("var",), io_key="varm", kind="mapping", builtin=True + ), + SectionSpec( + name="layers", + alignment=("obs", "var"), + io_key="layers", + kind="mapping", + builtin=True, + ), + SectionSpec( + name="obsp", + alignment=("obs", "obs"), + io_key="obsp", + kind="mapping", + builtin=True, + ), + SectionSpec( + name="varp", + alignment=("var", "var"), + io_key="varp", + kind="mapping", + builtin=True, + ), + SectionSpec(name="raw", alignment=(), io_key="raw", kind="raw", builtin=True), +] + + +def _init_builtin_sections(cls: type[AnnData]) -> None: + """Populate ``_registered_sections`` with built-in section specs. + + Called once during AnnData class setup. Does NOT create descriptors — + the built-in ``AlignedMappingProperty`` instances are already on the class. + """ + for spec in BUILTIN_SECTIONS: + cls._registered_sections[spec.name] = spec + + +# --------------------------------------------------------------------------- +# Section iteration utility +# --------------------------------------------------------------------------- + + +def iter_sections( + adata: AnnData, + *, + kinds: set[str] | None = None, + exclude_kinds: set[str] | None = None, + only_nonempty: bool = False, +) -> Iterator[tuple[SectionSpec, Any]]: + """Iterate over AnnData sections with optional filtering. + + Yields ``(spec, value)`` pairs for each section, where *value* is + the result of ``getattr(adata, spec.name)``. + + Parameters + ---------- + adata + AnnData to iterate over. + kinds + If given, only yield sections whose ``kind`` is in this set. + E.g., ``{"mapping"}`` for dict-like sections (obsm, layers, …). + exclude_kinds + If given, skip sections whose ``kind`` is in this set. + E.g., ``{"unstructured", "raw"}`` to skip uns and raw. + only_nonempty + If ``True``, skip sections that are empty or ``None``. + + Examples + -------- + All mapping sections (built-in + registered): + + >>> for spec, mapping in iter_sections(adata, kinds={"mapping"}): + ... print(spec.name, list(mapping.keys())) + + Everything except uns and raw: + + >>> for spec, value in iter_sections(adata, exclude_kinds={"unstructured", "raw"}): + ... ... + + Non-empty sections for repr: + + >>> for spec, value in iter_sections(adata, only_nonempty=True): + ... print(spec.name) + """ + for spec in adata._registered_sections.values(): + if kinds is not None and spec.kind not in kinds: + continue + if exclude_kinds is not None and spec.kind in exclude_kinds: + continue + try: + value = getattr(adata, spec.name, None) + except Exception: # noqa: BLE001 + # Crashing objects in aligned mappings (adversarial data) + continue + if only_nonempty: + if value is None: + continue + try: + if len(value) == 0: + continue + except TypeError: + pass # no len, treat as non-empty + yield spec, value diff --git a/src/anndata/_io/h5ad.py b/src/anndata/_io/h5ad.py index d0540de1c..9ec0b93a0 100644 --- a/src/anndata/_io/h5ad.py +++ b/src/anndata/_io/h5ad.py @@ -93,14 +93,21 @@ def write_h5ad( dataset_kwargs=dataset_kwargs, ) _write_raw(f, adata.raw, as_dense=as_dense, dataset_kwargs=dataset_kwargs) - write_elem(f, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - write_elem(f, "var", adata.var, dataset_kwargs=dataset_kwargs) - write_elem(f, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - write_elem(f, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - write_elem(f, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - write_elem(f, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) + + # Write all non-X/raw sections via the unified registry + from anndata._core.section_registry import iter_sections + + for spec, value in iter_sections(adata, exclude_kinds={"X", "raw"}): + # Skip empty mappings (but always write DataFrames — they carry the index) + if spec.kind != "dataframe" and len(value) == 0: + continue + if spec.serialize_fn is not None: + data = {k: spec.serialize_fn(v) for k, v in value.items()} + elif spec.kind == "dataframe": + data = value # write DataFrame directly + else: + data = dict(value) # mappings and uns → dict + write_elem(f, spec.io_key, data, dataset_kwargs=dataset_kwargs) def _write_x( @@ -262,13 +269,22 @@ def read_h5ad( def callback(read_func, elem_name: str, elem: StorageType, iospec: IOSpec): if iospec.encoding_type == "anndata" or elem_name.endswith("/"): - return AnnData(**{ + d = { # This is covering up backwards compat in the anndata initializer # In most cases we should be able to call `func(elen[k])` instead k: read_dispatched(elem[k], callback) for k in elem if not k.startswith("raw.") - }) + } + # Deserialize registered sections + for spec in AnnData._registered_sections.values(): + if spec.io_key in d and spec.deserialize_fn is not None: + data = d[spec.io_key] + if isinstance(data, dict): + d[spec.io_key] = { + k: spec.deserialize_fn(v) for k, v in data.items() + } + return AnnData(**d) elif elem_name.startswith("/raw."): return None elif elem_name == "/X" and "X" in as_sparse: diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 43b084a00..408ed192f 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -285,18 +285,25 @@ def write_anndata( _writer: Writer, dataset_kwargs: Mapping[str, Any] = MappingProxyType({}), ): + from anndata._core.section_registry import iter_sections + g = f.require_group(k) + # X and raw need special handling if adata.X is not None: _writer.write_elem(g, "X", adata.X, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obs", adata.obs, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "var", adata.var, dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsm", dict(adata.obsm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "raw", adata.raw, dataset_kwargs=dataset_kwargs) + # All other sections via the unified registry + for spec, value in iter_sections(adata, exclude_kinds={"X", "raw"}): + # Skip empty mappings (but always write DataFrames — they carry the index) + if spec.kind != "dataframe" and len(value) == 0: + continue + if spec.serialize_fn is not None: + data = {k: spec.serialize_fn(v) for k, v in value.items()} + elif spec.kind == "dataframe": + data = value # write DataFrame directly + else: + data = dict(value) # mappings and uns → dict + _writer.write_elem(g, spec.io_key, data, dataset_kwargs=dataset_kwargs) @_REGISTRY.register_read(H5Group, IOSpec("anndata", "0.1.0")) @@ -307,20 +314,12 @@ def write_anndata( @_REGISTRY.register_read(ZarrGroup, IOSpec("raw", "0.1.0")) def read_anndata(elem: _GroupStorageType | H5File, *, _reader: Reader) -> AnnData: d = {} - for k in [ - "X", - "obs", - "var", - "obsm", - "varm", - "obsp", - "varp", - "layers", - "uns", - "raw", - ]: - if k in elem: - d[k] = _reader.read_elem(elem[k]) + for spec in AnnData._registered_sections.values(): + if spec.io_key in elem: + data = _reader.read_elem(elem[spec.io_key]) + if spec.deserialize_fn is not None and isinstance(data, dict): + data = {k: spec.deserialize_fn(v) for k, v in data.items()} + d[spec.name] = data return AnnData(**d) diff --git a/src/anndata/_repr/__init__.py b/src/anndata/_repr/__init__.py index 1b61850be..28b34b46a 100644 --- a/src/anndata/_repr/__init__.py +++ b/src/anndata/_repr/__init__.py @@ -54,6 +54,18 @@ - Support for nested AnnData objects - Graceful handling of unknown types +.. note:: + + For extending AnnData with custom formatters, prefer importing from + :mod:`anndata.extensions` which provides the public API:: + + from anndata.extensions import ( + register_formatter, + TypeFormatter, + SectionFormatter, + FormattedOutput, + ) + Extensibility ------------- The system is designed to be extensible via two registry patterns: @@ -68,7 +80,11 @@ Example - format by Python type:: - from anndata._repr import register_formatter, TypeFormatter, FormattedOutput + from anndata.extensions import ( + register_formatter, + TypeFormatter, + FormattedOutput, + ) @register_formatter @@ -101,8 +117,12 @@ def format(self, obj, context): Example - format by embedded type hint (for tagged data in uns):: - from anndata._repr import register_formatter, TypeFormatter, FormattedOutput - from anndata._repr import extract_uns_type_hint + from anndata.extensions import ( + register_formatter, + TypeFormatter, + FormattedOutput, + extract_uns_type_hint, + ) @register_formatter @@ -150,8 +170,12 @@ def format(self, obj, context): Example:: - from anndata._repr import register_formatter, SectionFormatter - from anndata._repr import FormattedEntry, FormattedOutput + from anndata.extensions import ( + register_formatter, + SectionFormatter, + FormattedEntry, + FormattedOutput, + ) @register_formatter diff --git a/src/anndata/extensions.py b/src/anndata/extensions.py new file mode 100644 index 000000000..2bb623d2a --- /dev/null +++ b/src/anndata/extensions.py @@ -0,0 +1,122 @@ +""" +Public API for extending AnnData functionality. + +This module provides registration mechanisms for: + +1. **Accessors** - Add custom namespaces to AnnData objects (e.g., `adata.myns.method()`) +2. **Aligned Sections** - Add new axis-aligned mappings (e.g., `adata.obst`) with full + subsetting, IO, repr, and init support — no subclassing needed +3. **HTML Formatters** - Customize how types are displayed in Jupyter notebooks + +Examples +-------- +Register a custom accessor namespace:: + + import anndata as ad + from anndata.extensions import register_anndata_namespace + + + @register_anndata_namespace("transform") + class TransformAccessor: + def __init__(self, adata: ad.AnnData): + self._adata = adata + + def log1p(self): + import numpy as np + + self._adata.X = np.log1p(self._adata.X) + return self._adata + + + # Usage: adata.transform.log1p() + +Register a custom HTML formatter for a type:: + + from anndata.extensions import register_formatter, TypeFormatter, FormattedOutput + + + @register_formatter + class MyArrayFormatter(TypeFormatter): + priority = 100 # Higher = checked first + + def can_format(self, obj): + return isinstance(obj, MyArrayType) + + def format(self, obj, context): + return FormattedOutput( + type_name=f"MyArray {obj.shape}", + css_class="dtype-custom", + ) + +Register a custom section formatter (for packages like TreeData, SpatialData):: + + from anndata.extensions import register_formatter, SectionFormatter + from anndata.extensions import FormattedEntry, FormattedOutput + + + @register_formatter + class ObstSectionFormatter(SectionFormatter): + section_name = "obst" + after_section = "obsm" # Position in display order + + def should_show(self, obj): + return hasattr(obj, "obst") and len(obj.obst) > 0 + + def get_entries(self, obj, context): + return [ + FormattedEntry( + key=k, + output=FormattedOutput(type_name=f"Tree ({v.n_nodes} nodes)"), + ) + for k, v in obj.obst.items() + ] + +See Also +-------- +anndata._repr : Full documentation of the HTML representation system +""" + +from __future__ import annotations + +# Accessor registration (from PR #1870) +# Section registration (pluggable sections with custom alignment, IO, validation) +from anndata._core.extensions import register_anndata_namespace, register_section +from anndata._core.section_registry import SectionSpec + +# HTML representation formatters +from anndata._repr import ( + # Type hint utilities for tagged data + UNS_TYPE_HINT_KEY, + # Core formatter classes + FormattedEntry, + FormattedOutput, + FormatterContext, + FormatterRegistry, + SectionFormatter, + TypeFormatter, + extract_uns_type_hint, + # Global registry instance + formatter_registry, + # Registration function + register_formatter, +) + +__all__ = [ # noqa: RUF022 # organized by category, not alphabetically + # Accessor registration + "register_anndata_namespace", + # Section registration + "register_section", + "SectionSpec", + # HTML formatter registration + "register_formatter", + "TypeFormatter", + "SectionFormatter", + "FormattedOutput", + "FormattedEntry", + "FormatterContext", + "FormatterRegistry", + "formatter_registry", + # Type hint utilities + "extract_uns_type_hint", + "UNS_TYPE_HINT_KEY", +] diff --git a/tests/test_registered_sections.py b/tests/test_registered_sections.py new file mode 100644 index 000000000..40fc32263 --- /dev/null +++ b/tests/test_registered_sections.py @@ -0,0 +1,913 @@ +"""Tests for register_section decorator. + +Validates that registered sections behave correctly for all alignment +combinations, custom validation, custom subsetting, custom IO, and +HTML repr integration. Uses TreeData-like, SpatialData-like, and +xarray scenarios. +""" + +from __future__ import annotations + +import json + +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from scipy.sparse import csr_matrix + +import anndata as ad +from anndata.extensions import register_section + +# --------------------------------------------------------------------------- +# Fixtures: register sections once per test session +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True, scope="module") +def _register_test_sections(): # noqa: PLR0912 + """Register test sections for all alignment combinations.""" + # obs-aligned (like obsm) + if "sec_obs" not in ad.AnnData._registered_sections: + + @register_section("sec_obs", alignment="obs") + class SecObs: + pass + + # var-aligned (like varm) + if "sec_var" not in ad.AnnData._registered_sections: + + @register_section("sec_var", alignment="var") + class SecVar: + pass + + # Both axes (like layers) + if "sec_both" not in ad.AnnData._registered_sections: + + @register_section("sec_both", alignment=("obs", "var")) + class SecBoth: + pass + + # Pairwise obs (like obsp) + if "sec_pair_obs" not in ad.AnnData._registered_sections: + + @register_section("sec_pair_obs", alignment=("obs", "obs")) + class SecPairObs: + pass + + # Pairwise var (like varp) + if "sec_pair_var" not in ad.AnnData._registered_sections: + + @register_section("sec_pair_var", alignment=("var", "var")) + class SecPairVar: + pass + + # Unaligned (like SpatialData images) + if "sec_unaligned" not in ad.AnnData._registered_sections: + + @register_section("sec_unaligned", alignment=()) + class SecUnaligned: + pass + + # Custom type validation (TreeData-like) + if "sec_typed" not in ad.AnnData._registered_sections: + + @register_section("sec_typed", alignment="obs") + class SecTyped: + value_type = np.ndarray + + @staticmethod + def validate(key, value): + if value.ndim != 2: + msg = f"{key} must be 2D" + raise ValueError(msg) + + # Custom serialize/deserialize + if "sec_custom_io" not in ad.AnnData._registered_sections: + + @register_section("sec_custom_io", alignment="obs") + class SecCustomIO: + @staticmethod + def serialize(value): + # Convert dict to JSON string for storage + return json.dumps(value) + + @staticmethod + def deserialize(data): + return json.loads(data) + + # Cell-cell communication tensor: (sender, receiver, gene) + if "cellcomm" not in ad.AnnData._registered_sections: + + @register_section("cellcomm", alignment=("obs", "obs", "var")) + class CellCommSection: + """Ligand-receptor communication scores (sender × receiver × gene).""" + + section_after = "obsp" + section_tooltip = "Cell-cell communication" + + # Cell-specific gene-gene interactions: (obs, var, var) + if "genereg" not in ad.AnnData._registered_sections: + + @register_section("genereg", alignment=("obs", "var", "var")) + class GeneRegSection: + """Cell-specific gene regulatory networks (cell × gene × gene).""" + + section_after = "varp" + section_tooltip = "Gene regulation per cell" + + # Custom subset + if "sec_custom_subset" not in ad.AnnData._registered_sections: + + @register_section("sec_custom_subset", alignment="obs") + class SecCustomSubset: + @staticmethod + def subset(value, idx): + # Custom: return a dict describing the subset + return {"original": value, "subset_idx": idx} + + # Factored tensor: store rank-R factors, reconstruct on demand + if "comm_obs" not in ad.AnnData._registered_sections: + + @register_section("comm_obs", alignment="obs") + class CommObs: + """Cell factor matrix (n_obs × rank) for communication tensor.""" + + if "comm_var" not in ad.AnnData._registered_sections: + + @register_section("comm_var", alignment="var") + class CommVar: + """Gene factor matrix (n_vars × rank) for communication tensor.""" + + # xarray layers (custom type with serialize/deserialize) + if "xr_layers" not in ad.AnnData._registered_sections: + + @register_section("xr_layers", alignment=("obs", "var")) + class XarrayLayers: + value_type = xr.DataArray + + @staticmethod + def serialize(value): + return value.values # xarray → numpy for h5ad + + @staticmethod + def deserialize(data): + return xr.DataArray(data) # numpy → xarray on read + + +@pytest.fixture +def adata(): + """Basic AnnData for testing.""" + return ad.AnnData( + X=np.ones((5, 3)), + obs=pd.DataFrame({"group": list("aabbc")}, index=[f"c{i}" for i in range(5)]), + var=pd.DataFrame(index=[f"v{i}" for i in range(3)]), + ) + + +# --------------------------------------------------------------------------- +# Registration API +# --------------------------------------------------------------------------- + + +class TestRegistrationAPI: + def test_register_creates_property(self): + assert hasattr(ad.AnnData, "sec_obs") + adata = ad.AnnData(np.ones((3, 4))) + assert len(adata.sec_obs) == 0 + + def test_register_duplicate_raises(self): + with pytest.raises(ValueError, match="already registered"): + register_section("sec_obs", alignment="obs")(type("Dup", (), {})) + + def test_register_reserved_name_raises(self): + # "obs" is a built-in registered section, so it's already registered + with pytest.raises(ValueError, match="already registered"): + register_section("obs", alignment="obs")(type("Bad", (), {})) + + def test_all_sections_in_registry(self): + for name in [ + "sec_obs", + "sec_var", + "sec_both", + "sec_pair_obs", + "sec_pair_var", + "sec_unaligned", + "sec_typed", + ]: + assert name in ad.AnnData._registered_sections + + +# --------------------------------------------------------------------------- +# Obs-aligned: alignment=("obs",) +# --------------------------------------------------------------------------- + + +class TestObsAligned: + def test_store_and_retrieve(self, adata): + arr = np.random.rand(5, 3) + adata.sec_obs["x"] = arr + np.testing.assert_array_equal(adata.sec_obs["x"], arr) + + def test_wrong_shape_raises(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.sec_obs["bad"] = np.ones((10, 3)) + + def test_sparse(self, adata): + adata.sec_obs["sp"] = csr_matrix(np.eye(5)) + assert adata.sec_obs["sp"].shape == (5, 5) + + def test_subset_obs(self, adata): + adata.sec_obs["x"] = np.arange(15).reshape(5, 3) + sub = adata[:3] + assert sub.sec_obs["x"].shape == (3, 3) + + def test_subset_var_unchanged(self, adata): + adata.sec_obs["x"] = np.arange(15).reshape(5, 3) + sub = adata[:, :2] + # obs-aligned section not affected by var subsetting + assert sub.sec_obs["x"].shape == (5, 3) + + +# --------------------------------------------------------------------------- +# Var-aligned: alignment=("var",) +# --------------------------------------------------------------------------- + + +class TestVarAligned: + def test_store_and_retrieve(self, adata): + arr = np.random.rand(3, 2) + adata.sec_var["x"] = arr + np.testing.assert_array_equal(adata.sec_var["x"], arr) + + def test_wrong_shape_raises(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.sec_var["bad"] = np.ones((10, 2)) + + def test_subset_var(self, adata): + adata.sec_var["x"] = np.arange(6).reshape(3, 2) + sub = adata[:, :2] + assert sub.sec_var["x"].shape == (2, 2) + + def test_subset_obs_unchanged(self, adata): + adata.sec_var["x"] = np.arange(6).reshape(3, 2) + sub = adata[:3] + assert sub.sec_var["x"].shape == (3, 2) + + +# --------------------------------------------------------------------------- +# Both axes: alignment=("obs", "var") +# --------------------------------------------------------------------------- + + +class TestBothAxes: + def test_store_and_retrieve(self, adata): + arr = np.random.rand(5, 3) + adata.sec_both["x"] = arr + np.testing.assert_array_equal(adata.sec_both["x"], arr) + + def test_wrong_obs_shape_raises(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.sec_both["bad"] = np.ones((10, 3)) + + def test_wrong_var_shape_raises(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.sec_both["bad"] = np.ones((5, 10)) + + def test_subset_both(self, adata): + adata.sec_both["x"] = np.arange(15).reshape(5, 3) + sub = adata[:3, :2] + assert sub.sec_both["x"].shape == (3, 2) + + +# --------------------------------------------------------------------------- +# Pairwise obs: alignment=("obs", "obs") +# --------------------------------------------------------------------------- + + +class TestPairwiseObs: + def test_store_and_retrieve(self, adata): + arr = np.random.rand(5, 5) + adata.sec_pair_obs["dist"] = arr + np.testing.assert_array_equal(adata.sec_pair_obs["dist"], arr) + + def test_non_square_raises(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.sec_pair_obs["bad"] = np.ones((5, 3)) + + def test_subset(self, adata): + adata.sec_pair_obs["dist"] = np.eye(5) + sub = adata[:3] + assert sub.sec_pair_obs["dist"].shape == (3, 3) + + +# --------------------------------------------------------------------------- +# Pairwise var: alignment=("var", "var") +# --------------------------------------------------------------------------- + + +class TestPairwiseVar: + def test_store_and_retrieve(self, adata): + arr = np.random.rand(3, 3) + adata.sec_pair_var["corr"] = arr + np.testing.assert_array_equal(adata.sec_pair_var["corr"], arr) + + def test_subset(self, adata): + adata.sec_pair_var["corr"] = np.eye(3) + sub = adata[:, :2] + assert sub.sec_pair_var["corr"].shape == (2, 2) + + +# --------------------------------------------------------------------------- +# Unaligned: alignment=() +# --------------------------------------------------------------------------- + + +class TestUnaligned: + def test_store_anything(self, adata): + adata.sec_unaligned["img"] = np.random.rand(100, 100, 3) + assert adata.sec_unaligned["img"].shape == (100, 100, 3) + + def test_no_shape_validation(self, adata): + # Any shape is fine for unaligned + adata.sec_unaligned["a"] = np.ones((1,)) + adata.sec_unaligned["b"] = np.ones((999, 888)) + assert len(adata.sec_unaligned) == 2 + + def test_subset_unchanged(self, adata): + adata.sec_unaligned["img"] = np.random.rand(100, 100, 3) + sub = adata[:3] + # Unaligned data is not subsetted + assert sub.sec_unaligned["img"].shape == (100, 100, 3) + + def test_non_array_values(self, adata): + adata.sec_unaligned["config"] = {"key": "value"} + assert adata.sec_unaligned["config"] == {"key": "value"} + + +# --------------------------------------------------------------------------- +# Custom type validation +# --------------------------------------------------------------------------- + + +class TestCustomValidation: + def test_type_check(self, adata): + adata.sec_typed["x"] = np.eye(5) + assert adata.sec_typed["x"].shape == (5, 5) + + def test_wrong_type_raises(self, adata): + with pytest.raises(TypeError, match="must be ndarray"): + adata.sec_typed["bad"] = [[1, 2], [3, 4]] + + def test_custom_validate(self, adata): + with pytest.raises(ValueError, match="must be 2D"): + adata.sec_typed["bad"] = np.ones(5) # 1D, not 2D + + +# --------------------------------------------------------------------------- +# Custom subset +# --------------------------------------------------------------------------- + + +class TestCustomSubset: + def test_custom_subset_fn(self, adata): + adata.sec_custom_subset["x"] = np.eye(5) + sub = adata[:3] + result = sub.sec_custom_subset["x"] + assert isinstance(result, dict) + assert "original" in result + assert "subset_idx" in result + + +# --------------------------------------------------------------------------- +# Init kwargs +# --------------------------------------------------------------------------- + + +class TestInitKwargs: + def test_init_with_section(self): + adata = ad.AnnData( + np.ones((3, 4)), + obs=pd.DataFrame(index=["c1", "c2", "c3"]), + sec_obs={"x": np.eye(3)}, + ) + assert "x" in adata.sec_obs + + def test_init_with_multiple_sections(self): + adata = ad.AnnData( + np.ones((3, 4)), + obs=pd.DataFrame(index=["c1", "c2", "c3"]), + var=pd.DataFrame(index=["v1", "v2", "v3", "v4"]), + sec_obs={"x": np.eye(3)}, + sec_var={"y": np.eye(4)}, + ) + assert "x" in adata.sec_obs + assert "y" in adata.sec_var + + def test_init_without_section(self): + adata = ad.AnnData(np.ones((3, 4))) + assert len(adata.sec_obs) == 0 + + +# --------------------------------------------------------------------------- +# IO roundtrip (h5ad) +# --------------------------------------------------------------------------- + + +class TestH5adRoundtrip: + def test_write_read_obs_aligned(self, adata, tmp_path): + adata.sec_obs["x"] = np.eye(5) + path = tmp_path / "test.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert "x" in adata2.sec_obs + np.testing.assert_array_equal(adata2.sec_obs["x"], np.eye(5)) + + def test_write_read_both_axes(self, adata, tmp_path): + adata.sec_both["x"] = np.arange(15).reshape(5, 3).astype(float) + path = tmp_path / "test.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + np.testing.assert_array_equal(adata2.sec_both["x"], np.arange(15).reshape(5, 3)) + + def test_empty_section_not_written(self, adata, tmp_path): + import h5py + + path = tmp_path / "test.h5ad" + adata.write(path) + with h5py.File(path, "r") as f: + assert "sec_obs" not in f + + def test_custom_serialize_deserialize(self, adata, tmp_path): + adata.sec_custom_io["config"] = {"lr": 0.001, "epochs": 100} + path = tmp_path / "test.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert adata2.sec_custom_io["config"] == {"lr": 0.001, "epochs": 100} + + def test_subset_then_write(self, adata, tmp_path): + adata.sec_obs["x"] = np.arange(15).reshape(5, 3).astype(float) + sub = adata[:3].copy() + path = tmp_path / "test.h5ad" + sub.write(path) + sub2 = ad.read_h5ad(path) + assert sub2.sec_obs["x"].shape == (3, 3) + + +# --------------------------------------------------------------------------- +# Repr +# --------------------------------------------------------------------------- + + +class TestRepr: + def test_repr_shows_section(self, adata): + adata.sec_obs["x"] = np.eye(5) + assert "sec_obs" in repr(adata) + + def test_repr_hides_empty(self, adata): + assert "sec_obs" not in repr(adata) + + +# --------------------------------------------------------------------------- +# Copy +# --------------------------------------------------------------------------- + + +class TestCopy: + def test_copy_preserves(self, adata): + adata.sec_obs["x"] = np.eye(5) + adata2 = adata.copy() + assert "x" in adata2.sec_obs + np.testing.assert_array_equal(adata2.sec_obs["x"], np.eye(5)) + + def test_copy_is_independent(self, adata): + adata.sec_obs["x"] = np.eye(5) + adata2 = adata.copy() + adata2.sec_obs["new"] = np.ones((5, 2)) + assert "new" not in adata.sec_obs + + def test_view_copy_on_write(self, adata): + adata.sec_obs["x"] = np.eye(5) + sub = adata[:3] + sub.sec_obs["new"] = np.ones((3, 2)) + assert not sub.is_view + assert "new" in sub.sec_obs + assert "new" not in adata.sec_obs + + +# --------------------------------------------------------------------------- +# TreeData-like end-to-end scenario +# --------------------------------------------------------------------------- + + +class TestTreeDataScenario: + def test_full_workflow(self, tmp_path): + n_obs, n_vars = 10, 5 + adata = ad.AnnData( + X=np.random.rand(n_obs, n_vars), + obs=pd.DataFrame( + {"cell_type": pd.Categorical(["A"] * 5 + ["B"] * 5)}, + index=[f"cell_{i}" for i in range(n_obs)], + ), + var=pd.DataFrame(index=[f"gene_{i}" for i in range(n_vars)]), + ) + + # Store tree embeddings + adata.sec_obs["lineage"] = np.random.rand(n_obs, 4) + adata.sec_var["phylogeny"] = np.random.rand(n_vars, 3) + + # Subset + mask = adata.obs["cell_type"] == "A" + sub = adata[mask] + assert sub.sec_obs["lineage"].shape == (5, 4) + assert sub.sec_var["phylogeny"].shape == (n_vars, 3) + + # IO roundtrip + path = tmp_path / "treedata.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert set(adata2.sec_obs.keys()) == {"lineage"} + assert set(adata2.sec_var.keys()) == {"phylogeny"} + + # Repr + assert "sec_obs" in repr(adata2) + assert "sec_var" in repr(adata2) + + +# --------------------------------------------------------------------------- +# SpatialData-like end-to-end scenario +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# xarray DataArray layers (custom type + serialize/deserialize) +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Cell-cell communication tensor: alignment=("obs", "obs", "var") +# --------------------------------------------------------------------------- + + +class TestCellCommunication: + """3D tensor for ligand-receptor communication scores. + + Tools like CellChat, LIANA, and CellPhoneDB compute communication + strengths between cell pairs mediated by specific genes. The natural + shape is (sender_cell, receiver_cell, gene). With alignment=("obs", + "obs", "var"), the tensor subsets correctly when filtering cells or genes. + """ + + def test_store_tensor(self, adata): + comm = np.random.rand(5, 5, 3) + adata.cellcomm["lr_scores"] = comm + assert adata.cellcomm["lr_scores"].shape == (5, 5, 3) + + def test_validates_obs_dim(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.cellcomm["bad"] = np.ones((10, 10, 3)) + + def test_validates_var_dim(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.cellcomm["bad"] = np.ones((5, 5, 10)) + + def test_validates_square_obs(self, adata): + """Sender and receiver must both be n_obs.""" + with pytest.raises(ValueError, match="shape"): + adata.cellcomm["bad"] = np.ones((5, 3, 3)) + + def test_subset_cells(self, adata): + """Filtering cells subsets both sender and receiver dims.""" + adata.cellcomm["lr"] = np.random.rand(5, 5, 3) + sub = adata[:3] + assert sub.cellcomm["lr"].shape == (3, 3, 3) + + def test_subset_genes(self, adata): + """Filtering genes subsets the third dim.""" + adata.cellcomm["lr"] = np.random.rand(5, 5, 3) + sub = adata[:, :2] + assert sub.cellcomm["lr"].shape == (5, 5, 2) + + def test_subset_both(self, adata): + adata.cellcomm["lr"] = np.random.rand(5, 5, 3) + sub = adata[:3, :2] + assert sub.cellcomm["lr"].shape == (3, 3, 2) + + def test_io_roundtrip(self, adata, tmp_path): + comm = np.random.rand(5, 5, 3) + adata.cellcomm["lr_scores"] = comm + path = tmp_path / "comm.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + np.testing.assert_array_almost_equal(adata2.cellcomm["lr_scores"], comm) + + def test_workflow(self, tmp_path): + """End-to-end: simulate CellChat-like analysis.""" + n_obs, n_vars = 20, 50 + adata = ad.AnnData( + X=np.random.rand(n_obs, n_vars), + obs=pd.DataFrame( + {"cell_type": pd.Categorical(["T"] * 10 + ["B"] * 10)}, + index=[f"cell_{i}" for i in range(n_obs)], + ), + var=pd.DataFrame( + {"is_ligand": [True] * 25 + [False] * 25}, + index=[f"gene_{i}" for i in range(n_vars)], + ), + ) + + # Compute communication scores (simulated) + adata.cellcomm["cellchat"] = np.random.rand(n_obs, n_obs, n_vars) + + # Filter to T cells only + t_cells = adata.obs["cell_type"] == "T" + sub = adata[t_cells] + assert sub.cellcomm["cellchat"].shape == (10, 10, n_vars) + + # Filter to ligand genes only + ligands = adata.var["is_ligand"] + sub2 = adata[:, ligands] + assert sub2.cellcomm["cellchat"].shape == (n_obs, n_obs, 25) + + # IO roundtrip + path = tmp_path / "cellchat.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert adata2.cellcomm["cellchat"].shape == (n_obs, n_obs, n_vars) + + +# --------------------------------------------------------------------------- +# xarray DataArray layers (custom type + serialize/deserialize) +# --------------------------------------------------------------------------- + + +# --------------------------------------------------------------------------- +# Cell-specific gene regulation: alignment=("obs", "var", "var") +# --------------------------------------------------------------------------- + + +class TestGeneRegulation: + """3D tensor for cell-specific gene regulatory networks. + + Each cell has its own gene-gene interaction matrix (e.g., inferred + from single-cell GRN methods like SCENIC, CellOracle, or Dictys). + Shape is (cell, source_gene, target_gene). Subsetting cells reduces + the first dim, subsetting genes reduces both gene dims. + """ + + def test_store_tensor(self, adata): + grn = np.random.rand(5, 3, 3) + adata.genereg["scenic"] = grn + assert adata.genereg["scenic"].shape == (5, 3, 3) + + def test_validates_obs_dim(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.genereg["bad"] = np.ones((10, 3, 3)) + + def test_validates_var_dims(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.genereg["bad"] = np.ones((5, 10, 3)) # source wrong + with pytest.raises(ValueError, match="shape"): + adata.genereg["bad"] = np.ones((5, 3, 10)) # target wrong + + def test_subset_cells(self, adata): + """Filtering cells subsets the first dim only.""" + adata.genereg["grn"] = np.random.rand(5, 3, 3) + sub = adata[:3] + assert sub.genereg["grn"].shape == (3, 3, 3) + + def test_subset_genes(self, adata): + """Filtering genes subsets both gene dims (source and target).""" + adata.genereg["grn"] = np.random.rand(5, 3, 3) + sub = adata[:, :2] + assert sub.genereg["grn"].shape == (5, 2, 2) + + def test_subset_both(self, adata): + adata.genereg["grn"] = np.random.rand(5, 3, 3) + sub = adata[:3, :2] + assert sub.genereg["grn"].shape == (3, 2, 2) + + def test_io_roundtrip(self, adata, tmp_path): + grn = np.random.rand(5, 3, 3) + adata.genereg["scenic"] = grn + path = tmp_path / "grn.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + np.testing.assert_array_almost_equal(adata2.genereg["scenic"], grn) + + +# --------------------------------------------------------------------------- +# Factored tensor: sections + accessor (scalable communication analysis) +# --------------------------------------------------------------------------- + + +class TestFactoredTensor: + """Store rank-R factors in sections, reconstruct tensor via accessor. + + For million-cell datasets, a dense (n_obs × n_obs × n_vars) tensor + is infeasible. Instead, store compact factors (n_obs × rank) and + (n_vars × rank), and reconstruct on demand. The factors subset + correctly, serialize to h5ad, and the accessor provides the tensor API. + """ + + @pytest.fixture(autouse=True) + def _ensure_accessor(self): + """Create and register the accessor (idempotent).""" + if hasattr(ad.AnnData, "comm"): + return + from anndata.extensions import ( + FormattedEntry, + FormattedOutput, + register_anndata_namespace, + ) + + @register_anndata_namespace("comm") + class CellCommAccessor: + section_after = "obsp" + section_tooltip = "Cell-cell communication (factored)" + + def __init__(self, adata: ad.AnnData): + self._adata = adata + + def tensor(self, key="default"): + """Reconstruct (obs × obs × var) tensor from factors.""" + U = self._adata.comm_obs[key] + V = self._adata.comm_var[key] + return np.einsum("ir,jr,kr->ijk", U, U, V) + + def query(self, sender, receiver, gene, key="default"): + """O(rank) point query without materializing tensor.""" + U = self._adata.comm_obs[key] + V = self._adata.comm_var[key] + i = self._adata.obs_names.get_loc(sender) + j = self._adata.obs_names.get_loc(receiver) + k = self._adata.var_names.get_loc(gene) + return float(U[i] @ (U[j] * V[k])) + + def _repr_section_(self, context): + keys = list(self._adata.comm_obs.keys()) + if not keys: + return None + return [ + FormattedEntry( + key=k, + output=FormattedOutput( + type_name=f"rank-{self._adata.comm_obs[k].shape[1]} factors", + preview=( + f"({self._adata.comm_obs[k].shape[0]} cells " + f"× {self._adata.comm_var[k].shape[0]} genes)" + ), + ), + ) + for k in keys + ] + + def test_store_factors(self, adata): + n_obs, n_vars, rank = 5, 3, 2 + adata.comm_obs["lr"] = np.random.rand(n_obs, rank) + adata.comm_var["lr"] = np.random.rand(n_vars, rank) + assert adata.comm_obs["lr"].shape == (n_obs, rank) + assert adata.comm_var["lr"].shape == (n_vars, rank) + + def test_reconstruct_tensor(self, adata): + rank = 3 + adata.comm_obs["lr"] = np.random.rand(5, rank) + adata.comm_var["lr"] = np.random.rand(3, rank) + tensor = adata.comm.tensor("lr") + assert tensor.shape == (5, 5, 3) + + def test_point_query(self, adata): + rank = 3 + U = np.random.rand(5, rank) + V = np.random.rand(3, rank) + adata.comm_obs["lr"] = U + adata.comm_var["lr"] = V + score = adata.comm.query("c0", "c1", "v0", "lr") + expected = float(U[0] @ (U[1] * V[0])) + assert abs(score - expected) < 1e-10 + + def test_subset_preserves_reconstruction(self, adata): + rank = 3 + adata.comm_obs["lr"] = np.random.rand(5, rank) + adata.comm_var["lr"] = np.random.rand(3, rank) + full_tensor = adata.comm.tensor("lr") + + sub = adata[:3, :2] + sub_tensor = sub.comm.tensor("lr") + assert sub_tensor.shape == (3, 3, 2) + np.testing.assert_array_almost_equal(sub_tensor, full_tensor[:3, :3, :2]) + + def test_io_roundtrip(self, adata, tmp_path): + rank = 3 + adata.comm_obs["lr"] = np.random.rand(5, rank) + adata.comm_var["lr"] = np.random.rand(3, rank) + tensor_before = adata.comm.tensor("lr") + + path = tmp_path / "factored.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + tensor_after = adata2.comm.tensor("lr") + np.testing.assert_array_almost_equal(tensor_before, tensor_after) + + def test_compression_ratio(self): + """Factors are orders of magnitude smaller than dense tensor.""" + n_obs, n_vars, rank = 1000, 500, 10 + factor_bytes = (n_obs * rank + n_vars * rank) * 8 # float64 + tensor_bytes = n_obs * n_obs * n_vars * 8 + ratio = tensor_bytes / factor_bytes + assert ratio > 100 # ~33,000× for these sizes + + +# --------------------------------------------------------------------------- +# xarray DataArray layers (custom type + serialize/deserialize) +# --------------------------------------------------------------------------- + + +class TestXarrayScenario: + """xarray DataArrays as layer values with custom serialization.""" + + def test_store_xarray(self, adata): + da = xr.DataArray(np.random.rand(5, 3), dims=["obs", "var"]) + adata.xr_layers["normalized"] = da + assert isinstance(adata.xr_layers["normalized"], xr.DataArray) + assert adata.xr_layers["normalized"].shape == (5, 3) + + def test_type_enforcement(self, adata): + with pytest.raises(TypeError, match="must be DataArray"): + adata.xr_layers["bad"] = np.ones((5, 3)) + + def test_alignment_validation(self, adata): + with pytest.raises(ValueError, match="shape"): + adata.xr_layers["bad"] = xr.DataArray(np.ones((10, 3))) + + def test_subset(self, adata): + da = xr.DataArray(np.arange(15.0).reshape(5, 3), dims=["obs", "var"]) + adata.xr_layers["data"] = da + sub = adata[:3, :2] + result = sub.xr_layers["data"] + assert result.shape == (3, 2) + + def test_io_roundtrip(self, adata, tmp_path): + da = xr.DataArray(np.arange(15.0).reshape(5, 3), dims=["obs", "var"]) + adata.xr_layers["data"] = da + path = tmp_path / "xr.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + # Deserialized back to xarray + assert isinstance(adata2.xr_layers["data"], xr.DataArray) + np.testing.assert_array_equal( + adata2.xr_layers["data"].values, np.arange(15.0).reshape(5, 3) + ) + + def test_full_workflow(self, tmp_path): + """End-to-end: store, subset, copy, IO with xarray layers.""" + adata = ad.AnnData( + X=np.ones((10, 5)), + obs=pd.DataFrame(index=[f"c{i}" for i in range(10)]), + var=pd.DataFrame(index=[f"g{i}" for i in range(5)]), + xr_layers={ + "scaled": xr.DataArray(np.random.rand(10, 5), dims=["obs", "var"]), + }, + ) + + # Subset preserves type + sub = adata[:5] + assert isinstance(sub.xr_layers["scaled"], xr.DataArray) + assert sub.xr_layers["scaled"].shape == (5, 5) + + # Copy preserves type + copy = adata.copy() + assert isinstance(copy.xr_layers["scaled"], xr.DataArray) + + # IO roundtrip preserves type + path = tmp_path / "xr_workflow.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert isinstance(adata2.xr_layers["scaled"], xr.DataArray) + assert adata2.xr_layers["scaled"].shape == (10, 5) + + # Repr shows section + assert "xr_layers" in repr(adata) + + +# --------------------------------------------------------------------------- +# SpatialData-like end-to-end scenario +# --------------------------------------------------------------------------- + + +class TestSpatialDataScenario: + def test_unaligned_images(self, adata, tmp_path): + # Store images of arbitrary size + adata.sec_unaligned["hires"] = np.random.rand(200, 200, 3) + adata.sec_unaligned["lowres"] = np.random.rand(50, 50, 3) + + # Subsetting obs doesn't affect images + sub = adata[:3] + assert sub.sec_unaligned["hires"].shape == (200, 200, 3) + + # IO roundtrip + path = tmp_path / "spatial.h5ad" + adata.write(path) + adata2 = ad.read_h5ad(path) + assert set(adata2.sec_unaligned.keys()) == {"hires", "lowres"} + assert adata2.sec_unaligned["hires"].shape == (200, 200, 3)