Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/squidpy/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@

from __future__ import annotations

from . import im, pl
from . import im, pl, tl

__all__ = ["im", "pl"]
__all__ = ["im", "pl", "tl"]
263 changes: 228 additions & 35 deletions src/squidpy/experimental/im/_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
the tile whose non-overlapping base region contains the centroid owns the
cell. Non-owned cells are zeroed out in each tile's mask so that
downstream processing never double-counts.

All functions accept pre-computed centroid dicts and image shapes — they
never materialize the full image or label array.
"""

from __future__ import annotations
Expand All @@ -13,6 +16,7 @@
from typing import Literal

import numpy as np
import xarray as xr
from skimage.measure import regionprops


Expand Down Expand Up @@ -49,8 +53,13 @@ class TileSpec:
owned_ids: frozenset[int] = field(default_factory=frozenset)


def _compute_cell_info(labels: np.ndarray) -> dict[int, CellInfo]:
"""Compute centroid and bounding-box size for every label.
# ---------------------------------------------------------------------------
# Centroid computation
# ---------------------------------------------------------------------------


def compute_cell_info(labels: np.ndarray) -> dict[int, CellInfo]:
"""Compute centroid and bounding-box size for every label from a numpy array.

Parameters
----------
Expand All @@ -75,32 +84,148 @@ def _compute_cell_info(labels: np.ndarray) -> dict[int, CellInfo]:
return info


def compute_cell_info_multiscale(
labels_node: xr.DataTree,
target_scale: str = "scale0",
) -> dict[int, CellInfo]:
"""Compute centroids using the coarsest scale of a multiscale label pyramid.

Reads only the smallest resolution, then scales coordinates to *target_scale*.
"""
available = list(labels_node.keys())
if not available:
return {}

# Pick coarsest scale (highest numeric suffix)
def _scale_idx(k: str) -> int:
num = "".join(c for c in k if c.isdigit())
return int(num) if num else 0

coarsest = max(available, key=_scale_idx)
coarse_da = labels_node[coarsest].ds["image"]
coarse_labels = np.asarray(coarse_da.values).squeeze()

if coarse_labels.ndim != 2:
raise ValueError(f"Expected 2-D labels at scale {coarsest}, got shape {coarse_labels.shape}")

# Compute scale factors to target resolution
target_da = labels_node[target_scale].ds["image"]
target_h, target_w = target_da.sizes.get("y", target_da.shape[-2]), target_da.sizes.get("x", target_da.shape[-1])
coarse_h, coarse_w = coarse_labels.shape
scale_y = target_h / coarse_h
scale_x = target_w / coarse_w

props = regionprops(coarse_labels)
return {
p.label: CellInfo(
label=p.label,
centroid_y=p.centroid[0] * scale_y,
centroid_x=p.centroid[1] * scale_x,
bbox_h=int(np.ceil((p.bbox[2] - p.bbox[0]) * scale_y)),
bbox_w=int(np.ceil((p.bbox[3] - p.bbox[1]) * scale_x)),
)
for p in props
}


def compute_cell_info_tiled(
labels_da: xr.DataArray,
chunk_size: int = 4096,
) -> dict[int, CellInfo]:
"""Compute centroids by reading label tiles — never materializes the full array.

For cells spanning multiple chunks, centroids are computed as
area-weighted means of per-chunk centroids.

