diff --git a/pyproject.toml b/pyproject.toml index 698fe2e..df3a1dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ requires-python = ">=3.10.15" dynamic = ["version"] dependencies = [ "ezmsg-baseproc>=1.0.2", - "ezmsg-sigproc>=2.10.0", + "ezmsg-sigproc>=2.14.0", "river>=0.22.0", "scikit-learn>=1.6.0", "torch>=2.6.0", @@ -74,3 +74,4 @@ known-third-party = ["ezmsg", "ezmsg.baseproc", "ezmsg.sigproc"] [tool.uv.sources] # Uncomment to use development version of ezmsg from git #ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" } +#ezmsg-sigproc = { path = "../ezmsg-sigproc", editable = true } \ No newline at end of file diff --git a/src/ezmsg/learn/process/ssr.py b/src/ezmsg/learn/process/ssr.py new file mode 100644 index 0000000..c4c81a2 --- /dev/null +++ b/src/ezmsg/learn/process/ssr.py @@ -0,0 +1,374 @@ +"""Self-supervised regression framework and LRR implementation. + +This module provides a general framework for self-supervised channel +regression via :class:`SelfSupervisedRegressionTransformer`, and a +concrete implementation — Linear Regression Rereferencing (LRR) — via +:class:`LRRTransformer`. + +**Framework.** The base class accumulates the channel covariance +``C = X^T X`` and solves per-cluster ridge regressions to obtain a weight +matrix *W*. Subclasses define what to *do* with *W* by implementing +:meth:`~SelfSupervisedRegressionTransformer._on_weights_updated` and +:meth:`~SelfSupervisedRegressionTransformer._process`. + +**LRR.** For each channel *c*, predict it from the other channels in its +cluster via ridge regression, then subtract the prediction:: + + y = X - X @ W = X @ (I - W) + +The effective weight matrix ``I - W`` is passed to +:class:`~ezmsg.sigproc.affinetransform.AffineTransformTransformer`, which +automatically exploits block-diagonal structure when ``channel_clusters`` +are provided. + +**Fitting.** Given data matrix *X* of shape ``(samples, channels)``, the +sufficient statistic is the channel covariance ``C = X^T X``. When +``incremental=True`` (default), *C* is accumulated across +:meth:`~SelfSupervisedRegressionTransformer.partial_fit` calls. + +**Solving.** Within each cluster the weight matrix *W* is obtained from +the inverse of the (ridge-regularised) cluster covariance +``C_inv = (C_cluster + lambda * I)^{-1}`` using the block-inverse identity:: + + W[:, c] = -C_inv[:, c] / C_inv[c, c], diag(W) = 0 + +This replaces the naive per-channel Cholesky loop with a single matrix +inverse per cluster, keeping the linear algebra in the source array +namespace so that GPU-backed arrays benefit from device-side computation. +""" + +from __future__ import annotations + +import os +import typing +from abc import abstractmethod +from pathlib import Path + +import ezmsg.core as ez +import numpy as np +from array_api_compat import get_namespace +from ezmsg.baseproc import ( + BaseAdaptiveTransformer, + BaseAdaptiveTransformerUnit, + processor_state, +) +from ezmsg.baseproc.protocols import SettingsType, StateType +from ezmsg.sigproc.affinetransform import ( + AffineTransformSettings, + AffineTransformTransformer, +) +from ezmsg.sigproc.util.array import array_device, xp_create +from ezmsg.util.messages.axisarray import AxisArray + +# --------------------------------------------------------------------------- +# Base: Self-supervised regression +# --------------------------------------------------------------------------- + + +class SelfSupervisedRegressionSettings(ez.Settings): + """Settings common to all self-supervised regression modes.""" + + weights: np.ndarray | str | Path | None = None + """Pre-calculated weight matrix *W* or path to a CSV file (``np.loadtxt`` + compatible). If provided, the transformer is ready immediately.""" + + axis: str | None = None + """Channel axis name. ``None`` defaults to the last dimension.""" + + channel_clusters: list[list[int]] | None = None + """Per-cluster regression. ``None`` treats all channels as one cluster.""" + + ridge_lambda: float = 0.0 + """Ridge (L2) regularisation parameter.""" + + incremental: bool = True + """When ``True``, accumulate ``X^T X`` across :meth:`partial_fit` calls. + When ``False``, each call replaces the previous statistics.""" + + +@processor_state +class SelfSupervisedRegressionState: + cxx: object | None = None # Array API; namespace matches source data. + n_samples: int = 0 + weights: object | None = None # Array API; namespace matches cxx. + + +class SelfSupervisedRegressionTransformer( + BaseAdaptiveTransformer[SettingsType, AxisArray, AxisArray, StateType], + typing.Generic[SettingsType, StateType], +): + """Abstract base for self-supervised regression transformers. + + Subclasses must implement: + + * :meth:`_on_weights_updated` — called whenever the weight matrix *W* is + (re)computed, so the subclass can build whatever internal transform it + needs (e.g. ``I - W`` for LRR). + * :meth:`_process` — the per-message transform step. + """ + + # -- message hash / state management ------------------------------------ + + def _hash_message(self, message: AxisArray) -> int: + axis = self.settings.axis or message.dims[-1] + axis_idx = message.get_axis_idx(axis) + return hash((message.key, message.data.shape[axis_idx])) + + def _reset_state(self, message: AxisArray) -> None: + axis = self.settings.axis or message.dims[-1] + axis_idx = message.get_axis_idx(axis) + n_channels = message.data.shape[axis_idx] + + self._validate_clusters(n_channels) + self._state.cxx = None + self._state.n_samples = 0 + self._state.weights = None + + # If pre-calculated weights are provided, load and go. + weights = self.settings.weights + if weights is not None: + if isinstance(weights, str): + weights = Path(os.path.abspath(os.path.expanduser(weights))) + if isinstance(weights, Path): + weights = np.loadtxt(weights, delimiter=",") + weights = np.asarray(weights, dtype=np.float64) + self._state.weights = weights + self._on_weights_updated() + + # -- cluster validation -------------------------------------------------- + + def _validate_clusters(self, n_channels: int) -> None: + """Raise if any cluster index is out of range.""" + clusters = self.settings.channel_clusters + if clusters is None: + return + all_indices = np.concatenate([np.asarray(g) for g in clusters]) + if np.any((all_indices < 0) | (all_indices >= n_channels)): + raise ValueError(f"channel_clusters contains out-of-range indices (valid range: 0..{n_channels - 1})") + + # -- weight solving ------------------------------------------------------ + + def _solve_weights(self, cxx): + """Solve all per-channel ridge regressions via matrix inverse. + + Uses the block-inverse identity: for target channel *c* with + references *r*, ``w_c = -C_inv[r, c] / C_inv[c, c]`` where + ``C_inv = (C_cluster + λI)⁻¹``. This replaces the per-channel + Cholesky loop with one matrix inverse per cluster. + + All computation stays in the source array namespace so that + GPU-backed arrays benefit from device-side execution. Cluster + results are scattered into the full matrix via a selection-matrix + multiply (``S @ W_cluster @ S^T``) to avoid numpy fancy indexing. + + Returns weight matrix *W* in the same namespace as *cxx*, with + ``diag(W) == 0``. + """ + xp = get_namespace(cxx) + dev = array_device(cxx) + n = cxx.shape[0] + + clusters = self.settings.channel_clusters + if clusters is None: + clusters = [list(range(n))] + + W = xp_create(xp.zeros, (n, n), dtype=cxx.dtype, device=dev) + eye_n = xp_create(xp.eye, n, dtype=cxx.dtype, device=dev) + + for cluster in clusters: + k = len(cluster) + if k <= 1: + continue + + idx_xp = xp.asarray(cluster) if dev is None else xp.asarray(cluster, device=dev) + eye_k = xp_create(xp.eye, k, dtype=cxx.dtype, device=dev) + + # Extract cluster sub-covariance (stays on device) + sub = xp.take(xp.take(cxx, idx_xp, axis=0), idx_xp, axis=1) + + if self.settings.ridge_lambda > 0: + sub = sub + self.settings.ridge_lambda * eye_k + + # One inverse per cluster + try: + sub_inv = xp.linalg.inv(sub) + except Exception: + sub_inv = xp.linalg.pinv(sub) + + # Diagonal via element-wise product with identity + diag_vals = xp.sum(sub_inv * eye_k, axis=0) + + # w_c = -C_inv[:, c] / C_inv[c, c], vectorised over all c + W_cluster = -(sub_inv / xp.reshape(diag_vals, (1, k))) + + # Zero the diagonal + W_cluster = W_cluster * (1.0 - eye_k) + + # Scatter into full W + if k == n: + W = W + W_cluster + else: + # Selection matrix: columns of eye(n) at cluster indices + S = xp.take(eye_n, idx_xp, axis=1) # (n, k) + W = W + xp.matmul(S, xp.matmul(W_cluster, xp.permute_dims(S, (1, 0)))) + + return W + + # -- partial_fit (self-supervised, accepts AxisArray) -------------------- + + def partial_fit(self, message: AxisArray) -> None: # type: ignore[override] + xp = get_namespace(message.data) + + if xp.any(xp.isnan(message.data)): + return + + # Hash check / state reset + msg_hash = self._hash_message(message) + if self._hash != msg_hash: + self._reset_state(message) + self._hash = msg_hash + + axis = self.settings.axis or message.dims[-1] + axis_idx = message.get_axis_idx(axis) + data = message.data + + # Move channel axis to last, flatten to 2-D + if axis_idx != data.ndim - 1: + perm = list(range(data.ndim)) + perm.append(perm.pop(axis_idx)) + data = xp.permute_dims(data, perm) + + n_channels = data.shape[-1] + X = xp.reshape(data, (-1, n_channels)) + + # Covariance stays in the source namespace for accumulation. + cxx_new = xp.matmul(xp.permute_dims(X, (1, 0)), X) + + if self.settings.incremental and self._state.cxx is not None: + self._state.cxx = self._state.cxx + cxx_new + else: + self._state.cxx = cxx_new + self._state.n_samples += int(X.shape[0]) + + self._state.weights = self._solve_weights(self._state.cxx) + self._on_weights_updated() + + # -- convenience APIs ---------------------------------------------------- + + def fit(self, X: np.ndarray) -> None: + """Batch fit from a raw numpy array (samples x channels).""" + n_channels = X.shape[-1] + self._validate_clusters(n_channels) + X = np.asarray(X, dtype=np.float64).reshape(-1, n_channels) + self._state.cxx = X.T @ X + self._state.n_samples = X.shape[0] + self._state.weights = self._solve_weights(self._state.cxx) + self._on_weights_updated() + + def fit_transform(self, message: AxisArray) -> AxisArray: + """Convenience: ``partial_fit`` then ``_process``.""" + self.partial_fit(message) + return self._process(message) + + # -- abstract hooks for subclasses --------------------------------------- + + @abstractmethod + def _on_weights_updated(self) -> None: + """Called after ``self._state.weights`` has been set/updated. + + Subclasses should build or refresh whatever internal transform + object they need for :meth:`_process`. + """ + ... + + @abstractmethod + def _process(self, message: AxisArray) -> AxisArray: ... + + +# --------------------------------------------------------------------------- +# Concrete: Linear Regression Rereferencing (LRR) +# --------------------------------------------------------------------------- + + +class LRRSettings(SelfSupervisedRegressionSettings): + """Settings for :class:`LRRTransformer`.""" + + min_cluster_size: int = 32 + """Passed to :class:`AffineTransformTransformer` for the block-diagonal + merge threshold.""" + + +@processor_state +class LRRState(SelfSupervisedRegressionState): + affine: AffineTransformTransformer | None = None + + +class LRRTransformer( + SelfSupervisedRegressionTransformer[LRRSettings, LRRState], +): + """Adaptive LRR transformer. + + ``partial_fit`` accepts a plain :class:`AxisArray` (self-supervised), + and the transform step is delegated to an internal :class:`AffineTransformTransformer`. + """ + + # -- state management (clear own state, then delegate to base) ---------- + + def _reset_state(self, message: AxisArray) -> None: + self._state.affine = None + super()._reset_state(message) + + # -- weights → affine transform ----------------------------------------- + + def _on_weights_updated(self) -> None: + xp = get_namespace(self._state.weights) + dev = array_device(self._state.weights) + n = self._state.weights.shape[0] + effective = xp_create(xp.eye, n, dtype=self._state.weights.dtype, device=dev) - self._state.weights + + # Prefer in-place weight update when the affine transformer supports + # it (avoids a full _reset_state round-trip on every partial_fit). + if self._state.affine is not None: + self._state.affine.set_weights(effective) + else: + self._state.affine = AffineTransformTransformer( + AffineTransformSettings( + weights=effective, + axis=self.settings.axis, + channel_clusters=self.settings.channel_clusters, + min_cluster_size=self.settings.min_cluster_size, + ) + ) + + # -- transform ----------------------------------------------------------- + + def _process(self, message: AxisArray) -> AxisArray: + if self._state.affine is None: + raise RuntimeError( + "LRRTransformer has not been fitted. Call partial_fit() or provide pre-calculated weights." + ) + return self._state.affine(message) + + +class LRRUnit( + BaseAdaptiveTransformerUnit[ + LRRSettings, + AxisArray, + AxisArray, + LRRTransformer, + ], +): + """ezmsg Unit wrapping :class:`LRRTransformer`. + + Follows the :class:`BaseAdaptiveDecompUnit` pattern — accepts + :class:`AxisArray` (not :class:`SampleMessage`) for self-supervised + training via ``INPUT_SAMPLE``. + """ + + SETTINGS = LRRSettings + + INPUT_SAMPLE = ez.InputStream(AxisArray) + + @ez.subscriber(INPUT_SAMPLE) + async def on_sample(self, msg: AxisArray) -> None: + await self.processor.apartial_fit(msg) diff --git a/tests/benchmark/bench_lrr.py b/tests/benchmark/bench_lrr.py new file mode 100644 index 0000000..d77955c --- /dev/null +++ b/tests/benchmark/bench_lrr.py @@ -0,0 +1,317 @@ +"""Performance benchmarks for LRRTransformer. + +Run with: + .venv/bin/python tests/benchmark/bench_lrr.py + +Benchmarks: + 1. _process (inference) — numpy, varying chunk sizes + 2. partial_fit (training) — numpy, varying chunk sizes + 3. _process — torch MPS (Apple Silicon GPU) + 4. partial_fit — torch MPS (Apple Silicon GPU) +""" + +import time + +import numpy as np +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.learn.process.ssr import LRRSettings, LRRTransformer + +# --------------------------------------------------------------------------- +# Parameters +# --------------------------------------------------------------------------- + +N_CH = 512 +N_CLUSTERS = 8 +CLUSTER_SIZE = N_CH // N_CLUSTERS # 64 +FS = 30_000.0 +CHUNK_SIZES = [20, 50, 100, 150, 200, 300] +WARMUP_ITERS = 20 +BENCH_ITERS = 200 + +CLUSTERS = [list(range(i * CLUSTER_SIZE, (i + 1) * CLUSTER_SIZE)) for i in range(N_CLUSTERS)] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_msg(data, key: str = "bench") -> AxisArray: + return AxisArray( + data=data, + dims=["time", "ch"], + axes={"time": AxisArray.TimeAxis(fs=FS, offset=0.0)}, + key=key, + ) + + +def _print_header(title: str) -> None: + print(f"\n{'=' * 70}") + print(f" {title}") + print(f"{'=' * 70}") + + +def _print_row(chunk: int, median_us: float, throughput_khz: float) -> None: + print(f" chunk={chunk:>4d} | {median_us:8.1f} us/call | {throughput_khz:8.1f} kHz effective") + + +def _bench_loop(fn, n_warmup: int, n_iters: int) -> list[float]: + """Run fn() for warmup + measured iterations, return list of elapsed times.""" + for _ in range(n_warmup): + fn() + times = [] + for _ in range(n_iters): + t0 = time.perf_counter() + fn() + times.append(time.perf_counter() - t0) + return times + + +def _bench_loop_sync(fn, sync_fn, n_warmup: int, n_iters: int) -> list[float]: + """Like _bench_loop but calls sync_fn() before each timing measurement.""" + for _ in range(n_warmup): + fn() + sync_fn() + times = [] + for _ in range(n_iters): + t0 = time.perf_counter() + fn() + sync_fn() + times.append(time.perf_counter() - t0) + return times + + +# --------------------------------------------------------------------------- +# NumPy benchmarks +# --------------------------------------------------------------------------- + + +def bench_process_numpy() -> None: + _print_header("_process (inference) — NumPy") + print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters") + print() + + rng = np.random.default_rng(0) + + # Fit via partial_fit so the message hash is primed for send() + fit_data = rng.standard_normal((2000, N_CH)) + proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1)) + proc.partial_fit(_make_msg(fit_data)) + + for chunk in CHUNK_SIZES: + data = rng.standard_normal((chunk, N_CH)) + msg = _make_msg(data) + # Prime — first send triggers the affine's _reset_state + proc.send(msg) + + times = _bench_loop(lambda: proc.send(msg), WARMUP_ITERS, BENCH_ITERS) + median_us = np.median(times) * 1e6 + throughput = chunk / np.median(times) # samples/s + _print_row(chunk, median_us, throughput / 1e3) + + +def bench_partial_fit_numpy() -> None: + _print_header("partial_fit (training) — NumPy") + print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters") + print() + + rng = np.random.default_rng(1) + proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1)) + + for chunk in CHUNK_SIZES: + data = rng.standard_normal((chunk, N_CH)) + msg = _make_msg(data) + # Prime + proc.partial_fit(msg) + + times = _bench_loop(lambda: proc.partial_fit(msg), WARMUP_ITERS, BENCH_ITERS) + median_us = np.median(times) * 1e6 + throughput = chunk / np.median(times) + _print_row(chunk, median_us, throughput / 1e3) + + +# --------------------------------------------------------------------------- +# Torch MPS benchmarks +# --------------------------------------------------------------------------- + + +def bench_process_mps() -> None: + import torch + + if not torch.backends.mps.is_available(): + print("\n [SKIPPED] MPS not available") + return + + _print_header("_process (inference) — Torch MPS") + print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters") + print() + + rng = np.random.default_rng(0) + device = torch.device("mps") + + # Fit on CPU (numpy), then send MPS data to trigger device conversion + fit_data = rng.standard_normal((2000, N_CH)) + proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1)) + proc.partial_fit(_make_msg(fit_data)) + + def sync(): + torch.mps.synchronize() + + for chunk in CHUNK_SIZES: + data_mps = torch.randn(chunk, N_CH, device=device, dtype=torch.float32) + msg = _make_msg(data_mps) + # Prime — first send triggers affine's _reset_state with device conversion + proc.send(msg) + + times = _bench_loop_sync(lambda: proc.send(msg), sync, WARMUP_ITERS, BENCH_ITERS) + median_us = np.median(times) * 1e6 + throughput = chunk / np.median(times) + _print_row(chunk, median_us, throughput / 1e3) + + +def bench_partial_fit_mps() -> None: + import torch + + if not torch.backends.mps.is_available(): + print("\n [SKIPPED] MPS not available") + return + + _print_header("partial_fit (training) — Torch MPS") + print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters") + print() + + _ = np.random.default_rng(1) + device = torch.device("mps") + proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1)) + + def sync(): + torch.mps.synchronize() + + for chunk in CHUNK_SIZES: + data_mps = torch.randn(chunk, N_CH, device=device, dtype=torch.float32) + msg = _make_msg(data_mps) + # Prime + proc.partial_fit(msg) + + times = _bench_loop_sync(lambda: proc.partial_fit(msg), sync, WARMUP_ITERS, BENCH_ITERS) + median_us = np.median(times) * 1e6 + throughput = chunk / np.median(times) + _print_row(chunk, median_us, throughput / 1e3) + + +# --------------------------------------------------------------------------- +# MLX benchmarks +# --------------------------------------------------------------------------- + + +def bench_process_mlx() -> None: + try: + import mlx.core as mx + except ImportError: + print("\n [SKIPPED] MLX not installed") + return + + _print_header("_process (inference) — MLX") + print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters") + print() + + rng = np.random.default_rng(0) + + # Fit on CPU (numpy), then send MLX data + fit_data = rng.standard_normal((2000, N_CH)) + proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1)) + proc.partial_fit(_make_msg(fit_data)) + + def sync(): + mx.eval() + + for chunk in CHUNK_SIZES: + data_mlx = mx.random.normal(shape=(chunk, N_CH)) + msg = _make_msg(data_mlx) + # Prime — first send triggers affine's _reset_state with MLX conversion + out = proc.send(msg) + mx.eval(out.data) + + def run(): + out = proc.send(msg) + mx.eval(out.data) + + times = _bench_loop(run, WARMUP_ITERS, BENCH_ITERS) + median_us = np.median(times) * 1e6 + throughput = chunk / np.median(times) + _print_row(chunk, median_us, throughput / 1e3) + + +def bench_partial_fit_mlx() -> None: + try: + import mlx.core as mx + except ImportError: + print("\n [SKIPPED] MLX not installed") + return + + _print_header("partial_fit (training) — MLX") + print(f" {N_CH} channels, {N_CLUSTERS}x{CLUSTER_SIZE} clusters, {WARMUP_ITERS} warmup, {BENCH_ITERS} iters") + # MLX linalg.inv doesn't support GPU yet; run inv on CPU stream + print(" NOTE: linalg.inv runs on mx.cpu stream (GPU not supported)") + print() + + import mlx.core as mx + + _ = np.random.default_rng(1) + proc = LRRTransformer(LRRSettings(channel_clusters=CLUSTERS, min_cluster_size=1)) + + # Monkey-patch _solve_weights to use mx.cpu stream for inv + original_solve = proc._solve_weights + + def _solve_weights_cpu_inv(cxx): + from array_api_compat import get_namespace + + xp = get_namespace(cxx) + # If this is MLX, we need to override linalg.inv + if xp.__name__ == "mlx.core": + orig_inv = mx.linalg.inv + mx.linalg.inv = lambda a: orig_inv(a, stream=mx.cpu) + try: + return original_solve(cxx) + finally: + mx.linalg.inv = orig_inv + return original_solve(cxx) + + proc._solve_weights = _solve_weights_cpu_inv + + for chunk in CHUNK_SIZES: + data_mlx = mx.random.normal(shape=(chunk, N_CH)) + msg = _make_msg(data_mlx) + # Prime + proc.partial_fit(msg) + + def run(): + proc.partial_fit(msg) + mx.eval() + + times = _bench_loop(run, WARMUP_ITERS, BENCH_ITERS) + median_us = np.median(times) * 1e6 + throughput = chunk / np.median(times) + _print_row(chunk, median_us, throughput / 1e3) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + print(f"LRRTransformer benchmark: {N_CH} channels, {N_CLUSTERS} clusters of {CLUSTER_SIZE}, fs={FS / 1e3:.0f} kHz") + + bench_process_numpy() + bench_partial_fit_numpy() + bench_process_mps() + bench_partial_fit_mps() + bench_process_mlx() + bench_partial_fit_mlx() + + print() + realtime_budget_us = {c: c / FS * 1e6 for c in CHUNK_SIZES} + print("Real-time budgets at 30 kHz:") + for chunk, budget in realtime_budget_us.items(): + print(f" chunk={chunk:>4d} -> {budget:8.1f} us") diff --git a/tests/unit/test_ssr.py b/tests/unit/test_ssr.py new file mode 100644 index 0000000..83d035d --- /dev/null +++ b/tests/unit/test_ssr.py @@ -0,0 +1,324 @@ +"""Tests for ezmsg.learn.process.ssr (Linear Regression Rereferencing).""" + +import tempfile + +import numpy as np +import pytest +from ezmsg.util.messages.axisarray import AxisArray + +from ezmsg.learn.process.ssr import LRRSettings, LRRTransformer + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_axisarray( + data: np.ndarray, + fs: float = 100.0, + ch_axis: str = "ch", + dims: list[str] | None = None, + key: str = "test", +) -> AxisArray: + """Create an AxisArray from 2-D (time x ch) data.""" + if dims is None: + dims = ["time", ch_axis] + axes = {"time": AxisArray.TimeAxis(fs=fs, offset=0.0)} + return AxisArray(data=data, dims=dims, axes=axes, key=key) + + +def _random_data(n_times: int = 200, n_ch: int = 8, rng=None) -> np.ndarray: + if rng is None: + rng = np.random.default_rng(42) + return rng.standard_normal((n_times, n_ch)) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestFitThenProcessShape: + def test_fit_then_process_shape(self): + """Output shape must match input shape.""" + rng = np.random.default_rng(0) + X = _random_data(rng=rng) + msg = _make_axisarray(X) + + proc = LRRTransformer(LRRSettings()) + proc.partial_fit(msg) + out = proc.send(msg) + + assert isinstance(out, AxisArray) + assert out.data.shape == X.shape + + +class TestProcessBeforeFitRaises: + def test_process_before_fit_raises(self): + """Calling process before fitting must raise RuntimeError.""" + msg = _make_axisarray(_random_data()) + proc = LRRTransformer(LRRSettings()) + with pytest.raises(RuntimeError, match="not been fitted"): + proc.send(msg) + + +class TestEffectiveWeightsIMinusW: + def test_effective_weights_I_minus_W(self): + """Output equals X @ (I - W) computed manually.""" + rng = np.random.default_rng(1) + X = _random_data(n_times=300, n_ch=4, rng=rng) + msg = _make_axisarray(X) + + proc = LRRTransformer(LRRSettings()) + proc.partial_fit(msg) + out = proc.send(msg) + + W = proc.state.weights + expected = X @ (np.eye(W.shape[0]) - W) + np.testing.assert_allclose(out.data, expected, atol=1e-10) + + +class TestDiagonalZero: + def test_diagonal_zero(self): + """Diagonal of W must always be zero.""" + rng = np.random.default_rng(2) + X = _random_data(rng=rng) + msg = _make_axisarray(X) + + proc = LRRTransformer(LRRSettings()) + proc.partial_fit(msg) + + np.testing.assert_array_equal(np.diag(proc.state.weights), 0.0) + + +class TestChannelClusters: + def test_channel_clusters(self): + """Cross-cluster weights must be zero; within-cluster weights non-zero.""" + rng = np.random.default_rng(3) + n_ch = 8 + clusters = [[0, 1, 2, 3], [4, 5, 6, 7]] + X = _random_data(n_ch=n_ch, rng=rng) + msg = _make_axisarray(X) + + proc = LRRTransformer(LRRSettings(channel_clusters=clusters)) + proc.partial_fit(msg) + + W = proc.state.weights + + # Cross-cluster should be zero + for c1 in clusters: + for c2 in clusters: + if c1 is c2: + continue + cross = W[np.ix_(c1, c2)] + np.testing.assert_array_equal(cross, 0.0) + + # Within-cluster (off-diagonal) should be non-zero + for cluster in clusters: + sub = W[np.ix_(cluster, cluster)] + off_diag = sub[~np.eye(len(cluster), dtype=bool)] + assert np.any(off_diag != 0), "Expected non-zero within-cluster weights" + + +class TestIncrementalAccumulates: + def test_incremental_accumulates(self): + """Two partial_fits with incremental=True should match one fit on concatenated data.""" + rng = np.random.default_rng(4) + X1 = _random_data(n_times=100, rng=rng) + X2 = _random_data(n_times=100, rng=rng) + + # Incremental: two calls + proc_inc = LRRTransformer(LRRSettings(incremental=True)) + proc_inc.partial_fit(_make_axisarray(X1)) + proc_inc.partial_fit(_make_axisarray(X2)) + + # Batch: one call on concatenated data + proc_batch = LRRTransformer(LRRSettings(incremental=False)) + proc_batch.partial_fit(_make_axisarray(np.concatenate([X1, X2], axis=0))) + + np.testing.assert_allclose(proc_inc.state.weights, proc_batch.state.weights, atol=1e-10) + + +class TestBatchResetsEachCall: + def test_batch_resets_each_call(self): + """With incremental=False, the second partial_fit ignores the first.""" + rng = np.random.default_rng(5) + X1 = _random_data(n_times=100, rng=rng) + X2 = _random_data(n_times=100, rng=rng) + + # Non-incremental: two calls + proc = LRRTransformer(LRRSettings(incremental=False)) + proc.partial_fit(_make_axisarray(X1)) + proc.partial_fit(_make_axisarray(X2)) + + # Reference: single fit on X2 only + proc_ref = LRRTransformer(LRRSettings(incremental=False)) + proc_ref.partial_fit(_make_axisarray(X2)) + + np.testing.assert_allclose(proc.state.weights, proc_ref.state.weights, atol=1e-10) + + +class TestRidgeHandlesCollinearity: + def test_ridge_handles_collinearity(self): + """Identical channels should not crash when ridge_lambda > 0.""" + rng = np.random.default_rng(6) + base = rng.standard_normal((200, 1)) + X = np.hstack([base, base, rng.standard_normal((200, 2))]) + msg = _make_axisarray(X) + + proc = LRRTransformer(LRRSettings(ridge_lambda=1.0)) + proc.partial_fit(msg) + out = proc.send(msg) + + assert out.data.shape == X.shape + assert np.all(np.isfinite(out.data)) + + +class TestNanDataSkipped: + def test_nan_data_skipped(self): + """partial_fit with NaN data is a no-op.""" + rng = np.random.default_rng(7) + X_good = _random_data(rng=rng) + X_nan = _random_data(rng=rng) + X_nan[0, 0] = np.nan + + proc = LRRTransformer(LRRSettings()) + proc.partial_fit(_make_axisarray(X_good)) + W_before = proc.state.weights.copy() + + proc.partial_fit(_make_axisarray(X_nan)) + np.testing.assert_array_equal(proc.state.weights, W_before) + + +class TestCustomAxisName: + def test_custom_axis_name(self): + """Works when the channel axis has a custom name like 'sensor'.""" + rng = np.random.default_rng(8) + X = _random_data(n_ch=4, rng=rng) + msg = AxisArray( + data=X, + dims=["time", "sensor"], + axes={"time": AxisArray.TimeAxis(fs=100.0, offset=0.0)}, + key="test", + ) + + proc = LRRTransformer(LRRSettings(axis="sensor")) + proc.partial_fit(msg) + out = proc.send(msg) + + assert out.data.shape == X.shape + + +class TestNonLastAxis: + def test_non_last_axis(self): + """Channel axis in a middle position.""" + rng = np.random.default_rng(9) + n_ch = 4 + # shape: (ch, time) — channels first + X = rng.standard_normal((n_ch, 50)) + msg = AxisArray( + data=X, + dims=["ch", "time"], + axes={"time": AxisArray.TimeAxis(fs=100.0, offset=0.0)}, + key="test", + ) + + proc = LRRTransformer(LRRSettings(axis="ch")) + proc.partial_fit(msg) + out = proc.send(msg) + + assert out.data.shape == X.shape + + +class TestFitTransform: + def test_fit_transform(self): + """fit_transform matches separate partial_fit + process.""" + rng = np.random.default_rng(10) + X = _random_data(rng=rng) + msg = _make_axisarray(X) + + proc1 = LRRTransformer(LRRSettings()) + out1 = proc1.fit_transform(msg) + + proc2 = LRRTransformer(LRRSettings()) + proc2.partial_fit(msg) + out2 = proc2.send(msg) + + np.testing.assert_allclose(out1.data, out2.data, atol=1e-12) + + +class TestInvalidClusterIndicesRaise: + def test_invalid_cluster_indices_raise(self): + """Out-of-range indices in channel_clusters should raise ValueError.""" + rng = np.random.default_rng(11) + X = _random_data(n_ch=4, rng=rng) + msg = _make_axisarray(X) + + proc = LRRTransformer(LRRSettings(channel_clusters=[[0, 1, 99]])) + with pytest.raises(ValueError, match="out-of-range"): + proc.partial_fit(msg) + + +class TestClustersEngageBlockDiagonal: + def test_clusters_engage_block_diagonal(self): + """When clusters create a block-diagonal I-W, AffineTransform uses cluster opt.""" + rng = np.random.default_rng(12) + n_ch = 8 + clusters = [[0, 1, 2, 3], [4, 5, 6, 7]] + X = _random_data(n_ch=n_ch, n_times=300, rng=rng) + msg = _make_axisarray(X) + + proc = LRRTransformer(LRRSettings(channel_clusters=clusters, min_cluster_size=1)) + proc.partial_fit(msg) + out = proc.send(msg) + + # Verify output is correct — the block-diagonal path should produce + # the same result as a full matmul. + W = proc.state.weights + expected = X @ (np.eye(n_ch) - W) + np.testing.assert_allclose(out.data, expected, atol=1e-10) + + +class TestPrecalculatedWeights: + def test_precalculated_weights(self): + """Pre-calculated weights skip fit and produce correct output.""" + rng = np.random.default_rng(13) + n_ch = 4 + X = _random_data(n_ch=n_ch, rng=rng) + + # Fit once to get weights + proc_fit = LRRTransformer(LRRSettings()) + proc_fit.partial_fit(_make_axisarray(X)) + W = proc_fit.state.weights.copy() + + # Use pre-calculated weights + proc_pre = LRRTransformer(LRRSettings(weights=W)) + msg = _make_axisarray(X) + out = proc_pre.send(msg) + + expected = X @ (np.eye(n_ch) - W) + np.testing.assert_allclose(out.data, expected, atol=1e-10) + + +class TestPrecalculatedWeightsFromFile: + def test_precalculated_weights_from_file(self): + """Load pre-calculated weights from a CSV file.""" + rng = np.random.default_rng(14) + n_ch = 4 + X = _random_data(n_ch=n_ch, rng=rng) + + # Fit once to get weights + proc_fit = LRRTransformer(LRRSettings()) + proc_fit.partial_fit(_make_axisarray(X)) + W = proc_fit.state.weights.copy() + + with tempfile.NamedTemporaryFile(suffix=".csv", delete=False, mode="w") as f: + np.savetxt(f, W, delimiter=",") + path = f.name + + proc_pre = LRRTransformer(LRRSettings(weights=path)) + msg = _make_axisarray(X) + out = proc_pre.send(msg) + + expected = X @ (np.eye(n_ch) - W) + np.testing.assert_allclose(out.data, expected, atol=1e-10)