diff --git a/docs/api.md b/docs/api.md index d89398505..c91ff4c22 100644 --- a/docs/api.md +++ b/docs/api.md @@ -147,6 +147,7 @@ See the {doc}`extensibility guide ` for how to implement a custo .. autosummary:: :toctree: api + experimental.im.calculate_image_features experimental.tl.calculate_tiling_qc experimental.tl.TilingQCParams experimental.tl.assign_stitch_groups diff --git a/src/squidpy/experimental/im/__init__.py b/src/squidpy/experimental/im/__init__.py index 7adea9d74..e6bef60a8 100644 --- a/src/squidpy/experimental/im/__init__.py +++ b/src/squidpy/experimental/im/__init__.py @@ -1,5 +1,6 @@ from __future__ import annotations +from ._calculate_image_features import calculate_image_features from ._detect_tissue import ( BackgroundDetectionParams, FelzenszwalbParams, @@ -29,6 +30,8 @@ "StainReference", "VahadaneParams", "WekaParams", + "apply_stain_normalization", + "calculate_image_features", "normalize_stains", "decompose_stains", "detect_tissue", diff --git a/src/squidpy/experimental/im/_calculate_image_features.py b/src/squidpy/experimental/im/_calculate_image_features.py new file mode 100644 index 000000000..9a37e8fee --- /dev/null +++ b/src/squidpy/experimental/im/_calculate_image_features.py @@ -0,0 +1,816 @@ +"""Experimental feature extraction module. + +Extracts per-cell features from segmentation masks using scikit-image +``regionprops`` and squidpy-specific metrics (summary statistics, GLCM +texture, colour histograms). Large images are automatically tiled so +that each tile is processed independently. +""" + +from __future__ import annotations + +import time +from collections.abc import Callable +from dataclasses import dataclass +from typing import Literal, NamedTuple + +import anndata as ad +import numpy as np +import pandas as pd +import xarray as xr +from joblib import Parallel, delayed +from skimage import measure +from skimage.feature import graycomatrix, graycoprops +from spatialdata import SpatialData, rasterize +from spatialdata._logging import logger as logg +from spatialdata.models import TableModel, get_channel_names +from tqdm.auto import tqdm + +from squidpy.experimental.im._tiling import ( + build_tile_specs, + compute_cell_info, + compute_cell_info_multiscale, + compute_cell_info_tiled, + extract_labels_tile_lazy, + extract_tile_lazy, +) + +# --------------------------------------------------------------------------- +# Drop accounting +# --------------------------------------------------------------------------- + + +@dataclass +class DropReport: + """Counters for cells/tiles excluded during a featurization run.""" + + empty_tiles: int = 0 + + def summary(self) -> str: + if self.empty_tiles == 0: + return "No empty tiles." + return f"Skipped {self.empty_tiles} empty tile(s)." + + +__all__ = ["calculate_image_features"] + +# --------------------------------------------------------------------------- +# Skimage property sets +# --------------------------------------------------------------------------- + +_MASK_PROPS = frozenset( + { + "area", + "area_filled", + "area_convex", + "axis_major_length", + "axis_minor_length", + "eccentricity", + "equivalent_diameter_area", + "extent", + "feret_diameter_max", + "solidity", + "euler_number", + "centroid", + "centroid_local", + "perimeter", + "perimeter_crofton", + "inertia_tensor", + "inertia_tensor_eigvals", + } +) +_INTENSITY_PROPS = frozenset( + { + "intensity_max", + "intensity_mean", + "intensity_min", + "intensity_std", + } +) + +# cp_measure flag names recognised by `_parse_features`. These currently +# raise NotImplementedError so the error is "not implemented" rather than +# the confusing "unknown feature". +_CPMEASURE_FLAG_NAMES = frozenset( + { + "cpmeasure:intensity", + "cpmeasure:sizeshape", + "cpmeasure:texture", + "cpmeasure:granularity", + "cpmeasure:zernike", + "cpmeasure:feret", + "cpmeasure:radial", + "cpmeasure:correlation", + "cpmeasure:correlation_pearson", + "cpmeasure:correlation_costes", + "cpmeasure:correlation_manders_fold", + "cpmeasure:correlation_rwc", + } +) + +# All known top-level feature group names (used for validation). +_ALL_FEATURES = ( + _CPMEASURE_FLAG_NAMES + | {"skimage:morphology", "skimage:intensity"} + | {"squidpy:summary", "squidpy:texture", "squidpy:color_hist"} +) + + +# --------------------------------------------------------------------------- +# Feature parsing +# --------------------------------------------------------------------------- + + +class _ParsedFeatures(NamedTuple): + skimage_morphology_props: frozenset[str] | None + skimage_intensity_props: frozenset[str] | None + squidpy_summary: bool + squidpy_texture: bool + squidpy_color_hist: bool + + +def _parse_features(features: list[str] | str | None) -> _ParsedFeatures: + """Parse user-facing feature names into structured config. + + ``features=None`` requires an explicit choice. Any ``cpmeasure:*`` + flag raises ``NotImplementedError``. + """ + if features is None: + raise ValueError( + "`features` must be specified explicitly. " + "Use e.g. `features=['skimage:morphology']` for skimage regionprops or " + "`features=['squidpy:summary', 'squidpy:texture', 'squidpy:color_hist']` for squidpy-native features." + ) + + if isinstance(features, str): + features = [features] + + morphology_props: set[str] | None = None + intensity_props: set[str] | None = None + sq_summary = False + sq_texture = False + sq_color_hist = False + + for f in features: + if f in _CPMEASURE_FLAG_NAMES: + raise NotImplementedError(f"cp_measure feature `{f}` is not yet implemented.") + + if f == "skimage:morphology": + if morphology_props is not None: + raise ValueError( + "Mixing 'skimage:morphology' with 'skimage:morphology:' is ambiguous; pick one form." + ) + morphology_props = set(_MASK_PROPS) + elif f == "skimage:intensity": + if intensity_props is not None: + raise ValueError( + "Mixing 'skimage:intensity' with 'skimage:intensity:' is ambiguous; pick one form." + ) + intensity_props = set(_INTENSITY_PROPS) + + # skimage fine-grained: "skimage:morphology:prop" or "skimage:intensity:prop" + elif f.startswith("skimage:morphology:"): + prop = f.split(":", 2)[2] + if prop not in _MASK_PROPS: + raise ValueError(f"Unknown skimage morphology property: '{prop}'. Available: {sorted(_MASK_PROPS)}") + if morphology_props is not None and morphology_props >= _MASK_PROPS: + raise ValueError( + "Mixing 'skimage:morphology' with 'skimage:morphology:' is ambiguous; pick one form." + ) + morphology_props = (morphology_props or set()) | {prop} + elif f.startswith("skimage:intensity:"): + prop = f.split(":", 2)[2] + if prop not in _INTENSITY_PROPS: + raise ValueError(f"Unknown skimage intensity property: '{prop}'. Available: {sorted(_INTENSITY_PROPS)}") + if intensity_props is not None and intensity_props >= _INTENSITY_PROPS: + raise ValueError( + "Mixing 'skimage:intensity' with 'skimage:intensity:' is ambiguous; pick one form." + ) + intensity_props = (intensity_props or set()) | {prop} + + elif f == "squidpy:summary": + sq_summary = True + elif f == "squidpy:texture": + sq_texture = True + elif f == "squidpy:color_hist": + sq_color_hist = True + + else: + # cp_measure flags get a specific NotImplementedError above; don't + # advertise them in the "available" list since they always raise. + supported = sorted(_ALL_FEATURES - _CPMEASURE_FLAG_NAMES) + raise ValueError( + f"Unknown feature: '{f}'. Available top-level features: {supported}, " + f"or use 'skimage:morphology:property' / 'skimage:intensity:property' for individual properties." + ) + + return _ParsedFeatures( + skimage_morphology_props=frozenset(morphology_props) if morphology_props else None, + skimage_intensity_props=frozenset(intensity_props) if intensity_props else None, + squidpy_summary=sq_summary, + squidpy_texture=sq_texture, + squidpy_color_hist=sq_color_hist, + ) + + +def _has_any_features(parsed: _ParsedFeatures) -> bool: + return ( + parsed.skimage_morphology_props is not None + or parsed.skimage_intensity_props is not None + or parsed.squidpy_summary + or parsed.squidpy_texture + or parsed.squidpy_color_hist + ) + + +def _image_requiring_features(parsed: _ParsedFeatures) -> list[str]: + """User-facing flags in the request that need pixel data (i.e. an image).""" + flags = [ + (parsed.skimage_intensity_props is not None, "skimage:intensity"), + (parsed.squidpy_summary, "squidpy:summary"), + (parsed.squidpy_texture, "squidpy:texture"), + (parsed.squidpy_color_hist, "squidpy:color_hist"), + ] + return [name for cond, name in flags if cond] + + +# --------------------------------------------------------------------------- +# Per-tile dispatcher +# --------------------------------------------------------------------------- + + +def _featurize_tile( + tile_image: np.ndarray | None, + tile_labels: np.ndarray, + parsed: _ParsedFeatures, + channel_names: list[str], +) -> pd.DataFrame: + """Compute all requested features for a single tile. + + Parameters + ---------- + tile_image + ``(C, H, W)`` image tile, or ``None`` for a morphology-only run. + tile_labels + ``(H, W)`` label tile with only owned cells. + parsed + Parsed feature configuration. + channel_names + Channel names for column naming. + + Returns + ------- + DataFrame indexed by cell label ID with one column per feature. + """ + cell_ids = np.unique(tile_labels) + cell_ids = cell_ids[cell_ids != 0] + if len(cell_ids) == 0: + return pd.DataFrame() + + parts: list[pd.DataFrame] = [] + + # --- skimage regionprops --- + if parsed.skimage_morphology_props is not None or parsed.skimage_intensity_props is not None: + df = _compute_skimage_features( + tile_labels, tile_image, parsed.skimage_morphology_props, parsed.skimage_intensity_props, channel_names + ) + if not df.empty: + parts.append(df) + + # --- squidpy per-cell features --- + if parsed.squidpy_summary or parsed.squidpy_texture or parsed.squidpy_color_hist: + df = _compute_squidpy_per_cell(tile_labels, tile_image, parsed, channel_names) + if not df.empty: + parts.append(df) + + if not parts: + return pd.DataFrame(index=cell_ids) + + combined = pd.concat(parts, axis=1) + combined = combined.reindex(cell_ids) + return combined + + +# --------------------------------------------------------------------------- +# skimage regionprops +# --------------------------------------------------------------------------- + + +def _rename_intensity_col(col: str, channel_names: list[str]) -> str: + """Map a multichannel ``regionprops_table`` column to ``_``. + + A multichannel ``intensity_image`` makes skimage suffix each intensity prop + with the channel index (``intensity_mean-0``); rename that to the channel's + name (``intensity_mean_DAPI``). + """ + prop, _, idx = col.rpartition("-") + return f"{prop}_{channel_names[int(idx)]}" + + +def _regionprops_table_to_df(table: dict[str, np.ndarray], rename: Callable[[str], str] | None = None) -> pd.DataFrame: + """Build a label-indexed DataFrame from a ``regionprops_table`` dict. + + ``rename`` optionally maps each (non-label) column name. + """ + index = table.pop("label") + if rename is not None: + table = {rename(col): vals for col, vals in table.items()} + return pd.DataFrame(table, index=index) + + +def _compute_skimage_features( + labels: np.ndarray, + image: np.ndarray | None, + morphology_props: frozenset[str] | None, + intensity_props: frozenset[str] | None, + channel_names: list[str], +) -> pd.DataFrame: + """Compute skimage regionprops features for all cells in a tile. + + Uses :func:`skimage.measure.regionprops_table` (one vectorised Cython pass + over all cells, all channels) instead of a per-region Python loop. Morphology + props keep skimage's native flattened names (e.g. ``centroid-0``, + ``inertia_tensor-0-0``); intensity props are computed for every channel in a + single multichannel call and renamed ``_``. ``image`` is only + read for ``intensity_props`` and may be ``None`` for a morphology-only run. + """ + parts: list[pd.DataFrame] = [] + + if morphology_props is not None: + table = measure.regionprops_table(labels, properties=["label", *sorted(morphology_props)]) + parts.append(_regionprops_table_to_df(table)) + + if intensity_props is not None: + # One multichannel pass: moveaxis -> (y, x, c) so skimage treats the last + # axis as channels and computes every channel at once. + table = measure.regionprops_table( + labels, + intensity_image=np.moveaxis(image, 0, -1), + properties=["label", *sorted(intensity_props)], + ) + parts.append(_regionprops_table_to_df(table, lambda col: _rename_intensity_col(col, channel_names))) + + if not parts: + return pd.DataFrame() + return pd.concat(parts, axis=1) + + +# --------------------------------------------------------------------------- +# squidpy per-cell features +# --------------------------------------------------------------------------- + + +def _compute_squidpy_per_cell( + labels: np.ndarray, + image: np.ndarray, + parsed: _ParsedFeatures, + channel_names: list[str], +) -> pd.DataFrame: + """Compute squidpy features per cell within a tile. + + Only reached when a squidpy feature is requested, which always requires an + image (enforced by validation), so ``image`` is never ``None`` here. + """ + regions = measure.regionprops(labels) + n_channels = image.shape[0] + rows: dict[int, dict[str, float]] = {} + + for region in regions: + lid = region.label + bbox = region.bbox # (min_row, min_col, max_row, max_col) + cell_features: dict[str, float] = {} + + img_crop = image[:, bbox[0] : bbox[2], bbox[1] : bbox[3]] + mask_crop = labels[bbox[0] : bbox[2], bbox[1] : bbox[3]] == lid + + for ch_idx in range(n_channels): + ch_name = channel_names[ch_idx] + ch_crop = img_crop[ch_idx].astype(np.float32) + masked_vals = ch_crop[mask_crop] + + if len(masked_vals) == 0: + continue + + if parsed.squidpy_summary: + for stat, fn in (("mean", np.mean), ("std", np.std), ("min", np.min), ("max", np.max)): + cell_features[f"summary_{stat}_{ch_name}"] = float(fn(masked_vals)) + + if parsed.squidpy_texture: + cell_features.update(_glcm_features(ch_crop, mask_crop, ch_name)) + + if parsed.squidpy_color_hist: + cell_features.update(_histogram_features(masked_vals, ch_name)) + + rows[lid] = cell_features + + return pd.DataFrame.from_dict(rows, orient="index") + + +def _glcm_features(channel_crop: np.ndarray, mask: np.ndarray, ch_name: str) -> dict[str, float]: + """GLCM texture features for a single channel within a cell's bbox.""" + quant_levels = 32 + ch = channel_crop.copy() + # Zero out non-cell pixels so they don't affect GLCM + ch[~mask] = 0 + ch_min, ch_max = float(ch[mask].min()), float(ch[mask].max()) + if ch_max > ch_min: + ch = (ch - ch_min) / (ch_max - ch_min) + else: + ch = np.zeros_like(ch) + ch_q = np.clip((ch * (quant_levels - 1)).round().astype(np.uint8), 0, quant_levels - 1) + ch_q[~mask] = 0 + + try: + glcm = graycomatrix(ch_q, distances=[1], angles=[0], levels=quant_levels, symmetric=True, normed=True) + return { + f"texture_contrast_{ch_name}": float(graycoprops(glcm, "contrast")[0, 0]), + f"texture_dissimilarity_{ch_name}": float(graycoprops(glcm, "dissimilarity")[0, 0]), + f"texture_homogeneity_{ch_name}": float(graycoprops(glcm, "homogeneity")[0, 0]), + f"texture_energy_{ch_name}": float(graycoprops(glcm, "energy")[0, 0]), + f"texture_ASM_{ch_name}": float(graycoprops(glcm, "ASM")[0, 0]), + f"texture_correlation_{ch_name}": float(graycoprops(glcm, "correlation")[0, 0]), + } + except (ValueError, IndexError): + return {} + + +def _histogram_features(masked_vals: np.ndarray, ch_name: str, bins: int = 16) -> dict[str, float]: + """Per-cell intensity histogram features.""" + lo, hi = float(masked_vals.min()), float(masked_vals.max()) + hist, _ = np.histogram(masked_vals, bins=bins, range=(lo, hi if hi > lo else lo + 1)) + hist = hist.astype(np.float32) + hist_sum = hist.sum() + if hist_sum > 0: + hist = hist / hist_sum + return {f"color_hist_bin{b}_{ch_name}": float(v) for b, v in enumerate(hist)} + + +# --------------------------------------------------------------------------- +# Input preparation (lazy - returns xarray DataArrays, not numpy) +# --------------------------------------------------------------------------- + + +def _resolve_da(node: xr.DataTree | xr.DataArray, scale: str | None) -> xr.DataArray: + """Get a DataArray from a DataTree or single-scale element (stays lazy).""" + if not isinstance(node, xr.DataTree): + return node + if scale is None: + raise ValueError("Scale must be provided for DataTree data.") + if scale not in node: + raise ValueError(f"Scale '{scale}' not found. Available: {list(node.keys())}") + return node[scale].ds["image"] + + +def _validate_inputs( + sdata: SpatialData, + image_key: str | None, + labels_key: str | None, + shapes_key: str | None, + scale: str | None, +) -> None: + """Run structural input validation (no data loading). + + Feature-dependent rules (whether an image is required at all) live in + :func:`calculate_image_features`, which has the parsed feature set. + """ + if labels_key is None and shapes_key is None: + raise ValueError("Provide either `labels_key` or `shapes_key`.") + if labels_key is not None and shapes_key is not None: + raise ValueError("Use either `labels_key` or `shapes_key`, not both.") + if labels_key is not None and labels_key not in sdata.labels: + raise ValueError(f"Labels key '{labels_key}' not found, valid keys: {list(sdata.labels.keys())}") + if shapes_key is not None and shapes_key not in sdata.shapes: + raise ValueError(f"Shapes key '{shapes_key}' not found, valid keys: {list(sdata.shapes.keys())}") + if labels_key is not None and isinstance(sdata.labels[labels_key], xr.DataTree) and scale is None: + raise ValueError("When using multi-scale labels, please specify the scale.") + if image_key is not None: + if image_key not in sdata.images: + raise ValueError(f"Image key '{image_key}' not found, valid keys: {list(sdata.images.keys())}") + if isinstance(sdata.images[image_key], xr.DataTree) and scale is None: + raise ValueError("When using multi-scale images, please specify the scale.") + + +def _prepare_lazy( + sdata: SpatialData, + image_key: str | None, + labels_key: str | None, + shapes_key: str | None, + scale: str | None, + channels: list[str] | None, + align_mode: Literal["strict"], +) -> tuple[xr.DataArray | None, xr.DataArray, list[str]]: + """Return lazy image and labels DataArrays, plus channel names. + + ``image_da`` is ``None`` (and ``channel_names`` empty) for a morphology-only + run with no ``image_key``. Does NOT call ``.compute()`` - arrays stay lazy + for on-demand tile reads. For the shapes->labels path, labels are + materialized but wrapped in a DataArray for a uniform interface. + """ + _validate_inputs(sdata, image_key, labels_key, shapes_key, scale) + + # Only strict, axis-aligned image/labels are supported. The Literal narrows + # align_mode statically; this guard catches callers passing it dynamically. + if align_mode != "strict": + raise ValueError(f"`align_mode` must be 'strict'; got {align_mode!r}.") + + image_da = None + if image_key is not None: + image_da = _resolve_da(sdata.images[image_key], scale) + if "c" not in image_da.dims: + image_da = image_da.expand_dims("c") + + # Labels DataArray (lazy for labels_key, materialized for shapes_key). + if labels_key is not None: + labels_da = _resolve_da(sdata.labels[labels_key], scale) + else: + # shapes_key requires an image to size the rasterization grid (enforced + # by calculate_image_features), so image_da is not None here. + logg.info("Converting shapes to labels.") + img_shape = {d: image_da.sizes[d] for d in ("y", "x")} + try: + labels_result = rasterize( + sdata.shapes[shapes_key], + ["x", "y"], + min_coordinate=[0, 0], + max_coordinate=[img_shape["x"], img_shape["y"]], + target_coordinate_system="global", + target_unit_to_pixels=1.0, + return_regions_as_labels=True, + ) + except ValueError as e: + raise ValueError( + "Failed to rasterize shapes; geometries may be empty or unsupported. " + "Filter out empty/non-polygon geometries or choose a different shapes_key." + ) from e + labels_da = ( + labels_result + if isinstance(labels_result, xr.DataArray) + else xr.DataArray(np.asarray(labels_result), dims=["y", "x"]) + ) + + # Image and labels must share a pixel grid (only checkable with an image). + if image_da is not None and labels_key is not None: + iy, ix = image_da.sizes.get("y"), image_da.sizes.get("x") + ly, lx = labels_da.sizes.get("y"), labels_da.sizes.get("x") + if (iy, ix) != (ly, lx): + raise ValueError( + f"Image (y={iy}, x={ix}) and labels (y={ly}, x={lx}) have different " + f"pixel grids. Pre-align with `spatialdata.rasterize`." + ) + + if image_da is None: + return image_da, labels_da, [] + + # Resolve channel names through spatialdata's canonical accessor so we + # honor c_coords set at parse time. Always cast to str. + all_ch = [str(v) for v in get_channel_names(sdata.images[image_key])] + if len(all_ch) != image_da.sizes["c"]: + # Multiscale element where get_channel_names may report from a + # different scale than image_da. Fall back to positional naming. + all_ch = [str(i) for i in range(image_da.sizes["c"])] + + ch_names: list[str] + if channels is not None: + selected_idx: list[int] = [] + ch_names = [] + for ch in channels: + if not isinstance(ch, str): + raise TypeError( + f"channels must contain strings (channel names); got {type(ch).__name__} {ch!r}. " + f"Available channel names: {all_ch}." + ) + if ch not in all_ch: + raise ValueError(f"Channel '{ch}' not found. Available: {all_ch}") + selected_idx.append(all_ch.index(ch)) + ch_names.append(ch) + image_da = image_da.isel(c=selected_idx) + else: + ch_names = all_ch + + return image_da, labels_da, ch_names + + +def _compute_centroids( + sdata: SpatialData, + labels_key: str | None, + labels_da: xr.DataArray, + scale: str | None, +) -> dict: + """Compute cell centroids using the most efficient strategy available.""" + # Multiscale labels - use coarsest scale + if labels_key is not None and isinstance(sdata.labels[labels_key], xr.DataTree): + logg.info("Computing centroids from coarse scale.") + return compute_cell_info_multiscale(sdata.labels[labels_key], target_scale=scale or "scale0") + + # Small enough to fit in memory - direct regionprops + n_pixels = labels_da.sizes.get("y", 1) * labels_da.sizes.get("x", 1) + if n_pixels <= 4096 * 4096: + lbl_np = labels_da.values + if lbl_np.ndim > 2: + lbl_np = lbl_np.squeeze() + return compute_cell_info(lbl_np) + + # Large single-scale - tiled centroid computation + logg.info("Computing centroids in tiled mode (large single-scale labels).") + return compute_cell_info_tiled(labels_da) + + +# --------------------------------------------------------------------------- +# Main function +# --------------------------------------------------------------------------- + + +def calculate_image_features( + sdata: SpatialData, + image_key: str | None = None, + labels_key: str | None = None, + shapes_key: str | None = None, + scale: str | None = None, + channels: list[str] | None = None, + features: list[str] | str | None = None, + tile_size: int = 2048, + overlap_margin: int | Literal["auto"] = "auto", + align_mode: Literal["strict"] = "strict", + adata_key_added: str = "morphology", + invalid_as_zero: bool = True, + n_jobs: int = 1, + inplace: bool = True, +) -> ad.AnnData | None: + """ + Calculate per-cell features from segmentation masks. + + Uses scikit-image ``regionprops`` for morphological/intensity features + and squidpy-specific per-cell metrics (summary statistics, GLCM texture, + colour histograms). Large images are automatically tiled into + ``tile_size x tile_size`` chunks with overlap so that every cell is + fully contained in exactly one tile. + + Parameters + ---------- + sdata + SpatialData object. + image_key + Key in ``sdata.images``. Optional: required only for intensity / squidpy + features (and for ``shapes_key``). Morphology-only runs need no image. + labels_key + Key in ``sdata.labels`` with segmentation masks. + shapes_key + Key in ``sdata.shapes`` (rasterized to labels internally). + scale + Scale level for multi-scale data. + channels + Subset of channel names to use, matching those returned by + :func:`spatialdata.models.get_channel_names`. ``None`` uses all + channels. Integer indices are not accepted -- always pass names. + features + Which features to compute (required; ``None`` is rejected). A list of + flag strings drawn from two groups: + + - **skimage regionprops** -- ``"skimage:morphology"`` (all shape props, + from the mask alone) or ``"skimage:morphology:"`` for one + (e.g. ``area``); ``"skimage:intensity"`` (all per-channel intensity + props, needs an image) or ``"skimage:intensity:"`` for one + (e.g. ``intensity_mean``). Morphology columns use skimage's native + names (``area``, ``centroid-0``); intensity columns are suffixed with + the channel name. + - **squidpy per-cell** -- ``"squidpy:summary"`` (per-channel mean / std / + min / max), ``"squidpy:texture"`` (per-channel GLCM contrast, + dissimilarity, homogeneity, energy, ASM, correlation), and + ``"squidpy:color_hist"`` (per-channel intensity histogram). + tile_size + Side length of the tiling grid (pixels). + overlap_margin + Overlap around each tile to capture boundary cells. + ``"auto"`` computes the minimum from the largest cell's bounding box. + align_mode + Only ``"strict"`` is supported: require image and labels to + share the same pixel grid (same y/x sizes). Raise otherwise. + adata_key_added + Key under which to store the result in ``sdata.tables``. + invalid_as_zero + Replace ``inf`` and ``NaN`` values with zero. + n_jobs + Number of parallel jobs for tile processing. + inplace + If ``True``, store result in ``sdata.tables``. Otherwise return it. + + Returns + ------- + :class:`~anndata.AnnData` when ``inplace=False``, otherwise ``None``. + + Examples + -------- + >>> import squidpy as sq + >>> sq.experimental.im.calculate_image_features( + ... sdata, + ... image_key="image", + ... labels_key="cells", + ... features=["skimage:morphology", "skimage:intensity", "squidpy:summary"], + ... ) # doctest: +SKIP + + Morphology-only needs no image: + + >>> sq.experimental.im.calculate_image_features( + ... sdata, labels_key="cells", features=["skimage:morphology:area"] + ... ) # doctest: +SKIP + + The per-cell table is stored in ``sdata.tables["morphology"]``. + """ + # --- Parse & validate --- + parsed = _parse_features(features) + if not _has_any_features(parsed): + raise ValueError("No valid features requested.") + + # An image is needed only for intensity / squidpy features; morphology runs + # from the labels alone. Reject the cases that genuinely cannot proceed. + if image_key is None: + needs_image = _image_requiring_features(parsed) + if needs_image: + raise ValueError(f"Features {needs_image} require pixel data; pass `image_key`.") + if shapes_key is not None: + raise ValueError("`shapes_key` requires `image_key` (rasterization needs the image grid).") + if channels is not None: + raise ValueError("`channels` selection requires `image_key`.") + + drop_report = DropReport() + + image_da, labels_da, channel_names = _prepare_lazy( + sdata, image_key, labels_key, shapes_key, scale, channels, align_mode + ) + + # --- Warmup: compute centroids without materializing full arrays --- + cell_info = _compute_centroids(sdata, labels_key, labels_da, scale) + if not cell_info: + logg.info(drop_report.summary()) + raise ValueError("No cells found in labels (all zeros).") + + H = int(labels_da.sizes.get("y", labels_da.shape[-2])) + W = int(labels_da.sizes.get("x", labels_da.shape[-1])) + + # --- Tile --- + specs = build_tile_specs((H, W), cell_info, tile_size=tile_size, overlap_margin=overlap_margin) + total_tiles = len(specs) + logg.info(f"Processing {total_tiles} tiles ({tile_size}x{tile_size}, margin={overlap_margin}).") + + # --- Process tiles (each worker materializes only its own ~2k x 2k crop) --- + def _process_one(spec): + if image_da is None: + tile_lbl = extract_labels_tile_lazy(labels_da, spec) + return _featurize_tile(None, tile_lbl, parsed, channel_names) + tile_img, tile_lbl = extract_tile_lazy(image_da, labels_da, spec) + return _featurize_tile(tile_img, tile_lbl, parsed, channel_names) + + log_every = max(1, total_tiles // 10) + start_t = time.monotonic() + tile_dfs: list[pd.DataFrame] = [] + results_iter = Parallel(n_jobs=n_jobs, prefer="threads", return_as="generator_unordered")( + delayed(_process_one)(spec) for spec in specs + ) + for done, df in enumerate( + tqdm(results_iter, total=total_tiles, desc="Featurizing tiles", unit="tile"), + start=1, + ): + if df.empty: + drop_report.empty_tiles += 1 + else: + tile_dfs.append(df) + if done == 1 or done == total_tiles or done % log_every == 0: + elapsed = time.monotonic() - start_t + logg.info(f"Tile {done}/{total_tiles} done (elapsed {elapsed:.1f}s).") + + if not tile_dfs: + logg.info(drop_report.summary()) + raise ValueError("No features computed for any tile.") + + # Sort by cell label for deterministic output. inf/NaN handling happens + # in one numpy pass below to avoid two extra full-table allocations. + combined = pd.concat(tile_dfs, axis=0).sort_index() + + # --- Build AnnData --- + # Exactly one of labels_key / shapes_key is set (enforced in _validate_inputs). + region_key_value = labels_key or shapes_key + + arr = combined.to_numpy(dtype=np.float32, copy=True) + if invalid_as_zero: + np.nan_to_num(arr, copy=False, nan=0.0, posinf=0.0, neginf=0.0) + adata = ad.AnnData(X=arr) + adata.obs_names = [f"cell_{i}" for i in combined.index] + adata.var_names = list(combined.columns) + + adata.uns["spatialdata_attrs"] = { + "region": region_key_value, + "region_key": "region", + "instance_key": "label_id", + } + adata.obs["region"] = pd.Categorical.from_codes(np.zeros(len(adata), dtype=np.int8), categories=[region_key_value]) + + if shapes_key is not None and len(sdata.shapes[shapes_key]) == len(adata): + adata.obs["label_id"] = sdata.shapes[shapes_key].index.values + else: + adata.obs["label_id"] = combined.index.values + + logg.info(drop_report.summary()) + + if inplace: + sdata.tables[adata_key_added] = TableModel.parse(adata) + return None + return adata diff --git a/tests/experimental/test_calculate_image_features.py b/tests/experimental/test_calculate_image_features.py new file mode 100644 index 000000000..c8953f7a9 --- /dev/null +++ b/tests/experimental/test_calculate_image_features.py @@ -0,0 +1,873 @@ +"""Tests for calculate_image_features. + +Uses a small synthetic SpatialData (200x200 image, ~20 cells) so tests +run in seconds without downloading real data. +""" + +from __future__ import annotations + +import anndata as ad +import geopandas as gpd +import numpy as np +import pandas as pd +import pytest +import xarray as xr +from shapely import Polygon +from spatialdata import SpatialData +from spatialdata.models import Image2DModel, Labels2DModel, ShapesModel + +import squidpy as sq + + +@pytest.fixture() +def sdata_synthetic(): + """Synthetic SpatialData with a small 3-channel image and ~20 rectangular cells.""" + rng = np.random.default_rng(42) + H, W, C = 200, 200, 3 + + image_data = rng.integers(0, 255, (C, H, W), dtype=np.uint8) + image_xr = xr.DataArray( + image_data, + dims=["c", "y", "x"], + coords={"c": ["R", "G", "B"]}, + ) + + # Place ~20 rectangular cells in a grid (non-overlapping, 30x30 each) + labels_data = np.zeros((H, W), dtype=np.int32) + cell_id = 0 + for y in range(10, H - 30, 40): + for x in range(10, W - 30, 40): + cell_id += 1 + labels_data[y : y + 30, x : x + 30] = cell_id + + labels_xr = xr.DataArray(labels_data, dims=["y", "x"]) + + return SpatialData( + images={"test_img": Image2DModel.parse(image_xr)}, + labels={"test_labels": Labels2DModel.parse(labels_xr)}, + ) + + +class TestCalculateImageFeatures: + """Tests for calculate_image_features function.""" + + # --- Basic functionality --- + + def test_skimage_morphology_inplace(self, sdata_synthetic): + """Inplace stores AnnData in sdata.tables.""" + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:morphology"], + adata_key_added="morphology", + inplace=True, + ) + + assert "morphology" in sdata_synthetic.tables + adata = sdata_synthetic.tables["morphology"] + assert adata.n_obs > 0 + assert adata.n_vars > 0 + assert "spatialdata_attrs" in adata.uns + assert adata.uns["spatialdata_attrs"]["region"] == "test_labels" + assert "region" in adata.obs + assert "label_id" in adata.obs + + def test_not_inplace_returns_anndata(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:morphology"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert result.n_vars > 0 + assert "area" in result.var_names + + # --- Feature sources --- + # (the single-property `== ["area"]` contract is covered by the bare-string and + # morphology-only tests below.) + + def test_skimage_intensity(self, sdata_synthetic): + """skimage:intensity produces per-channel intensity features.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:intensity"], + inplace=False, + ) + assert result.n_vars > 0 + assert any("_" in col for col in result.var_names) + + def test_skimage_intensity_single_property(self, sdata_synthetic): + """Fine-grained: only intensity_mean per channel.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:intensity:intensity_mean"], + inplace=False, + ) + assert all(col.startswith("intensity_mean_") for col in result.var_names) + assert not any(col.startswith("intensity_max") for col in result.var_names) + + def test_cpmeasure_flag_raises_not_implemented(self, sdata_synthetic): + """cpmeasure:* flags are recognised but not yet implemented.""" + with pytest.raises(NotImplementedError, match="cp_measure feature `cpmeasure:sizeshape`"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["cpmeasure:sizeshape"], + inplace=False, + ) + + def test_features_none_raises(self, sdata_synthetic): + """features=None must be rejected; require an explicit choice.""" + with pytest.raises(ValueError, match="must be specified explicitly"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + inplace=False, + ) + + def test_align_mode_rasterize_rejected(self, sdata_synthetic): + """The Literal narrows align_mode to 'strict' statically; a runtime guard + catches dynamic callers passing other values.""" + with pytest.raises(ValueError, match="`align_mode` must be 'strict'"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:morphology:area"], + align_mode="rasterize", # type: ignore[arg-type] + inplace=False, + ) + + def test_squidpy_summary(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["squidpy:summary"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert any(col.startswith("summary_mean") for col in result.var_names) + + def test_squidpy_texture(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["squidpy:texture"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert any(col.startswith("texture_contrast") for col in result.var_names) + + def test_squidpy_color_hist(self, sdata_synthetic): + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["squidpy:color_hist"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert any(col.startswith("color_hist_bin") for col in result.var_names) + + # --- Validation errors --- + + def test_invalid_image_key(self, sdata_synthetic): + with pytest.raises(ValueError, match="Image key 'nonexistent' not found"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="nonexistent", + labels_key="test_labels", + features=["skimage:morphology"], + ) + + def test_invalid_labels_key(self, sdata_synthetic): + with pytest.raises(ValueError, match="Labels key 'nonexistent' not found"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="nonexistent", + features=["skimage:morphology"], + ) + + def test_both_labels_and_shapes_error(self, sdata_synthetic): + with pytest.raises(ValueError, match="Use either"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + shapes_key="fake", + features=["skimage:morphology"], + ) + + def test_missing_labels_and_shapes(self, sdata_synthetic): + with pytest.raises(ValueError, match="Provide either"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + features=["skimage:morphology"], + ) + + def test_invalid_feature(self, sdata_synthetic): + with pytest.raises(ValueError, match="Unknown feature") as excinfo: + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["nonexistent:measurement"], + ) + # cpmeasure:* names are recognised but always raise NotImplementedError; + # don't advertise them as "available" in the unknown-feature error. + assert "cpmeasure:" not in str(excinfo.value) + assert "squidpy:summary" in str(excinfo.value) + + def test_mixed_group_and_fine_grained_raises(self, sdata_synthetic): + """Mixing 'skimage:morphology' (all props) with 'skimage:morphology:area' (one prop) + is ambiguous; raise rather than silently take one or the other.""" + with pytest.raises(ValueError, match="ambiguous"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:morphology", "skimage:morphology:area"], + inplace=False, + ) + + def test_no_valid_features(self, sdata_synthetic): + with pytest.raises(ValueError, match="No valid features requested"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=[], + inplace=False, + ) + + def test_dimension_mismatch_strict_raises(self): + """Mismatched image/labels pixel grids must raise under align_mode='strict'.""" + rng = np.random.default_rng(42) + image_xr = xr.DataArray( + rng.integers(0, 255, (3, 200, 200), dtype=np.uint8), + dims=["c", "y", "x"], + coords={"c": ["R", "G", "B"]}, + ) + labels_arr = np.zeros((100, 100), dtype=np.int32) + labels_arr[10:40, 10:40] = 1 + labels_xr = xr.DataArray(labels_arr, dims=["y", "x"]) + sdata = SpatialData( + images={"img": Image2DModel.parse(image_xr)}, + labels={"lbl": Labels2DModel.parse(labels_xr)}, + ) + + with pytest.raises(ValueError, match="different .*pixel grids"): + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:morphology"], + inplace=False, + ) + + # --- Channel selection --- + + def test_channel_selection_by_name(self, sdata_synthetic): + """Selecting a single channel reduces feature columns.""" + result_all = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=["skimage:intensity:intensity_mean"], + inplace=False, + ) + # Image2DModel.parse converts channel coords to integers [0,1,2] + result_one = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + channels=["0"], + features=["skimage:intensity:intensity_mean"], + inplace=False, + ) + # All channels -> 3 columns; one channel -> 1 column + assert result_all.n_vars == 3 + assert result_one.n_vars == 1 + assert "intensity_mean_0" in result_one.var_names + + def test_channel_selection_rejects_int(self, sdata_synthetic): + """Integer channel indices are no longer accepted -- names only.""" + with pytest.raises(TypeError, match="channels must contain strings"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + channels=[0], # int, not str -- should fail validation + features=["squidpy:summary"], + inplace=False, + ) + + def test_channel_selection_invalid(self, sdata_synthetic): + with pytest.raises(ValueError, match="Channel 'DAPI' not found"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + channels=["DAPI"], + features=["skimage:morphology"], + ) + + # --- Tiled vs non-tiled equivalence --- + + def test_tiled_vs_single_tile_equivalence(self, sdata_synthetic): + """Tile-invariant features should be identical whether we tile or not. + + Position-dependent features (centroid, perimeter_crofton) are expected + to differ across tile boundaries, so we test with ``area`` and + ``squidpy:summary`` which depend only on the cell's pixel values. + """ + kw = { + "image_key": "test_img", + "labels_key": "test_labels", + "features": ["skimage:morphology:area", "squidpy:summary"], + "inplace": False, + "invalid_as_zero": True, + } + # Single tile (tile_size >= image -> no tiling) + result_single = sq.experimental.im.calculate_image_features(sdata_synthetic, tile_size=1000, **kw) + # Multiple tiles (tile_size=100 -> 4 tiles on 200x200) + result_tiled = sq.experimental.im.calculate_image_features(sdata_synthetic, tile_size=100, **kw) + + # Same cells, same features + assert result_single.n_obs == result_tiled.n_obs + assert set(result_single.var_names) == set(result_tiled.var_names) + + # Align columns and rows for comparison + common_cols = list(result_single.var_names) + df_single = pd.DataFrame(result_single.X, index=result_single.obs["label_id"].values, columns=common_cols) + df_tiled = pd.DataFrame( + result_tiled[:, common_cols].X, index=result_tiled.obs["label_id"].values, columns=common_cols + ) + df_single = df_single.sort_index() + df_tiled = df_tiled.sort_index() + + np.testing.assert_array_equal(df_single.index, df_tiled.index) + np.testing.assert_allclose(df_single.values, df_tiled.values, rtol=1e-5, atol=1e-5) + + # --- Parallelization --- + + def test_n_jobs_produces_same_result(self, sdata_synthetic): + """n_jobs>1 produces the same result as n_jobs=1.""" + kw = { + "image_key": "test_img", + "labels_key": "test_labels", + "features": ["skimage:morphology:area"], + "inplace": False, + } + result_seq = sq.experimental.im.calculate_image_features(sdata_synthetic, n_jobs=1, **kw) + result_par = sq.experimental.im.calculate_image_features(sdata_synthetic, n_jobs=2, **kw) + + assert result_seq.n_obs == result_par.n_obs + np.testing.assert_array_equal( + result_seq.X[np.argsort(result_seq.obs["label_id"].values)], + result_par.X[np.argsort(result_par.obs["label_id"].values)], + ) + + +# --------------------------------------------------------------------------- +# Behavioural regression tests +# --------------------------------------------------------------------------- + + +def _toy_sdata( + image_shape: tuple[int, int] = (200, 200), + n_channels: int = 3, + channel_names: list[str] | None = None, + labels_shape: tuple[int, int] | None = None, + labels_translation: tuple[float, float] | None = None, + labels_scale: tuple[float, float] | None = None, + label_ids: list[int] | None = None, +) -> SpatialData: + """Build a synthetic SpatialData with controllable label/image transforms.""" + from spatialdata.transformations import Scale, Translation, set_transformation + + rng = np.random.default_rng(0) + H, W = image_shape + image_data = rng.integers(0, 255, (n_channels, H, W), dtype=np.uint8) + image_xr = xr.DataArray(image_data, dims=["c", "y", "x"]) + + LH, LW = labels_shape if labels_shape is not None else image_shape + labels_data = np.zeros((LH, LW), dtype=np.int32) + ids = label_ids if label_ids is not None else list(range(1, 6)) + cell_h, cell_w = max(LH // 8, 4), max(LW // 8, 4) + for i, lid in enumerate(ids): + row = i // 3 + col = i % 3 + y0 = 10 + row * (cell_h + 6) + x0 = 10 + col * (cell_w + 6) + if y0 + cell_h > LH or x0 + cell_w > LW: + continue + labels_data[y0 : y0 + cell_h, x0 : x0 + cell_w] = lid + labels_xr = xr.DataArray(labels_data, dims=["y", "x"]) + + img_el = ( + Image2DModel.parse(image_xr, c_coords=channel_names) + if channel_names is not None + else Image2DModel.parse(image_xr) + ) + lbl_el = Labels2DModel.parse(labels_xr) + + if labels_translation is not None: + ty, tx = labels_translation + set_transformation(lbl_el, Translation([tx, ty], axes=("x", "y")), "global") + if labels_scale is not None: + sy, sx = labels_scale + set_transformation(lbl_el, Scale([sx, sy], axes=("x", "y")), "global") + + return SpatialData(images={"img": img_el}, labels={"lbl": lbl_el}) + + +class TestBehaviouralRegressions: + """Regression tests for previously-reported issues.""" + + # -- channel names are str-typed in output columns -- + + def test_concern1_channel_str_names_in_columns(self): + sdata = _toy_sdata(channel_names=["DAPI", "CD3", "CD8"]) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["squidpy:summary"], + inplace=False, + ) + cols = list(adata.var_names) + assert any("_DAPI" in c for c in cols) + assert any("_CD3" in c for c in cols) + assert any("_CD8" in c for c in cols) + # Make sure the numeric-fallback names did not slip in: + assert not any(c.endswith("_0") or c.endswith("_1") or c.endswith("_2") for c in cols) + + # -- progress logs are emitted -- + + def test_concern2_progress_log_emitted(self, capsys): + sdata = _toy_sdata() + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:morphology:area"], + tile_size=80, # forces >1 tile on 200x200 + inplace=False, + ) + captured = capsys.readouterr() + import re + + # spatialdata's logger renders via rich and injects ANSI escapes + # between tokens, so the digits in "Tile 1/9" are wrapped. + ansi_re = re.compile(r"\x1b\[[0-9;]*m") + plain = ansi_re.sub("", captured.out) + assert re.search(r"Tile \d+/\d+", plain), f"no progress log in:\n{plain}" + + # -- channel subset selection -- + + def test_concern4_channel_subset_by_name(self): + sdata = _toy_sdata(n_channels=4, channel_names=["c0", "c1", "c2", "c3"]) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["squidpy:summary"], + channels=["c0", "c2"], + inplace=False, + ) + cols = list(adata.var_names) + assert any("_c0" in c for c in cols) + assert any("_c2" in c for c in cols) + assert not any("_c1" in c for c in cols) + assert not any("_c3" in c for c in cols) + + # -- spatialdata_attrs on output table -- + + def test_concern5_spatialdata_attrs_present(self): + sdata = _toy_sdata() + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:morphology:area"], + inplace=True, + adata_key_added="morphology", + ) + attrs = sdata.tables["morphology"].uns["spatialdata_attrs"] + assert "region" in attrs + assert "region_key" in attrs + assert "instance_key" in attrs + assert attrs["region"] == "lbl" + + # -- non-contiguous label IDs survive the roundtrip -- + + def test_concern6_non_contiguous_label_ids(self): + sdata = _toy_sdata(label_ids=[1, 37, 82]) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:morphology:area"], + inplace=False, + ) + observed = set(adata.obs["label_id"].astype(int).tolist()) + assert {1, 37, 82}.issubset(observed) + + +# --------------------------------------------------------------------------- +# Feature-string parsing: accepted scalar form + contract error messages +# --------------------------------------------------------------------------- + + +class TestFeatureParsing: + """Parsing of the ``features`` argument and its error contract.""" + + def test_features_as_bare_string(self, sdata_synthetic): + """A single feature may be passed as a string, not just a list.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features="skimage:morphology:area", + inplace=False, + ) + assert list(result.var_names) == ["area"] + + @pytest.mark.parametrize( + ("features", "match"), + [ + # group flag after a fine-grained prop of the same group (reverse of + # the already-tested fine-after-group order). + (["skimage:morphology:area", "skimage:morphology"], "ambiguous"), + (["skimage:intensity:intensity_mean", "skimage:intensity"], "ambiguous"), + # unknown fine-grained property names, per group. + (["skimage:morphology:bogus"], "Unknown skimage morphology property"), + (["skimage:intensity:bogus"], "Unknown skimage intensity property"), + ], + ) + def test_parse_errors(self, sdata_synthetic, features, match): + with pytest.raises(ValueError, match=match): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="test_labels", + features=features, + inplace=False, + ) + + +# --------------------------------------------------------------------------- +# Shapes input: rasterized to labels internally +# --------------------------------------------------------------------------- + + +def _sdata_with_shapes() -> tuple[SpatialData, dict[int, float]]: + """3-channel image plus four square polygons of *distinct* sizes. + + Each polygon has a unique edge length (hence a unique rasterized area) and a + non-default index, so the label_id<->cell correspondence can be checked + instead of trivially comparing a default RangeIndex against itself. + + Returns the SpatialData and the expected ``{label_id: area}`` mapping. + """ + rng = np.random.default_rng(7) + image_xr = xr.DataArray( + rng.integers(0, 255, (3, 200, 200), dtype=np.uint8), + dims=["c", "y", "x"], + coords={"c": ["R", "G", "B"]}, + ) + centers = [(50, 50), (150, 50), (50, 150), (150, 150)] + edges = [20, 30, 40, 50] + index = [10, 20, 30, 40] + polys = [] + for (cx, cy), e in zip(centers, edges, strict=True): + h = e / 2.0 + polys.append(Polygon([(cx - h, cy - h), (cx + h, cy - h), (cx + h, cy + h), (cx - h, cy + h)])) + shapes = ShapesModel.parse(gpd.GeoDataFrame(geometry=polys, index=index)) + expected_area = {idx: float(e * e) for idx, e in zip(index, edges, strict=True)} + sdata = SpatialData(images={"test_img": Image2DModel.parse(image_xr)}, shapes={"cells": shapes}) + return sdata, expected_area + + +class TestShapesInput: + """The ``shapes_key`` path rasterizes polygons to labels internally.""" + + def test_shapes_input_featurized(self): + sdata, expected_area = _sdata_with_shapes() + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="test_img", + shapes_key="cells", + features=["skimage:morphology:area"], + inplace=False, + ) + # One row per polygon; region attr points at the shapes element. + assert adata.n_obs == 4 + assert adata.uns["spatialdata_attrs"]["region"] == "cells" + # label_id carries the (non-default) shapes index; the per-polygon distinct + # area proves each index maps to the *correct* cell, not just that the index + # equals itself. + observed_area = { + int(lid): float(ar) + for lid, ar in zip(adata.obs["label_id"].values, adata[:, "area"].X.ravel(), strict=True) + } + assert observed_area == expected_area + + def test_invalid_shapes_key(self, sdata_synthetic): + with pytest.raises(ValueError, match="Shapes key 'nope' not found"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + shapes_key="nope", + features=["skimage:morphology"], + inplace=False, + ) + + +# --------------------------------------------------------------------------- +# All-zero labels +# --------------------------------------------------------------------------- + + +def test_all_zero_labels_raises(sdata_synthetic): + """Labels with no foreground cells must raise a clear error.""" + zero_labels = xr.DataArray(np.zeros((200, 200), dtype=np.int32), dims=["y", "x"]) + sdata_synthetic.labels["empty"] = Labels2DModel.parse(zero_labels) + with pytest.raises(ValueError, match="No cells found in labels"): + sq.experimental.im.calculate_image_features( + sdata_synthetic, + image_key="test_img", + labels_key="empty", + features=["skimage:morphology:area"], + inplace=False, + ) + + +# --------------------------------------------------------------------------- +# GLCM texture on a flat (constant-intensity) channel +# --------------------------------------------------------------------------- + + +def test_texture_on_constant_channel(): + """A flat cell hits the degenerate GLCM branch and yields its forced values. + + With zero intensity variation, GLCM contrast and dissimilarity must be 0 and + homogeneity 1 (and nothing NaN) -- asserting the values, not just that texture + columns exist, locks the degenerate-branch behaviour against regressions. + """ + image_xr = xr.DataArray( + np.full((1, 100, 100), 100, dtype=np.uint8), + dims=["c", "y", "x"], + coords={"c": ["flat"]}, + ) + labels = np.zeros((100, 100), dtype=np.int32) + labels[20:50, 20:50] = 1 + labels_xr = xr.DataArray(labels, dims=["y", "x"]) + sdata = SpatialData( + images={"img": Image2DModel.parse(image_xr)}, + labels={"lbl": Labels2DModel.parse(labels_xr)}, + ) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["squidpy:texture"], + inplace=False, + ) + assert adata.n_obs == 1 + vals = {c: float(adata[:, c].X[0, 0]) for c in adata.var_names} + assert not np.isnan(list(vals.values())).any() + assert next(v for c, v in vals.items() if c.startswith("texture_contrast_")) == 0.0 + assert next(v for c, v in vals.items() if c.startswith("texture_dissimilarity_")) == 0.0 + assert next(v for c, v in vals.items() if c.startswith("texture_homogeneity_")) == 1.0 + + +# --------------------------------------------------------------------------- +# Multiscale (DataTree) image / labels +# --------------------------------------------------------------------------- + + +def _multiscale_sdata(multiscale_image: bool = True, multiscale_labels: bool = True) -> SpatialData: + """SpatialData whose image and/or labels are multiscale (DataTree-backed).""" + rng = np.random.default_rng(3) + image_xr = xr.DataArray( + rng.integers(0, 255, (3, 256, 256), dtype=np.uint8), + dims=["c", "y", "x"], + coords={"c": ["R", "G", "B"]}, + ) + labels = np.zeros((256, 256), dtype=np.int32) + cid = 0 + for y in range(20, 220, 60): + for x in range(20, 220, 60): + cid += 1 + labels[y : y + 30, x : x + 30] = cid + labels_xr = xr.DataArray(labels, dims=["y", "x"]) + + img = Image2DModel.parse(image_xr, scale_factors=[2] if multiscale_image else None) + lbl = Labels2DModel.parse(labels_xr, scale_factors=[2] if multiscale_labels else None) + return SpatialData(images={"img": img}, labels={"lbl": lbl}) + + +class TestMultiscale: + """Multi-scale (DataTree) inputs require an explicit ``scale`` and then work.""" + + def test_multiscale_image_requires_scale(self): + sdata = _multiscale_sdata(multiscale_image=True, multiscale_labels=False) + with pytest.raises(ValueError, match="multi-scale images"): + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:morphology:area"], + inplace=False, + ) + + def test_multiscale_labels_requires_scale(self): + sdata = _multiscale_sdata(multiscale_image=False, multiscale_labels=True) + with pytest.raises(ValueError, match="multi-scale labels"): + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + features=["skimage:morphology:area"], + inplace=False, + ) + + def test_multiscale_featurized_with_scale(self): + sdata = _multiscale_sdata(multiscale_image=True, multiscale_labels=True) + adata = sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + scale="scale0", + features=["skimage:morphology:area"], + inplace=False, + ) + # The fixture places 16 cells of 30x30=900 px at full resolution. Asserting + # the exact count, label IDs, and area proves scale0 was read (scale1 would + # give area ~225) and that no cell was silently dropped. + assert adata.n_obs == 16 + assert set(adata.obs["label_id"].astype(int)) == set(range(1, 17)) + np.testing.assert_array_equal(adata[:, "area"].X.ravel(), np.full(16, 900.0)) + + def test_invalid_scale_name(self): + sdata = _multiscale_sdata(multiscale_image=True, multiscale_labels=True) + with pytest.raises(ValueError, match="Scale 'scale9' not found"): + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + labels_key="lbl", + scale="scale9", + features=["skimage:morphology:area"], + inplace=False, + ) + + +# --------------------------------------------------------------------------- +# Shapes that fail to rasterize +# --------------------------------------------------------------------------- + + +def test_shapes_rasterize_failure_raises(): + """Empty geometries raise a clear, actionable error during rasterization.""" + image_xr = xr.DataArray( + np.random.default_rng(5).integers(0, 255, (3, 100, 100), dtype=np.uint8), + dims=["c", "y", "x"], + coords={"c": ["R", "G", "B"]}, + ) + # An empty polygon is unsupported by rasterize; the function should wrap the + # failure in an actionable error rather than let the raw one surface. + degenerate = ShapesModel.parse(gpd.GeoDataFrame(geometry=[Polygon()])) + sdata = SpatialData(images={"img": Image2DModel.parse(image_xr)}, shapes={"cells": degenerate}) + with pytest.raises(ValueError, match="Failed to rasterize shapes"): + sq.experimental.im.calculate_image_features( + sdata, + image_key="img", + shapes_key="cells", + features=["skimage:morphology:area"], + inplace=False, + ) + + +# --------------------------------------------------------------------------- +# Optional image_key: morphology needs only the labels +# --------------------------------------------------------------------------- + + +class TestOptionalImage: + """`image_key` is required only for intensity / squidpy features.""" + + def test_morphology_only_without_image(self, sdata_synthetic): + """skimage:morphology runs from the labels alone, no image_key.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + labels_key="test_labels", + features=["skimage:morphology:area"], + inplace=False, + ) + assert isinstance(result, ad.AnnData) + assert result.n_obs > 0 + assert list(result.var_names) == ["area"] + + def test_morphology_only_without_image_parallel(self, sdata_synthetic): + """The no-image path also works under threaded tile dispatch.""" + result = sq.experimental.im.calculate_image_features( + sdata_synthetic, + labels_key="test_labels", + features=["skimage:morphology:area"], + tile_size=100, # >1 tile on the 200x200 grid + n_jobs=2, + inplace=False, + ) + assert result.n_obs > 0 + + @pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"features": ["skimage:intensity"]}, r"require pixel data.*image_key"), + # mixed request: the error must name the offending (intensity) flag + ({"features": ["skimage:morphology", "skimage:intensity"]}, "skimage:intensity"), + ({"features": ["squidpy:summary"]}, "squidpy:summary"), + ({"features": ["skimage:morphology:area"], "channels": ["R"]}, "`channels` selection requires `image_key`"), + ], + ) + def test_requires_image_key_raises(self, sdata_synthetic, kwargs, match): + """Intensity / squidpy features and channel selection need image_key.""" + with pytest.raises(ValueError, match=match): + sq.experimental.im.calculate_image_features( + sdata_synthetic, labels_key="test_labels", inplace=False, **kwargs + ) + + def test_shapes_without_image_raises(self): + rng = np.random.default_rng(7) + shapes = ShapesModel.parse(gpd.GeoDataFrame(geometry=[Polygon([(40, 40), (70, 40), (70, 70), (40, 70)])])) + # An image exists in the object, but we deliberately do not pass image_key. + image_xr = xr.DataArray( + rng.integers(0, 255, (1, 200, 200), dtype=np.uint8), dims=["c", "y", "x"], coords={"c": ["x"]} + ) + sdata = SpatialData(images={"img": Image2DModel.parse(image_xr)}, shapes={"cells": shapes}) + with pytest.raises(ValueError, match="`shapes_key` requires `image_key`"): + sq.experimental.im.calculate_image_features( + sdata, + shapes_key="cells", + features=["skimage:morphology:area"], + inplace=False, + )