Parameters
----------
labels_da
2-D (y, x) dask-backed xarray DataArray.
chunk_size
Size of chunks to read at a time.
"""
H = int(labels_da.sizes.get("y", labels_da.shape[-2]))
W = int(labels_da.sizes.get("x", labels_da.shape[-1]))

# Per-label accumulators: [sum_y*area, sum_x*area, total_area, min_y, max_y, min_x, max_x]
stats: dict[int, list[float]] = {}

for y0 in range(0, H, chunk_size):
y1 = min(y0 + chunk_size, H)
for x0 in range(0, W, chunk_size):
x1 = min(x0 + chunk_size, W)
chunk = labels_da.isel(y=slice(y0, y1), x=slice(x0, x1)).values
if chunk.ndim > 2:
chunk = chunk.squeeze()

for p in regionprops(chunk):
lid = p.label
cy_global = p.centroid[0] + y0
cx_global = p.centroid[1] + x0
area = p.area
min_row = p.bbox[0] + y0
max_row = p.bbox[2] + y0
min_col = p.bbox[1] + x0
max_col = p.bbox[3] + x0

if lid not in stats:
stats[lid] = [cy_global * area, cx_global * area, area, min_row, max_row, min_col, max_col]
else:
s = stats[lid]
s[0] += cy_global * area
s[1] += cx_global * area
s[2] += area
s[3] = min(s[3], min_row)
s[4] = max(s[4], max_row)
s[5] = min(s[5], min_col)
s[6] = max(s[6], max_col)

result: dict[int, CellInfo] = {}
for lid, s in stats.items():
if lid == 0:
continue
result[lid] = CellInfo(
label=lid,
centroid_y=s[0] / s[2],
centroid_x=s[1] / s[2],
bbox_h=int(s[4] - s[3]),
bbox_w=int(s[6] - s[5]),
)
return result


# ---------------------------------------------------------------------------
# Tile spec building
# ---------------------------------------------------------------------------


def _auto_margin(cell_info: dict[int, CellInfo]) -> int:
"""Compute the minimum margin that covers the largest cell's half-extent."""
if not cell_info:
return 0
max_half = max(max(c.bbox_h, c.bbox_w) for c in cell_info.values())
# Full bbox extent: a cell's centroid can be at most half a bbox away
# from its edge, so margin = ceil(max_extent / 2) guarantees coverage.
# Centroid can be at most half a bbox away from the cell's edge.
# Add 1 pixel for safety (rounding / off-by-one).
return int(np.ceil(max_half / 2)) + 1


def build_tile_specs(
labels: np.ndarray,
image_shape: tuple[int, int],
cell_info: dict[int, CellInfo],
tile_size: int = 2048,
overlap_margin: int | Literal["auto"] = "auto",
) -> list[TileSpec]:
"""Build tile specifications for a label image.
"""Build tile specifications from pre-computed centroids.

Each tile gets a non-overlapping *base* region (for centroid ownership)
and an extended *crop* region (base + margin on each side). Every
nonzero label is assigned to exactly one tile.
No pixel data is needed — only the image dimensions and centroid dict.

