diff --git a/docs/source/guides/array_api.rst b/docs/source/guides/array_api.rst
new file mode 100644
index 0000000..d7e15d5
--- /dev/null
+++ b/docs/source/guides/array_api.rst
@@ -0,0 +1,246 @@
+Array API Compatibility
+=======================
+
+ezmsg-learn uses the `Array API standard `_
+to allow processors to operate on arrays from different backends — NumPy, CuPy,
+PyTorch, and others — without code changes.
+
+.. contents:: On this page
+ :local:
+ :depth: 2
+
+
+How It Works
+------------
+
+Modules that support the Array API derive the array namespace from their input
+data using ``array_api_compat.get_namespace()``:
+
+.. code-block:: python
+
+ from array_api_compat import get_namespace
+
+ def process(self, data):
+ xp = get_namespace(data) # numpy, cupy, torch, etc.
+ result = xp.linalg.inv(data) # dispatches to the right backend
+ return result
+
+This means that if you pass a CuPy array, all computation stays on the GPU.
+If you pass a NumPy array, it behaves exactly as before.
+
+Helper utilities from ``ezmsg.sigproc.util.array`` handle device placement
+and creation functions portably:
+
+- ``array_device(x)`` — returns the device of an array, or ``None``
+- ``xp_create(fn, *args, dtype=None, device=None)`` — calls creation
+ functions (``zeros``, ``eye``) with optional device
+- ``xp_asarray(xp, obj, dtype=None, device=None)`` — portable ``asarray``
+
+
+Module Compatibility
+--------------------
+
+The table below summarises the Array API status of each module.
+
+Fully compatible
+^^^^^^^^^^^^^^^^
+
+These modules perform all computation in the source array namespace.
+
+.. list-table::
+ :header-rows: 1
+ :widths: 35 65
+
+ * - Module
+ - Notes
+ * - ``process.ssr``
+ - LRR / self-supervised regression. Full Array API.
+ * - ``model.cca``
+ - Incremental CCA. Replaced ``scipy.linalg.sqrtm`` with an
+ eigendecomposition-based inverse square root using only Array API ops.
+ * - ``process.rnn``
+ - PyTorch-native; operates on ``torch.Tensor`` throughout.
+
+Mostly compatible (with NumPy boundaries)
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+These modules use the Array API for data manipulation but fall back to NumPy
+at specific points where a dependency requires it.
+
+.. list-table::
+ :header-rows: 1
+ :widths: 25 35 40
+
+ * - Module
+ - NumPy boundary
+ - Reason
+ * - ``model.refit_kalman``
+ - ``_compute_gain()``
+ - ``scipy.linalg.solve_discrete_are`` has no Array API equivalent.
+ Matrices are converted to NumPy for the DARE solver, then converted back.
+ * - ``model.refit_kalman``
+ - ``refit()`` mutation loop
+ - Per-sample velocity remapping uses ``np.linalg.norm`` on small vectors
+ and scalar element assignment.
+ * - ``process.refit_kalman``
+ - Inherits boundaries from model
+ - State init and output arrays use the source namespace.
+ * - ``process.slda``
+ - ``predict_proba``
+ - sklearn ``LinearDiscriminantAnalysis`` requires NumPy input.
+ * - ``process.adaptive_linear_regressor``
+ - ``partial_fit`` / ``predict``
+ - sklearn and river models require NumPy / pandas input.
+ * - ``dim_reduce.adaptive_decomp``
+ - ``partial_fit`` / ``transform``
+ - sklearn ``IncrementalPCA`` and ``MiniBatchNMF`` require NumPy input.
+
+Not converted
+^^^^^^^^^^^^^
+
+These modules use NumPy directly. Conversion would provide little benefit
+because the underlying estimator is the bottleneck.
+
+.. list-table::
+ :header-rows: 1
+ :widths: 25 75
+
+ * - Module
+ - Reason
+ * - ``process.linear_regressor``
+ - Thin wrapper around sklearn ``LinearModel.predict``.
+ Could be made compatible if sklearn's ``array_api_dispatch`` is enabled
+ (see below).
+ * - ``process.sgd``
+ - sklearn ``SGDClassifier`` has no Array API support.
+ * - ``process.sklearn``
+ - Generic wrapper for arbitrary models; cannot assume Array API support.
+ * - ``dim_reduce.incremental_decomp``
+ - Delegates to ``adaptive_decomp``; trivial numpy usage (``np.prod`` on
+ Python tuples).
+
+
+sklearn Array API Dispatch
+--------------------------
+
+scikit-learn 1.8+ has experimental support for Array API dispatch on a subset
+of estimators. Two estimators used in ezmsg-learn are on the supported list:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 30 30 40
+
+ * - Estimator
+ - Used in
+ - Constraint
+ * - ``LinearDiscriminantAnalysis``
+ - ``process.slda``
+ - Requires ``solver="svd"`` (the ``"lsqr"`` solver with ``shrinkage``
+ is not supported)
+ * - ``Ridge``
+ - ``process.linear_regressor``
+ - Requires ``solver="svd"``
+
+To use dispatch, enable it before creating the estimator:
+
+.. code-block:: python
+
+ from sklearn import set_config
+ set_config(array_api_dispatch=True)
+
+.. warning::
+
+ - ``array_api_dispatch`` is marked **experimental** in sklearn.
+ - Solver constraints (``solver="svd"``) may produce slightly different
+ numerical results compared to other solvers.
+ - Enabling dispatch globally may affect other sklearn estimators in the
+ same process.
+ - ezmsg-learn does **not** enable dispatch by default.
+
+Estimators that do **not** support Array API dispatch:
+
+- ``IncrementalPCA``, ``MiniBatchNMF`` — only batch ``PCA`` is supported
+- ``SGDClassifier``, ``SGDRegressor``, ``PassiveAggressiveRegressor``
+- All river models
+
+
+Writing Array API Compatible Code
+----------------------------------
+
+When adding or modifying processors in ezmsg-learn, follow these patterns.
+
+Deriving the namespace
+^^^^^^^^^^^^^^^^^^^^^^
+
+Always derive ``xp`` from the input data, not from a hardcoded ``numpy``:
+
+.. code-block:: python
+
+ from array_api_compat import get_namespace
+ from ezmsg.sigproc.util.array import array_device, xp_create
+
+ def _process(self, message):
+ xp = get_namespace(message.data)
+ dev = array_device(message.data)
+
+Transposing matrices
+^^^^^^^^^^^^^^^^^^^^
+
+The Array API does not support ``.T``. Use ``xp.linalg.matrix_transpose()``:
+
+.. code-block:: python
+
+ # Before (numpy-only)
+ result = A.T @ B
+
+ # After (Array API)
+ _mT = xp.linalg.matrix_transpose
+ result = _mT(A) @ B
+
+Creating arrays
+^^^^^^^^^^^^^^^
+
+Use ``xp_create`` to handle device placement portably:
+
+.. code-block:: python
+
+ # Before
+ I = np.eye(n)
+ z = np.zeros((m, n), dtype=np.float64)
+
+ # After
+ I = xp_create(xp.eye, n, device=dev)
+ z = xp_create(xp.zeros, (m, n), dtype=xp.float64, device=dev)
+
+Handling sklearn boundaries
+^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+When calling into sklearn (or other NumPy-only libraries), convert at the
+boundary and convert back:
+
+.. code-block:: python
+
+ from array_api_compat import is_numpy_array
+
+ # Convert to numpy for sklearn
+ X_np = np.asarray(X) if not is_numpy_array(X) else X
+ result_np = estimator.predict(X_np)
+
+ # Convert back to source namespace
+ result = xp.asarray(result_np) if not is_numpy_array(X) else result_np
+
+Checking for NaN
+^^^^^^^^^^^^^^^^
+
+Use ``xp.isnan`` instead of ``np.isnan``:
+
+.. code-block:: python
+
+ if xp.any(xp.isnan(message.data)):
+ return
+
+Norms
+^^^^^
+
+Use ``xp.linalg.matrix_norm`` (Frobenius by default) instead of
+``np.linalg.norm`` for matrices. For vectors, use ``xp.linalg.vector_norm``.
diff --git a/docs/source/index.rst b/docs/source/index.rst
index db90cf2..e240d38 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -54,6 +54,7 @@ For general ezmsg tutorials and guides, visit `ezmsg.org
:caption: Contents:
guides/classification
+ guides/array_api
api/index
diff --git a/pyproject.toml b/pyproject.toml
index b175a1a..315ec87 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -9,8 +9,9 @@ license = "MIT"
requires-python = ">=3.10.15"
dynamic = ["version"]
dependencies = [
- "ezmsg-baseproc>=1.4.0",
- "ezmsg-sigproc>=2.15.0",
+ "ezmsg>=3.7.3",
+ "ezmsg-baseproc>=1.5.1",
+ "ezmsg-sigproc>=2.17.0",
"river>=0.22.0",
"scikit-learn>=1.6.0",
"torch>=2.6.0",
@@ -73,5 +74,4 @@ known-third-party = ["ezmsg", "ezmsg.baseproc", "ezmsg.sigproc"]
[tool.uv.sources]
# Uncomment to use development version of ezmsg from git
-#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }
-#ezmsg-sigproc = { path = "../ezmsg-sigproc", editable = true }
\ No newline at end of file
+#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" }
\ No newline at end of file
diff --git a/src/ezmsg/learn/dim_reduce/adaptive_decomp.py b/src/ezmsg/learn/dim_reduce/adaptive_decomp.py
index 39bfac9..3803365 100644
--- a/src/ezmsg/learn/dim_reduce/adaptive_decomp.py
+++ b/src/ezmsg/learn/dim_reduce/adaptive_decomp.py
@@ -1,7 +1,18 @@
+"""Adaptive decomposition transformers (PCA, NMF).
+
+.. note::
+ This module supports the Array API standard via
+ ``array_api_compat.get_namespace()``. Reshaping and output allocation
+ use Array API operations; a NumPy boundary is applied before sklearn
+ ``partial_fit``/``transform`` calls.
+"""
+
+import math
import typing
import ezmsg.core as ez
import numpy as np
+from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
@@ -128,6 +139,8 @@ def _process(self, message: AxisArray) -> AxisArray:
if in_dat.shape[ax_idx] == 0:
return self._state.template
+ xp = get_namespace(in_dat)
+
# Re-order axes
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
if message.dims != sorted_dims_exp:
@@ -137,16 +150,20 @@ def _process(self, message: AxisArray) -> AxisArray:
pass
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
- d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
- in_dat = in_dat.reshape((-1, d2))
+ d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
+ in_dat = xp.reshape(in_dat, (-1, d2))
replace_kwargs = {
"axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]},
}
- # Transform data
+ # Transform data — sklearn needs numpy
if hasattr(self._state.estimator, "components_"):
- decomp_dat = self._state.estimator.transform(in_dat).reshape((-1,) + self._state.template.data.shape[1:])
+ in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
+ decomp_dat = self._state.estimator.transform(in_np)
+ # Convert back to source namespace
+ decomp_dat = xp.asarray(decomp_dat) if not is_numpy_array(in_dat) else decomp_dat
+ decomp_dat = xp.reshape(decomp_dat, (-1,) + self._state.template.data.shape[1:])
replace_kwargs["data"] = decomp_dat
return replace(self._state.template, **replace_kwargs)
@@ -165,6 +182,8 @@ def partial_fit(self, message: AxisArray) -> None:
if in_dat.shape[ax_idx] == 0:
return
+ xp = get_namespace(in_dat)
+
# Re-order axes if needed
sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes
if message.dims != sorted_dims_exp:
@@ -172,11 +191,12 @@ def partial_fit(self, message: AxisArray) -> None:
pass
# fold [iter_axis] + off_targ_axes together and fold targ_axes together
- d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :])
- in_dat = in_dat.reshape((-1, d2))
+ d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :])
+ in_dat = xp.reshape(in_dat, (-1, d2))
- # Fit the estimator
- self._state.estimator.partial_fit(in_dat)
+ # Fit the estimator — sklearn needs numpy
+ in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat
+ self._state.estimator.partial_fit(in_np)
class IncrementalPCASettings(AdaptiveDecompSettings):
diff --git a/src/ezmsg/learn/model/cca.py b/src/ezmsg/learn/model/cca.py
index 6a202b2..7ed98af 100644
--- a/src/ezmsg/learn/model/cca.py
+++ b/src/ezmsg/learn/model/cca.py
@@ -1,5 +1,29 @@
+"""Incremental Canonical Correlation Analysis (CCA).
+
+.. note::
+ This module supports the Array API standard via
+ ``array_api_compat.get_namespace()``. All linear algebra uses Array API
+ operations; ``scipy.linalg.sqrtm`` is replaced by an eigendecomposition-
+ based inverse square root (:func:`_inv_sqrtm_spd`).
+"""
+
import numpy as np
-from scipy import linalg
+from array_api_compat import get_namespace
+from ezmsg.sigproc.util.array import array_device, xp_create
+
+
+def _inv_sqrtm_spd(xp, A):
+ """Inverse matrix square root for symmetric positive-definite matrices.
+
+ Computes ``inv(sqrtm(A)) = Q @ diag(1/sqrt(lambda)) @ Q^T`` using the
+ eigendecomposition. This is more numerically stable than computing
+ ``inv(sqrtm(...))`` separately and uses only Array API operations.
+ """
+ eigenvalues, eigenvectors = xp.linalg.eigh(A)
+ eigenvalues = xp.clip(eigenvalues, 1e-12, None) # avoid div-by-zero
+ inv_sqrt_eig = 1.0 / xp.sqrt(eigenvalues)
+ # Q @ diag(v) == Q * v (broadcasting), then @ Q^T
+ return (eigenvectors * inv_sqrt_eig) @ xp.linalg.matrix_transpose(eigenvectors)
class IncrementalCCA:
@@ -33,39 +57,52 @@ def __init__(
self.adaptation_rate = adaptation_rate
self.initialized = False
- def initialize(self, d1, d2):
- """Initialize the necessary matrices"""
+ def initialize(self, d1, d2, *, ref_array=None):
+ """Initialize the necessary matrices.
+
+ Args:
+ d1: Dimensionality of the first dataset.
+ d2: Dimensionality of the second dataset.
+ ref_array: Optional reference array to derive array namespace
+ and device from. If ``None``, defaults to NumPy.
+ """
self.d1 = d1
self.d2 = d2
+ if ref_array is not None:
+ xp = get_namespace(ref_array)
+ dev = array_device(ref_array)
+ else:
+ xp, dev = np, None
+
# Initialize correlation matrices
- self.C11 = np.zeros((d1, d1))
- self.C22 = np.zeros((d2, d2))
- self.C12 = np.zeros((d1, d2))
+ self.C11 = xp_create(xp.zeros, (d1, d1), dtype=xp.float64, device=dev)
+ self.C22 = xp_create(xp.zeros, (d2, d2), dtype=xp.float64, device=dev)
+ self.C12 = xp_create(xp.zeros, (d1, d2), dtype=xp.float64, device=dev)
self.initialized = True
def _compute_change_magnitude(self, C11_new, C22_new, C12_new):
- """Compute magnitude of change in correlation structure"""
+ """Compute magnitude of change in correlation structure."""
+ xp = get_namespace(self.C11)
+
# Frobenius norm of differences
- diff11 = np.linalg.norm(C11_new - self.C11)
- diff22 = np.linalg.norm(C22_new - self.C22)
- diff12 = np.linalg.norm(C12_new - self.C12)
+ diff11 = xp.linalg.matrix_norm(C11_new - self.C11)
+ diff22 = xp.linalg.matrix_norm(C22_new - self.C22)
+ diff12 = xp.linalg.matrix_norm(C12_new - self.C12)
# Normalize by matrix sizes
- diff11 /= self.d1 * self.d1
- diff22 /= self.d2 * self.d2
- diff12 /= self.d1 * self.d2
+ diff11 = diff11 / (self.d1 * self.d1)
+ diff22 = diff22 / (self.d2 * self.d2)
+ diff12 = diff12 / (self.d1 * self.d2)
- return (diff11 + diff22 + diff12) / 3
+ return float((diff11 + diff22 + diff12) / 3)
def _adapt_smoothing(self, change_magnitude):
- """Adapt smoothing factor based on detected changes"""
+ """Adapt smoothing factor based on detected changes."""
# If change is large, decrease smoothing factor
target_smoothing = self.base_smoothing * (1.0 - change_magnitude)
- target_smoothing = np.clip(
- target_smoothing, self.min_smoothing, self.max_smoothing
- )
+ target_smoothing = max(self.min_smoothing, min(target_smoothing, self.max_smoothing))
# Smooth the adaptation itself
self.current_smoothing = (
@@ -73,18 +110,21 @@ def _adapt_smoothing(self, change_magnitude):
) * self.current_smoothing + self.adaptation_rate * target_smoothing
def partial_fit(self, X1, X2, update_projections=True):
- """Update the model with new samples using adaptive smoothing
- Assumes X1 and X2 are already centered and scaled"""
+ """Update the model with new samples using adaptive smoothing.
+ Assumes X1 and X2 are already centered and scaled."""
+ xp = get_namespace(X1, X2)
+ _mT = xp.linalg.matrix_transpose
+
if not self.initialized:
- self.initialize(X1.shape[1], X2.shape[1])
+ self.initialize(X1.shape[1], X2.shape[1], ref_array=X1)
# Compute new correlation matrices from current batch
- C11_new = X1.T @ X1 / X1.shape[0]
- C22_new = X2.T @ X2 / X2.shape[0]
- C12_new = X1.T @ X2 / X1.shape[0]
+ C11_new = _mT(X1) @ X1 / X1.shape[0]
+ C22_new = _mT(X2) @ X2 / X2.shape[0]
+ C12_new = _mT(X1) @ X2 / X1.shape[0]
# Detect changes and adapt smoothing factor
- if self.C11.any(): # Skip first update
+ if bool(xp.any(self.C11 != 0)): # Skip first update
change_magnitude = self._compute_change_magnitude(C11_new, C22_new, C12_new)
self._adapt_smoothing(change_magnitude)
@@ -98,25 +138,26 @@ def partial_fit(self, X1, X2, update_projections=True):
self._update_projections()
def _update_projections(self):
- """Update canonical vectors and correlations"""
+ """Update canonical vectors and correlations."""
+ xp = get_namespace(self.C11)
+ dev = array_device(self.C11)
+ _mT = xp.linalg.matrix_transpose
+
eps = 1e-8
- C11_reg = self.C11 + eps * np.eye(self.d1)
- C22_reg = self.C22 + eps * np.eye(self.d2)
+ C11_reg = self.C11 + eps * xp_create(xp.eye, self.d1, dtype=self.C11.dtype, device=dev)
+ C22_reg = self.C22 + eps * xp_create(xp.eye, self.d2, dtype=self.C22.dtype, device=dev)
+
+ inv_sqrt_C11 = _inv_sqrtm_spd(xp, C11_reg)
+ inv_sqrt_C22 = _inv_sqrtm_spd(xp, C22_reg)
- K = (
- linalg.inv(linalg.sqrtm(C11_reg))
- @ self.C12
- @ linalg.inv(linalg.sqrtm(C22_reg))
- )
- U, self.correlations_, V = linalg.svd(K)
+ K = inv_sqrt_C11 @ self.C12 @ inv_sqrt_C22
+ U, self.correlations_, Vh = xp.linalg.svd(K, full_matrices=False)
- self.x_weights_ = linalg.inv(linalg.sqrtm(C11_reg)) @ U[:, : self.n_components]
- self.y_weights_ = (
- linalg.inv(linalg.sqrtm(C22_reg)) @ V.T[:, : self.n_components]
- )
+ self.x_weights_ = inv_sqrt_C11 @ U[:, : self.n_components]
+ self.y_weights_ = inv_sqrt_C22 @ _mT(Vh)[:, : self.n_components]
def transform(self, X1, X2):
- """Project data onto canonical components"""
+ """Project data onto canonical components."""
X1_proj = X1 @ self.x_weights_
X2_proj = X2 @ self.y_weights_
return X1_proj, X2_proj
diff --git a/src/ezmsg/learn/model/refit_kalman.py b/src/ezmsg/learn/model/refit_kalman.py
index 4481c17..3d7124d 100644
--- a/src/ezmsg/learn/model/refit_kalman.py
+++ b/src/ezmsg/learn/model/refit_kalman.py
@@ -1,6 +1,17 @@
# refit_kalman.py
+"""Refit Kalman filter for adaptive neural decoding.
+
+.. note::
+ This module supports the Array API standard via
+ ``array_api_compat.get_namespace()``. All linear algebra in :meth:`fit`,
+ :meth:`predict`, and :meth:`update` stays in the source array namespace.
+ The DARE solver in :meth:`_compute_gain` and the per-sample mutation loop
+ in :meth:`refit` use NumPy regardless of input backend.
+"""
import numpy as np
+from array_api_compat import get_namespace
+from ezmsg.sigproc.util.array import array_device, xp_asarray, xp_create
from numpy.linalg import LinAlgError
from scipy.linalg import solve_discrete_are
@@ -108,19 +119,22 @@ def fit(self, X_train, y_train):
"""
# self._validate_state_vector(y_train)
- X = np.array(y_train)
- Z = np.array(X_train)
+ xp = get_namespace(X_train, y_train)
+ _mT = xp.linalg.matrix_transpose
+
+ X = xp.asarray(y_train)
+ Z = xp.asarray(X_train)
n_samples = X.shape[0]
# Calculate the transition matrix (from x_t to x_t+1) using least-squares
X2 = X[1:, :] # x_{t+1}
X1 = X[:-1, :] # x_t
- A = X2.T @ X1 @ np.linalg.inv(X1.T @ X1) # Transition matrix
- W = (X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1) # Covariance of transition matrix
+ A = _mT(X2) @ X1 @ xp.linalg.inv(_mT(X1) @ X1) # Transition matrix
+ W = _mT(X2 - X1 @ _mT(A)) @ (X2 - X1 @ _mT(A)) / (n_samples - 1) # Covariance of transition matrix
# Calculate the measurement matrix (from x_t to z_t) using least-squares
- H = Z.T @ X @ np.linalg.inv(X.T @ X) # Measurement matrix
- Q = (Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0] # Covariance of measurement matrix
+ H = _mT(Z) @ X @ xp.linalg.inv(_mT(X) @ X) # Measurement matrix
+ Q = _mT(Z - X @ _mT(H)) @ (Z - X @ _mT(H)) / Z.shape[0] # Covariance of measurement matrix
self.A_state_transition_matrix = A
self.W_process_noise_covariance = W * self.process_noise_scale
@@ -132,12 +146,12 @@ def fit(self, X_train, y_train):
def refit(
self,
- X_neural: np.ndarray,
- Y_state: np.ndarray,
+ X_neural,
+ Y_state,
intention_velocity_indices: int | None = None,
- target_positions: np.ndarray | None = None,
- cursor_positions: np.ndarray | None = None,
- hold_indices: np.ndarray | None = None,
+ target_positions=None,
+ cursor_positions=None,
+ hold_indices=None,
):
"""
Refit the observation model based on new data.
@@ -179,13 +193,17 @@ def refit(
else:
vel_idx = intention_velocity_indices
+ # The per-sample mutation loop uses numpy for element-wise operations
+ # on small vectors (np.linalg.norm on 2-element vectors, scalar indexing).
+ Y_state_np = np.asarray(Y_state)
+ target_positions_np = np.asarray(target_positions) if target_positions is not None else None
+ cursor_positions_np = np.asarray(cursor_positions) if cursor_positions is not None else None
+
# Only remap velocity if target and cursor positions are provided
- if target_positions is None or cursor_positions is None:
- intended_states = Y_state.copy()
- else:
- intended_states = Y_state.copy()
+ intended_states = Y_state_np.copy()
+ if target_positions_np is not None and cursor_positions_np is not None:
# Calculate intended velocities for each sample
- for i, (state, pos, target) in enumerate(zip(Y_state, cursor_positions, target_positions)):
+ for i, (state, pos, target) in enumerate(zip(Y_state_np, cursor_positions_np, target_positions_np)):
is_hold = hold_indices[i] if hold_indices is not None else False
if is_hold:
@@ -213,14 +231,19 @@ def refit(
else:
intended_states[i, vel_idx : vel_idx + 2] = state[vel_idx : vel_idx + 2]
- intended_states = np.array(intended_states)
- Z = np.array(X_neural)
+ # Convert back to source namespace for final linalg
+ xp = get_namespace(X_neural)
+ dev = array_device(X_neural)
+ _mT = xp.linalg.matrix_transpose
+
+ intended_states = xp_asarray(xp, intended_states, device=dev)
+ Z = xp.asarray(X_neural)
# Recalculate observation matrix and noise covariance
H = (
- Z.T @ intended_states @ np.linalg.pinv(intended_states.T @ intended_states)
+ _mT(Z) @ intended_states @ xp.linalg.pinv(_mT(intended_states) @ intended_states)
) # Using pinv() instead of inv() to avoid singular matrix errors
- Q = (Z - intended_states @ H.T).T @ (Z - intended_states @ H.T) / Z.shape[0]
+ Q = _mT(Z - intended_states @ _mT(H)) @ (Z - intended_states @ _mT(H)) / Z.shape[0]
self.H_observation_matrix = H
self.Q_measurement_noise_covariance = Q
@@ -236,56 +259,44 @@ def _compute_gain(self):
Riccati equation to find the optimal steady-state gain. In non-steady-state
mode, it computes the gain using the current covariance matrix.
+ The DARE solver requires NumPy arrays; results are converted back to the
+ source array namespace.
+
Raises:
LinAlgError: If the Riccati equation cannot be solved or matrix operations fail.
"""
- # TODO: consider removing non-steady-state for compute_gain() -
- # non_steady_state updates will occur during predict() and update()
- # if self.steady_state:
+ xp = get_namespace(self.A_state_transition_matrix)
+ dev = array_device(self.A_state_transition_matrix)
+ _mT = xp.linalg.matrix_transpose
+
+ # Convert to numpy for DARE (no Array API equivalent)
+ A_np = np.asarray(self.A_state_transition_matrix)
+ H_np = np.asarray(self.H_observation_matrix)
+ W_np = np.asarray(self.W_process_noise_covariance)
+ Q_np = np.asarray(self.Q_measurement_noise_covariance)
+
try:
- # Try with original matrices
- self.P_state_covariance = solve_discrete_are(
- self.A_state_transition_matrix.T,
- self.H_observation_matrix.T,
- self.W_process_noise_covariance,
- self.Q_measurement_noise_covariance,
- )
- self.K_kalman_gain = (
- self.P_state_covariance
- @ self.H_observation_matrix.T
- @ np.linalg.inv(
- self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T
- + self.Q_measurement_noise_covariance
- )
+ P_np = solve_discrete_are(A_np.T, H_np.T, W_np, Q_np)
+ self.P_state_covariance = xp_asarray(xp, P_np, device=dev)
+ S = (
+ self.H_observation_matrix @ self.P_state_covariance @ _mT(self.H_observation_matrix)
+ + self.Q_measurement_noise_covariance
)
+ self.K_kalman_gain = self.P_state_covariance @ _mT(self.H_observation_matrix) @ xp.linalg.inv(S)
except LinAlgError:
- # Apply regularization and retry
- # A_reg = self.A_state_transition_matrix * 0.999 # Slight damping
- # W_reg = self.W_process_noise_covariance + 1e-7 * np.eye(
- # self.W_process_noise_covariance.shape[0]
- # )
- Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye(self.Q_measurement_noise_covariance.shape[0])
+ Q_reg_np = Q_np + 1e-7 * np.eye(Q_np.shape[0])
try:
- self.P_state_covariance = solve_discrete_are(
- self.A_state_transition_matrix.T,
- self.H_observation_matrix.T,
- self.W_process_noise_covariance,
- Q_reg,
- )
- self.K_kalman_gain = (
- self.P_state_covariance
- @ self.H_observation_matrix.T
- @ np.linalg.inv(
- self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T + Q_reg
- )
- )
+ P_np = solve_discrete_are(A_np.T, H_np.T, W_np, Q_reg_np)
+ self.P_state_covariance = xp_asarray(xp, P_np, device=dev)
+ Q_reg = xp_asarray(xp, Q_reg_np, device=dev)
+ S = self.H_observation_matrix @ self.P_state_covariance @ _mT(self.H_observation_matrix) + Q_reg
+ self.K_kalman_gain = self.P_state_covariance @ _mT(self.H_observation_matrix) @ xp.linalg.inv(S)
print("Warning: Used regularized matrices for DARE solution")
except LinAlgError:
# Fallback to identity or manual initialization
print("Warning: DARE failed, using identity covariance")
- self.P_state_covariance = np.eye(self.A_state_transition_matrix.shape[0])
-
+ self.P_state_covariance = xp_create(xp.eye, self.A_state_transition_matrix.shape[0], device=dev)
# else:
# n_states = self.A_state_transition_matrix.shape[0]
# self.P_state_covariance = (
@@ -311,29 +322,35 @@ def _compute_gain(self):
# I_mat - self.K_kalman_gain @ self.H_observation_matrix
# ) @ P_m
- def predict(self, x_current: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+ def predict(self, x_current):
"""
Predict the next state and covariance.
This method predicts the next state and covariance using the current state.
"""
+ xp = get_namespace(x_current)
+ _mT = xp.linalg.matrix_transpose
+
x_predicted = self.A_state_transition_matrix @ x_current
if self.steady_state is True:
return x_predicted, None
else:
P_predicted = self.alpha_fading_memory**2 * (
- self.A_state_transition_matrix @ self.P_state_covariance @ self.A_state_transition_matrix.T
+ self.A_state_transition_matrix @ self.P_state_covariance @ _mT(self.A_state_transition_matrix)
+ self.W_process_noise_covariance
)
return x_predicted, P_predicted
def update(
self,
- z_measurement: np.ndarray,
- x_predicted: np.ndarray,
- P_predicted: np.ndarray | None = None,
- ) -> np.ndarray:
+ z_measurement,
+ x_predicted,
+ P_predicted=None,
+ ):
"""Update state estimate and covariance based on measurement z."""
+ xp = get_namespace(z_measurement, x_predicted)
+ dev = array_device(x_predicted)
+ _mT = xp.linalg.matrix_transpose
# Compute residual
innovation = z_measurement - self.H_observation_matrix @ x_predicted
@@ -347,19 +364,23 @@ def update(
# Non-steady-state mode
# System uncertainty
- S = self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T + self.Q_measurement_noise_covariance
+ S = (
+ self.H_observation_matrix @ P_predicted @ _mT(self.H_observation_matrix)
+ + self.Q_measurement_noise_covariance
+ )
# Kalman gain
- K = P_predicted @ self.H_observation_matrix.T @ np.linalg.pinv(S)
+ K = P_predicted @ _mT(self.H_observation_matrix) @ xp.linalg.pinv(S)
# Updated state
x_updated = x_predicted + K @ innovation
# Covariance update
- I_mat = np.eye(self.A_state_transition_matrix.shape[0])
- P_updated = (I_mat - K @ self.H_observation_matrix) @ P_predicted @ (
+ n = self.A_state_transition_matrix.shape[0]
+ I_mat = xp_create(xp.eye, n, device=dev)
+ P_updated = (I_mat - K @ self.H_observation_matrix) @ P_predicted @ _mT(
I_mat - K @ self.H_observation_matrix
- ).T + K @ self.Q_measurement_noise_covariance @ K.T
+ ) + K @ self.Q_measurement_noise_covariance @ _mT(K)
# Save updated values
self.P_state_covariance = P_updated
diff --git a/src/ezmsg/learn/process/adaptive_linear_regressor.py b/src/ezmsg/learn/process/adaptive_linear_regressor.py
index 3e00909..6ca65b1 100644
--- a/src/ezmsg/learn/process/adaptive_linear_regressor.py
+++ b/src/ezmsg/learn/process/adaptive_linear_regressor.py
@@ -1,3 +1,12 @@
+"""Adaptive linear regressor processor.
+
+.. note::
+ This module supports the Array API standard via
+ ``array_api_compat.get_namespace()``. NaN checks and axis permutations
+ use Array API operations; a NumPy boundary is applied before sklearn
+ ``partial_fit``/``predict`` and before river ``learn_many``/``predict_many``.
+"""
+
from dataclasses import field
import ezmsg.core as ez
@@ -6,6 +15,7 @@
import river.linear_model
import river.optim
import sklearn.base
+from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
@@ -78,24 +88,32 @@ def _reset_state(self, message: AxisArray) -> None:
pass
def partial_fit(self, message: AxisArray) -> None:
- if np.any(np.isnan(message.data)):
+ xp = get_namespace(message.data)
+
+ if xp.any(xp.isnan(message.data)):
return
if self.settings.model_type in [
AdaptiveLinearRegressor.LINEAR,
AdaptiveLinearRegressor.LOGISTIC,
]:
- x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
+ # river path: needs numpy/pandas
+ data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data
+ x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, data_np.T)})
y = pd.Series(
data=message.attrs["trigger"].value.data[:, 0],
name=message.attrs["trigger"].value.axes["ch"].data[0],
)
self.state.model.learn_many(x, y)
else:
+ # sklearn path: permute then convert to numpy
X = message.data
- if message.get_axis_idx("time") != 0:
- X = np.moveaxis(X, message.get_axis_idx("time"), 0)
- self.state.model.partial_fit(X, message.attrs["trigger"].value.data)
+ ax_idx = message.get_axis_idx("time")
+ if ax_idx != 0:
+ perm = (ax_idx,) + tuple(i for i in range(X.ndim) if i != ax_idx)
+ X = xp.permute_dims(X, perm)
+ X_np = np.asarray(X) if not is_numpy_array(X) else X
+ self.state.model.partial_fit(X_np, message.attrs["trigger"].value.data)
self.state.template = replace(
message.attrs["trigger"].value,
@@ -107,16 +125,21 @@ def _process(self, message: AxisArray) -> AxisArray | None:
if self.state.template is None:
return AxisArray(np.array([]), dims=[""])
- if not np.any(np.isnan(message.data)):
+ xp = get_namespace(message.data)
+
+ if not xp.any(xp.isnan(message.data)):
if self.settings.model_type in [
AdaptiveLinearRegressor.LINEAR,
AdaptiveLinearRegressor.LOGISTIC,
]:
- # convert msg_in.data to something appropriate for river
- x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)})
+ # river path: needs numpy/pandas
+ data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data
+ x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, data_np.T)})
preds = self.state.model.predict_many(x).values
else:
- preds = self.state.model.predict(message.data)
+ # sklearn path: needs numpy
+ data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data
+ preds = self.state.model.predict(data_np)
return replace(
self.state.template,
data=preds.reshape((len(preds), -1)),
diff --git a/src/ezmsg/learn/process/refit_kalman.py b/src/ezmsg/learn/process/refit_kalman.py
index 6684bc9..62ae835 100644
--- a/src/ezmsg/learn/process/refit_kalman.py
+++ b/src/ezmsg/learn/process/refit_kalman.py
@@ -1,13 +1,24 @@
+"""Refit Kalman filter processor for ezmsg.
+
+.. note::
+ This module supports the Array API standard via
+ ``array_api_compat.get_namespace()``. State initialization and output
+ arrays use the source data's namespace. Buffered data for refitting
+ is stacked via ``xp.stack``.
+"""
+
import pickle
from pathlib import Path
import ezmsg.core as ez
import numpy as np
+from array_api_compat import get_namespace
from ezmsg.baseproc import (
BaseAdaptiveTransformer,
BaseAdaptiveTransformerUnit,
processor_state,
)
+from ezmsg.sigproc.util.array import array_device, xp_create
from ezmsg.util.messages.axisarray import AxisArray
from ezmsg.util.messages.util import replace
@@ -54,8 +65,8 @@ class RefitKalmanFilterState:
"""
model: RefitKalmanFilter | None = None
- x: np.ndarray | None = None
- P: np.ndarray | None = None
+ x: object | None = None # Array API; namespace matches source data.
+ P: object | None = None # Array API; namespace matches source data.
buffer_neural: list | None = None
buffer_state: list | None = None
@@ -142,7 +153,7 @@ def _init_model(self, **kwargs):
config.update(kwargs)
self._state.model = RefitKalmanFilter(**config)
- def fit(self, X: np.ndarray, y: np.ndarray) -> None:
+ def fit(self, X, y) -> None:
if self._state.model is None:
self._init_model()
if hasattr(self._state.model, "fit"):
@@ -176,6 +187,11 @@ def save_checkpoint(self, checkpoint_path: str) -> None:
Raises:
ValueError: If the model is not initialized or has not been fitted.
+
+ .. note::
+ Checkpoint data is pickled as NumPy arrays. If model matrices live
+ on a non-NumPy backend (e.g. CuPy), pickle may fail — this is a
+ pre-existing limitation.
"""
if not self._state.model or not self._state.model.is_fitted:
raise ValueError("Cannot save checkpoint: model not fitted")
@@ -222,8 +238,15 @@ def _reset_state(
if self._state.model.A_state_transition_matrix is not None:
state_dim = self._state.model.A_state_transition_matrix.shape[0]
- self._state.x = np.zeros(state_dim)
- self._state.P = np.eye(state_dim)
+ # Derive xp/dev from message data when available; default to numpy.
+ if message is not None:
+ xp = get_namespace(message.data)
+ dev = array_device(message.data)
+ else:
+ xp, dev = np, None
+
+ self._state.x = xp_create(xp.zeros, (state_dim,), dtype=xp.float64, device=dev)
+ self._state.P = xp_create(xp.eye, state_dim, dtype=xp.float64, device=dev)
self._state.buffer_neural = []
self._state.buffer_state = []
@@ -252,15 +275,19 @@ def _process(self, message: AxisArray) -> AxisArray:
# No checkpoint means you need to initialize and fit the model
elif not self._state.model:
self._init_model()
+
+ xp = get_namespace(message.data)
+ dev = array_device(message.data)
+
state_dim = self._state.model.A_state_transition_matrix.shape[0]
if self._state.x is None:
- self._state.x = np.zeros(state_dim)
+ self._state.x = xp_create(xp.zeros, (state_dim,), dtype=xp.float64, device=dev)
- filtered_data = np.zeros(
- (
- message.data.shape[0],
- self._state.model.A_state_transition_matrix.shape[0],
- )
+ filtered_data = xp_create(
+ xp.zeros,
+ (message.data.shape[0], state_dim),
+ dtype=message.data.dtype,
+ device=dev,
)
for i in range(message.data.shape[0]):
@@ -271,9 +298,9 @@ def _process(self, message: AxisArray) -> AxisArray:
# Update
x_updated = self._state.model.update(measurement, x_pred, P_pred)
- # Store
- self._state.x = x_updated.copy()
- self._state.P = self._state.model.P_state_covariance.copy()
+ # Store — no .copy() needed, predict/update return new arrays
+ self._state.x = x_updated
+ self._state.P = self._state.model.P_state_covariance
filtered_data[i] = self._state.x
return replace(
@@ -297,7 +324,8 @@ def partial_fit(self, message: AxisArray) -> None:
if "trigger" not in message.attrs:
raise ValueError("Invalid message format for partial_fit.")
- X = np.array(message.data)
+ xp = get_namespace(message.data)
+ X = xp.asarray(message.data)
values = message.attrs["trigger"].value
if not isinstance(values, dict) or "Y_state" not in values:
@@ -305,7 +333,7 @@ def partial_fit(self, message: AxisArray) -> None:
kwargs = {
"X_neural": X,
- "Y_state": np.array(values["Y_state"]),
+ "Y_state": xp.asarray(values["Y_state"]),
}
# Optional fields
@@ -316,7 +344,7 @@ def partial_fit(self, message: AxisArray) -> None:
"hold_flags",
]:
if key in values and values[key] is not None:
- kwargs[key if key != "hold_flags" else "hold_indices"] = np.array(values[key])
+ kwargs[key if key != "hold_flags" else "hold_indices"] = xp.asarray(values[key])
# Call model refit
self._state.model.refit(**kwargs)
@@ -324,7 +352,7 @@ def partial_fit(self, message: AxisArray) -> None:
def log_for_refit(
self,
message: AxisArray,
- target_position: np.ndarray | None = None,
+ target_position=None,
hold_flag: bool | None = None,
):
"""
@@ -341,13 +369,13 @@ def log_for_refit(
hold_flag: Boolean flag indicating if this is a hold period.
"""
if target_position is not None:
- self._state.buffer_target_positions.append(target_position.copy())
+ self._state.buffer_target_positions.append(target_position)
if hold_flag is not None:
self._state.buffer_hold_flags.append(hold_flag)
measurement = message.data[-1]
- self._state.buffer_neural.append(measurement.copy())
- self._state.buffer_state.append(self._state.x.copy())
+ self._state.buffer_neural.append(measurement)
+ self._state.buffer_state.append(self._state.x)
def refit_model(self):
"""
@@ -370,17 +398,19 @@ def refit_model(self):
print("No buffered data to refit")
return
+ xp = get_namespace(self._state.buffer_neural[0])
+
kwargs = {
- "X_neural": np.array(self._state.buffer_neural),
- "Y_state": np.array(self._state.buffer_state),
+ "X_neural": xp.stack(self._state.buffer_neural),
+ "Y_state": xp.stack(self._state.buffer_state),
"intention_velocity_indices": self.settings.velocity_indices[0],
}
if self._state.buffer_target_positions and self._state.buffer_cursor_positions:
- kwargs["target_positions"] = np.array(self._state.buffer_target_positions)
- kwargs["cursor_positions"] = np.array(self._state.buffer_cursor_positions)
+ kwargs["target_positions"] = xp.stack(self._state.buffer_target_positions)
+ kwargs["cursor_positions"] = xp.stack(self._state.buffer_cursor_positions)
if self._state.buffer_hold_flags:
- kwargs["hold_indices"] = np.array(self._state.buffer_hold_flags)
+ kwargs["hold_indices"] = xp.asarray(self._state.buffer_hold_flags)
self._state.model.refit(**kwargs)
diff --git a/src/ezmsg/learn/process/slda.py b/src/ezmsg/learn/process/slda.py
index f3825a0..91a24b9 100644
--- a/src/ezmsg/learn/process/slda.py
+++ b/src/ezmsg/learn/process/slda.py
@@ -1,7 +1,17 @@
+"""Shrinkage LDA classifier processor.
+
+.. note::
+ This module supports the Array API standard via
+ ``array_api_compat.get_namespace()``. Input data is manipulated using
+ Array API operations (``permute_dims``, ``reshape``); a NumPy boundary
+ is applied before ``sklearn.predict_proba``.
+"""
+
import typing
import ezmsg.core as ez
import numpy as np
+from array_api_compat import get_namespace, is_numpy_array
from ezmsg.baseproc import (
BaseStatefulTransformer,
BaseTransformerUnit,
@@ -72,23 +82,29 @@ def _reset_state(self, message: AxisArray) -> None:
)
def _process(self, message: AxisArray) -> ClassifierMessage:
+ xp = get_namespace(message.data)
samp_ax_idx = message.dims.index(self.settings.axis)
- X = np.moveaxis(message.data, samp_ax_idx, 0)
+
+ # Move sample axis to front
+ perm = (samp_ax_idx,) + tuple(i for i in range(message.data.ndim) if i != samp_ax_idx)
+ X = xp.permute_dims(message.data, perm)
if X.shape[0]:
if isinstance(self.settings.settings_path, str) and self.settings.settings_path[-4:] == ".mat":
- # Assumes F-contiguous weights
+ # Assumes F-contiguous weights — need numpy for predict_proba
+ X_np = np.asarray(X) if not is_numpy_array(X) else X
pred_probas = []
- for samp in X:
+ for samp in X_np:
tmp = samp.flatten(order="F") * 1e-6
tmp = np.expand_dims(tmp, axis=0)
probas = self.state.lda.predict_proba(tmp)
pred_probas.append(probas)
pred_probas = np.concatenate(pred_probas, axis=0)
else:
- # This creates a copy.
- X = X.reshape(X.shape[0], -1)
- pred_probas = self.state.lda.predict_proba(X)
+ # Numpy boundary before sklearn predict_proba
+ X_np = np.asarray(X) if not is_numpy_array(X) else X
+ X_np = X_np.reshape(X_np.shape[0], -1)
+ pred_probas = self.state.lda.predict_proba(X_np)
update_ax = self.state.out_template.axes[self.settings.axis]
update_ax.offset = message.axes[self.settings.axis].offset
diff --git a/tests/unit/test_cca.py b/tests/unit/test_cca.py
new file mode 100644
index 0000000..a4ffdb5
--- /dev/null
+++ b/tests/unit/test_cca.py
@@ -0,0 +1,125 @@
+import numpy as np
+import pytest
+
+from ezmsg.learn.model.cca import IncrementalCCA, _inv_sqrtm_spd
+
+
+@pytest.fixture
+def rng():
+ return np.random.default_rng(42)
+
+
+@pytest.fixture
+def synthetic_data(rng):
+ """Create correlated synthetic data for CCA testing."""
+ n_samples, d1, d2 = 100, 5, 4
+ # Shared latent factor
+ latent = rng.standard_normal((n_samples, 2))
+ X1 = latent @ rng.standard_normal((2, d1)) + 0.1 * rng.standard_normal((n_samples, d1))
+ X2 = latent @ rng.standard_normal((2, d2)) + 0.1 * rng.standard_normal((n_samples, d2))
+ return X1, X2
+
+
+def test_initialize():
+ """Test that initialize creates matrices with correct shapes."""
+ cca = IncrementalCCA(n_components=2)
+ cca.initialize(5, 4)
+
+ assert cca.C11.shape == (5, 5)
+ assert cca.C22.shape == (4, 4)
+ assert cca.C12.shape == (5, 4)
+ assert cca.initialized is True
+ assert cca.d1 == 5
+ assert cca.d2 == 4
+
+
+def test_initialize_with_ref_array():
+ """Test that initialize respects ref_array namespace."""
+ ref = np.zeros(3)
+ cca = IncrementalCCA(n_components=2)
+ cca.initialize(5, 4, ref_array=ref)
+
+ assert cca.C11.shape == (5, 5)
+ assert cca.C11.dtype == np.float64
+
+
+def test_partial_fit(synthetic_data):
+ """Test that partial_fit runs without error and updates covariance."""
+ X1, X2 = synthetic_data
+ cca = IncrementalCCA(n_components=2)
+ cca.partial_fit(X1, X2)
+
+ assert cca.initialized is True
+ assert not np.allclose(cca.C11, 0)
+ assert not np.allclose(cca.C22, 0)
+ assert not np.allclose(cca.C12, 0)
+ assert hasattr(cca, "x_weights_")
+ assert hasattr(cca, "y_weights_")
+
+
+def test_partial_fit_incremental(synthetic_data):
+ """Test that multiple partial_fit calls update smoothly."""
+ X1, X2 = synthetic_data
+ cca = IncrementalCCA(n_components=2)
+
+ # First fit
+ cca.partial_fit(X1[:50], X2[:50])
+ C11_first = cca.C11.copy()
+
+ # Second fit should change covariance
+ cca.partial_fit(X1[50:], X2[50:])
+ assert not np.allclose(C11_first, cca.C11)
+
+
+def test_transform(synthetic_data):
+ """Test that transform produces correct output shapes."""
+ X1, X2 = synthetic_data
+ cca = IncrementalCCA(n_components=2)
+ cca.partial_fit(X1, X2)
+
+ X1_proj, X2_proj = cca.transform(X1, X2)
+ assert X1_proj.shape == (100, 2)
+ assert X2_proj.shape == (100, 2)
+
+
+def test_transform_single_component(synthetic_data):
+ """Test with a single canonical component."""
+ X1, X2 = synthetic_data
+ cca = IncrementalCCA(n_components=1)
+ cca.partial_fit(X1, X2)
+
+ X1_proj, X2_proj = cca.transform(X1, X2)
+ assert X1_proj.shape == (100, 1)
+ assert X2_proj.shape == (100, 1)
+
+
+def test_numerical_equivalence_inv_sqrtm():
+ """Compare eigh-based _inv_sqrtm_spd with scipy.linalg on known SPD matrices."""
+ scipy_linalg = pytest.importorskip("scipy.linalg")
+
+ rng = np.random.default_rng(123)
+ # Create a known SPD matrix
+ A_raw = rng.standard_normal((5, 5))
+ A = A_raw.T @ A_raw + 0.1 * np.eye(5) # Ensure SPD
+
+ # scipy reference
+ sqrtm_A = scipy_linalg.sqrtm(A)
+ inv_sqrtm_scipy = scipy_linalg.inv(np.real(sqrtm_A))
+
+ # Our implementation
+ inv_sqrtm_eigh = _inv_sqrtm_spd(np, A)
+
+ np.testing.assert_allclose(inv_sqrtm_eigh, inv_sqrtm_scipy, atol=1e-10)
+
+
+def test_correlations_are_computed(synthetic_data):
+ """Test that canonical correlations are stored after fit."""
+ X1, X2 = synthetic_data
+ cca = IncrementalCCA(n_components=2)
+ cca.partial_fit(X1, X2)
+
+ assert hasattr(cca, "correlations_")
+ assert len(cca.correlations_) >= 2
+ # Correlations should be between 0 and 1 for well-conditioned data
+ assert np.all(cca.correlations_[:2] >= 0)
+ assert np.all(cca.correlations_[:2] <= 1.0 + 1e-6)