Parameters
----------
labels
2-D integer label image (0 = background).
image_shape
``(H, W)`` of the full-resolution image/labels.
cell_info
Pre-computed centroids from :func:`compute_cell_info`,
:func:`compute_cell_info_multiscale`, or :func:`compute_cell_info_tiled`.
tile_size
Side length of the non-overlapping base grid cells.
overlap_margin
Expand All @@ -112,14 +237,10 @@ def build_tile_specs(
List of :class:`TileSpec`, one per grid cell that owns at least one
label. Empty tiles (no cells) are omitted.
"""
if labels.ndim != 2:
raise ValueError(f"Expected 2-D labels, got shape {labels.shape}")
H, W = image_shape
if tile_size <= 0:
raise ValueError(f"tile_size must be positive, got {tile_size}")

H, W = labels.shape
cell_info = _compute_cell_info(labels)

if isinstance(overlap_margin, str) and overlap_margin == "auto":
margin = _auto_margin(cell_info)
else:
Expand Down Expand Up @@ -150,13 +271,11 @@ def build_tile_specs(
if not owned:
continue

# Base region (non-overlapping)
by0 = row * tile_size
bx0 = col * tile_size
by1 = min(by0 + tile_size, H)
bx1 = min(bx0 + tile_size, W)

# Crop region (with margin, clamped)
cy0 = max(by0 - margin, 0)
cx0 = max(bx0 - margin, 0)
cy1 = min(by1 + margin, H)
Expand All @@ -173,56 +292,130 @@ def build_tile_specs(
return specs


# ---------------------------------------------------------------------------
# Tile extraction
# ---------------------------------------------------------------------------


def extract_tile(
image: np.ndarray,
labels: np.ndarray,
spec: TileSpec,
) -> tuple[np.ndarray, np.ndarray]:
"""Extract a tile's image and mask, zeroing out non-owned cells.
"""Extract a tile from numpy arrays, zeroing out non-owned cells.

Parameters
----------
image
3-D array of shape ``(C, H, W)``.
``(C, H, W)`` numpy array.
labels
2-D integer label image of shape ``(H, W)``.
``(H, W)`` numpy label array.
spec
Tile specification from :func:`build_tile_specs`.
Tile specification.

Returns
-------
tile_image
Cropped image of shape ``(C, crop_h, crop_w)``.
tile_labels
Cropped label image with non-owned cells zeroed out.
tile_image, tile_labels
"""
cy0, cx0, cy1, cx1 = spec.crop
tile_image = image[:, cy0:cy1, cx0:cx1]
tile_labels = labels[cy0:cy1, cx0:cx1].copy()
_zero_non_owned(tile_labels, spec.owned_ids)
return tile_image, tile_labels

# Zero out labels not owned by this tile
unique_in_crop = np.unique(tile_labels)
for lid in unique_in_crop:
if lid != 0 and lid not in spec.owned_ids:
tile_labels[tile_labels == lid] = 0

def extract_tile_lazy(
image_da: xr.DataArray,
labels_da: xr.DataArray,
spec: TileSpec,
) -> tuple[np.ndarray, np.ndarray]:
"""Extract a tile from dask-backed xarray arrays.

Materializes only the tile's crop region (~2k×2k), not the full image.

Parameters
----------
image_da
``(c, y, x)`` dask-backed DataArray.
labels_da
``(y, x)`` dask-backed DataArray.
spec
Tile specification.

Returns
-------
tile_image
``(C, crop_h, crop_w)`` numpy array.
tile_labels
``(crop_h, crop_w)`` numpy array with non-owned cells zeroed.
"""
cy0, cx0, cy1, cx1 = spec.crop
tile_image = image_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values
tile_labels = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values.copy()
if tile_labels.ndim > 2:
tile_labels = tile_labels.squeeze()
_zero_non_owned(tile_labels, spec.owned_ids)
return tile_image, tile_labels


def extract_labels_tile_lazy(
labels_da: xr.DataArray,
spec: TileSpec,
) -> np.ndarray:
"""Extract a labels-only tile from a dask-backed DataArray.

Like :func:`extract_tile_lazy` but skips the image entirely.
Materializes only the crop region.

Parameters
----------
labels_da
``(y, x)`` dask-backed DataArray.
spec
Tile specification.

Returns
-------
``(crop_h, crop_w)`` numpy array with non-owned cells zeroed.
"""
cy0, cx0, cy1, cx1 = spec.crop
tile_labels = labels_da.isel(y=slice(cy0, cy1), x=slice(cx0, cx1)).values.copy()
if tile_labels.ndim > 2:
tile_labels = tile_labels.squeeze()
_zero_non_owned(tile_labels, spec.owned_ids)
return tile_labels


def _zero_non_owned(tile_labels: np.ndarray, owned_ids: frozenset[int]) -> None:
"""Zero out labels not in *owned_ids* (in-place)."""
for lid in np.unique(tile_labels):
if lid != 0 and lid not in owned_ids:
tile_labels[tile_labels == lid] = 0


# ---------------------------------------------------------------------------
# Coverage verification
# ---------------------------------------------------------------------------


def verify_coverage(
labels: np.ndarray,
all_label_ids: set[int],
specs: list[TileSpec],
) -> None:
"""Assert that tile specs provide full, non-overlapping cell coverage.

Parameters
----------
all_label_ids
Set of all nonzero label IDs expected in the image.
specs
Tile specifications to verify.

Raises
------
AssertionError
If any cell is missing or assigned to more than one tile.
"""
all_label_ids = set(np.unique(labels))
all_label_ids.discard(0)

owned_union: set[int] = set()
for spec in specs:
overlap = owned_union & spec.owned_ids
Expand Down
4 changes: 3 additions & 1 deletion src/squidpy/experimental/pl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

__all__ = []
from ._tiling_qc import tiling_qc

__all__ = ["tiling_qc"]
Loading