From 10ae3d3fcda09c15121d3985cfcee21371ec2733 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Sat, 7 Feb 2026 00:19:34 -0500 Subject: [PATCH 1/3] Make as much processing as possible use Array API --- src/ezmsg/learn/dim_reduce/adaptive_decomp.py | 36 +++- src/ezmsg/learn/model/cca.py | 119 ++++++++----- src/ezmsg/learn/model/refit_kalman.py | 159 ++++++++++-------- .../process/adaptive_linear_regressor.py | 41 ++++- src/ezmsg/learn/process/refit_kalman.py | 82 ++++++--- src/ezmsg/learn/process/slda.py | 28 ++- tests/unit/test_cca.py | 125 ++++++++++++++ 7 files changed, 433 insertions(+), 157 deletions(-) create mode 100644 tests/unit/test_cca.py diff --git a/src/ezmsg/learn/dim_reduce/adaptive_decomp.py b/src/ezmsg/learn/dim_reduce/adaptive_decomp.py index 39bfac9..3803365 100644 --- a/src/ezmsg/learn/dim_reduce/adaptive_decomp.py +++ b/src/ezmsg/learn/dim_reduce/adaptive_decomp.py @@ -1,7 +1,18 @@ +"""Adaptive decomposition transformers (PCA, NMF). + +.. note:: + This module supports the Array API standard via + ``array_api_compat.get_namespace()``. Reshaping and output allocation + use Array API operations; a NumPy boundary is applied before sklearn + ``partial_fit``/``transform`` calls. +""" + +import math import typing import ezmsg.core as ez import numpy as np +from array_api_compat import get_namespace, is_numpy_array from ezmsg.baseproc import ( BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit, @@ -128,6 +139,8 @@ def _process(self, message: AxisArray) -> AxisArray: if in_dat.shape[ax_idx] == 0: return self._state.template + xp = get_namespace(in_dat) + # Re-order axes sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes if message.dims != sorted_dims_exp: @@ -137,16 +150,20 @@ def _process(self, message: AxisArray) -> AxisArray: pass # fold [iter_axis] + off_targ_axes together and fold targ_axes together - d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :]) - in_dat = in_dat.reshape((-1, d2)) + d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :]) + in_dat = xp.reshape(in_dat, (-1, d2)) replace_kwargs = { "axes": {**self._state.template.axes, iter_axis: message.axes[iter_axis]}, } - # Transform data + # Transform data — sklearn needs numpy if hasattr(self._state.estimator, "components_"): - decomp_dat = self._state.estimator.transform(in_dat).reshape((-1,) + self._state.template.data.shape[1:]) + in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat + decomp_dat = self._state.estimator.transform(in_np) + # Convert back to source namespace + decomp_dat = xp.asarray(decomp_dat) if not is_numpy_array(in_dat) else decomp_dat + decomp_dat = xp.reshape(decomp_dat, (-1,) + self._state.template.data.shape[1:]) replace_kwargs["data"] = decomp_dat return replace(self._state.template, **replace_kwargs) @@ -165,6 +182,8 @@ def partial_fit(self, message: AxisArray) -> None: if in_dat.shape[ax_idx] == 0: return + xp = get_namespace(in_dat) + # Re-order axes if needed sorted_dims_exp = [iter_axis] + off_targ_axes + targ_axes if message.dims != sorted_dims_exp: @@ -172,11 +191,12 @@ def partial_fit(self, message: AxisArray) -> None: pass # fold [iter_axis] + off_targ_axes together and fold targ_axes together - d2 = np.prod(in_dat.shape[len(off_targ_axes) + 1 :]) - in_dat = in_dat.reshape((-1, d2)) + d2 = math.prod(in_dat.shape[len(off_targ_axes) + 1 :]) + in_dat = xp.reshape(in_dat, (-1, d2)) - # Fit the estimator - self._state.estimator.partial_fit(in_dat) + # Fit the estimator — sklearn needs numpy + in_np = np.asarray(in_dat) if not is_numpy_array(in_dat) else in_dat + self._state.estimator.partial_fit(in_np) class IncrementalPCASettings(AdaptiveDecompSettings): diff --git a/src/ezmsg/learn/model/cca.py b/src/ezmsg/learn/model/cca.py index 6a202b2..7ed98af 100644 --- a/src/ezmsg/learn/model/cca.py +++ b/src/ezmsg/learn/model/cca.py @@ -1,5 +1,29 @@ +"""Incremental Canonical Correlation Analysis (CCA). + +.. note:: + This module supports the Array API standard via + ``array_api_compat.get_namespace()``. All linear algebra uses Array API + operations; ``scipy.linalg.sqrtm`` is replaced by an eigendecomposition- + based inverse square root (:func:`_inv_sqrtm_spd`). +""" + import numpy as np -from scipy import linalg +from array_api_compat import get_namespace +from ezmsg.sigproc.util.array import array_device, xp_create + + +def _inv_sqrtm_spd(xp, A): + """Inverse matrix square root for symmetric positive-definite matrices. + + Computes ``inv(sqrtm(A)) = Q @ diag(1/sqrt(lambda)) @ Q^T`` using the + eigendecomposition. This is more numerically stable than computing + ``inv(sqrtm(...))`` separately and uses only Array API operations. + """ + eigenvalues, eigenvectors = xp.linalg.eigh(A) + eigenvalues = xp.clip(eigenvalues, 1e-12, None) # avoid div-by-zero + inv_sqrt_eig = 1.0 / xp.sqrt(eigenvalues) + # Q @ diag(v) == Q * v (broadcasting), then @ Q^T + return (eigenvectors * inv_sqrt_eig) @ xp.linalg.matrix_transpose(eigenvectors) class IncrementalCCA: @@ -33,39 +57,52 @@ def __init__( self.adaptation_rate = adaptation_rate self.initialized = False - def initialize(self, d1, d2): - """Initialize the necessary matrices""" + def initialize(self, d1, d2, *, ref_array=None): + """Initialize the necessary matrices. + + Args: + d1: Dimensionality of the first dataset. + d2: Dimensionality of the second dataset. + ref_array: Optional reference array to derive array namespace + and device from. If ``None``, defaults to NumPy. + """ self.d1 = d1 self.d2 = d2 + if ref_array is not None: + xp = get_namespace(ref_array) + dev = array_device(ref_array) + else: + xp, dev = np, None + # Initialize correlation matrices - self.C11 = np.zeros((d1, d1)) - self.C22 = np.zeros((d2, d2)) - self.C12 = np.zeros((d1, d2)) + self.C11 = xp_create(xp.zeros, (d1, d1), dtype=xp.float64, device=dev) + self.C22 = xp_create(xp.zeros, (d2, d2), dtype=xp.float64, device=dev) + self.C12 = xp_create(xp.zeros, (d1, d2), dtype=xp.float64, device=dev) self.initialized = True def _compute_change_magnitude(self, C11_new, C22_new, C12_new): - """Compute magnitude of change in correlation structure""" + """Compute magnitude of change in correlation structure.""" + xp = get_namespace(self.C11) + # Frobenius norm of differences - diff11 = np.linalg.norm(C11_new - self.C11) - diff22 = np.linalg.norm(C22_new - self.C22) - diff12 = np.linalg.norm(C12_new - self.C12) + diff11 = xp.linalg.matrix_norm(C11_new - self.C11) + diff22 = xp.linalg.matrix_norm(C22_new - self.C22) + diff12 = xp.linalg.matrix_norm(C12_new - self.C12) # Normalize by matrix sizes - diff11 /= self.d1 * self.d1 - diff22 /= self.d2 * self.d2 - diff12 /= self.d1 * self.d2 + diff11 = diff11 / (self.d1 * self.d1) + diff22 = diff22 / (self.d2 * self.d2) + diff12 = diff12 / (self.d1 * self.d2) - return (diff11 + diff22 + diff12) / 3 + return float((diff11 + diff22 + diff12) / 3) def _adapt_smoothing(self, change_magnitude): - """Adapt smoothing factor based on detected changes""" + """Adapt smoothing factor based on detected changes.""" # If change is large, decrease smoothing factor target_smoothing = self.base_smoothing * (1.0 - change_magnitude) - target_smoothing = np.clip( - target_smoothing, self.min_smoothing, self.max_smoothing - ) + target_smoothing = max(self.min_smoothing, min(target_smoothing, self.max_smoothing)) # Smooth the adaptation itself self.current_smoothing = ( @@ -73,18 +110,21 @@ def _adapt_smoothing(self, change_magnitude): ) * self.current_smoothing + self.adaptation_rate * target_smoothing def partial_fit(self, X1, X2, update_projections=True): - """Update the model with new samples using adaptive smoothing - Assumes X1 and X2 are already centered and scaled""" + """Update the model with new samples using adaptive smoothing. + Assumes X1 and X2 are already centered and scaled.""" + xp = get_namespace(X1, X2) + _mT = xp.linalg.matrix_transpose + if not self.initialized: - self.initialize(X1.shape[1], X2.shape[1]) + self.initialize(X1.shape[1], X2.shape[1], ref_array=X1) # Compute new correlation matrices from current batch - C11_new = X1.T @ X1 / X1.shape[0] - C22_new = X2.T @ X2 / X2.shape[0] - C12_new = X1.T @ X2 / X1.shape[0] + C11_new = _mT(X1) @ X1 / X1.shape[0] + C22_new = _mT(X2) @ X2 / X2.shape[0] + C12_new = _mT(X1) @ X2 / X1.shape[0] # Detect changes and adapt smoothing factor - if self.C11.any(): # Skip first update + if bool(xp.any(self.C11 != 0)): # Skip first update change_magnitude = self._compute_change_magnitude(C11_new, C22_new, C12_new) self._adapt_smoothing(change_magnitude) @@ -98,25 +138,26 @@ def partial_fit(self, X1, X2, update_projections=True): self._update_projections() def _update_projections(self): - """Update canonical vectors and correlations""" + """Update canonical vectors and correlations.""" + xp = get_namespace(self.C11) + dev = array_device(self.C11) + _mT = xp.linalg.matrix_transpose + eps = 1e-8 - C11_reg = self.C11 + eps * np.eye(self.d1) - C22_reg = self.C22 + eps * np.eye(self.d2) + C11_reg = self.C11 + eps * xp_create(xp.eye, self.d1, dtype=self.C11.dtype, device=dev) + C22_reg = self.C22 + eps * xp_create(xp.eye, self.d2, dtype=self.C22.dtype, device=dev) + + inv_sqrt_C11 = _inv_sqrtm_spd(xp, C11_reg) + inv_sqrt_C22 = _inv_sqrtm_spd(xp, C22_reg) - K = ( - linalg.inv(linalg.sqrtm(C11_reg)) - @ self.C12 - @ linalg.inv(linalg.sqrtm(C22_reg)) - ) - U, self.correlations_, V = linalg.svd(K) + K = inv_sqrt_C11 @ self.C12 @ inv_sqrt_C22 + U, self.correlations_, Vh = xp.linalg.svd(K, full_matrices=False) - self.x_weights_ = linalg.inv(linalg.sqrtm(C11_reg)) @ U[:, : self.n_components] - self.y_weights_ = ( - linalg.inv(linalg.sqrtm(C22_reg)) @ V.T[:, : self.n_components] - ) + self.x_weights_ = inv_sqrt_C11 @ U[:, : self.n_components] + self.y_weights_ = inv_sqrt_C22 @ _mT(Vh)[:, : self.n_components] def transform(self, X1, X2): - """Project data onto canonical components""" + """Project data onto canonical components.""" X1_proj = X1 @ self.x_weights_ X2_proj = X2 @ self.y_weights_ return X1_proj, X2_proj diff --git a/src/ezmsg/learn/model/refit_kalman.py b/src/ezmsg/learn/model/refit_kalman.py index 4481c17..3d7124d 100644 --- a/src/ezmsg/learn/model/refit_kalman.py +++ b/src/ezmsg/learn/model/refit_kalman.py @@ -1,6 +1,17 @@ # refit_kalman.py +"""Refit Kalman filter for adaptive neural decoding. + +.. note:: + This module supports the Array API standard via + ``array_api_compat.get_namespace()``. All linear algebra in :meth:`fit`, + :meth:`predict`, and :meth:`update` stays in the source array namespace. + The DARE solver in :meth:`_compute_gain` and the per-sample mutation loop + in :meth:`refit` use NumPy regardless of input backend. +""" import numpy as np +from array_api_compat import get_namespace +from ezmsg.sigproc.util.array import array_device, xp_asarray, xp_create from numpy.linalg import LinAlgError from scipy.linalg import solve_discrete_are @@ -108,19 +119,22 @@ def fit(self, X_train, y_train): """ # self._validate_state_vector(y_train) - X = np.array(y_train) - Z = np.array(X_train) + xp = get_namespace(X_train, y_train) + _mT = xp.linalg.matrix_transpose + + X = xp.asarray(y_train) + Z = xp.asarray(X_train) n_samples = X.shape[0] # Calculate the transition matrix (from x_t to x_t+1) using least-squares X2 = X[1:, :] # x_{t+1} X1 = X[:-1, :] # x_t - A = X2.T @ X1 @ np.linalg.inv(X1.T @ X1) # Transition matrix - W = (X2 - X1 @ A.T).T @ (X2 - X1 @ A.T) / (n_samples - 1) # Covariance of transition matrix + A = _mT(X2) @ X1 @ xp.linalg.inv(_mT(X1) @ X1) # Transition matrix + W = _mT(X2 - X1 @ _mT(A)) @ (X2 - X1 @ _mT(A)) / (n_samples - 1) # Covariance of transition matrix # Calculate the measurement matrix (from x_t to z_t) using least-squares - H = Z.T @ X @ np.linalg.inv(X.T @ X) # Measurement matrix - Q = (Z - X @ H.T).T @ (Z - X @ H.T) / Z.shape[0] # Covariance of measurement matrix + H = _mT(Z) @ X @ xp.linalg.inv(_mT(X) @ X) # Measurement matrix + Q = _mT(Z - X @ _mT(H)) @ (Z - X @ _mT(H)) / Z.shape[0] # Covariance of measurement matrix self.A_state_transition_matrix = A self.W_process_noise_covariance = W * self.process_noise_scale @@ -132,12 +146,12 @@ def fit(self, X_train, y_train): def refit( self, - X_neural: np.ndarray, - Y_state: np.ndarray, + X_neural, + Y_state, intention_velocity_indices: int | None = None, - target_positions: np.ndarray | None = None, - cursor_positions: np.ndarray | None = None, - hold_indices: np.ndarray | None = None, + target_positions=None, + cursor_positions=None, + hold_indices=None, ): """ Refit the observation model based on new data. @@ -179,13 +193,17 @@ def refit( else: vel_idx = intention_velocity_indices + # The per-sample mutation loop uses numpy for element-wise operations + # on small vectors (np.linalg.norm on 2-element vectors, scalar indexing). + Y_state_np = np.asarray(Y_state) + target_positions_np = np.asarray(target_positions) if target_positions is not None else None + cursor_positions_np = np.asarray(cursor_positions) if cursor_positions is not None else None + # Only remap velocity if target and cursor positions are provided - if target_positions is None or cursor_positions is None: - intended_states = Y_state.copy() - else: - intended_states = Y_state.copy() + intended_states = Y_state_np.copy() + if target_positions_np is not None and cursor_positions_np is not None: # Calculate intended velocities for each sample - for i, (state, pos, target) in enumerate(zip(Y_state, cursor_positions, target_positions)): + for i, (state, pos, target) in enumerate(zip(Y_state_np, cursor_positions_np, target_positions_np)): is_hold = hold_indices[i] if hold_indices is not None else False if is_hold: @@ -213,14 +231,19 @@ def refit( else: intended_states[i, vel_idx : vel_idx + 2] = state[vel_idx : vel_idx + 2] - intended_states = np.array(intended_states) - Z = np.array(X_neural) + # Convert back to source namespace for final linalg + xp = get_namespace(X_neural) + dev = array_device(X_neural) + _mT = xp.linalg.matrix_transpose + + intended_states = xp_asarray(xp, intended_states, device=dev) + Z = xp.asarray(X_neural) # Recalculate observation matrix and noise covariance H = ( - Z.T @ intended_states @ np.linalg.pinv(intended_states.T @ intended_states) + _mT(Z) @ intended_states @ xp.linalg.pinv(_mT(intended_states) @ intended_states) ) # Using pinv() instead of inv() to avoid singular matrix errors - Q = (Z - intended_states @ H.T).T @ (Z - intended_states @ H.T) / Z.shape[0] + Q = _mT(Z - intended_states @ _mT(H)) @ (Z - intended_states @ _mT(H)) / Z.shape[0] self.H_observation_matrix = H self.Q_measurement_noise_covariance = Q @@ -236,56 +259,44 @@ def _compute_gain(self): Riccati equation to find the optimal steady-state gain. In non-steady-state mode, it computes the gain using the current covariance matrix. + The DARE solver requires NumPy arrays; results are converted back to the + source array namespace. + Raises: LinAlgError: If the Riccati equation cannot be solved or matrix operations fail. """ - # TODO: consider removing non-steady-state for compute_gain() - - # non_steady_state updates will occur during predict() and update() - # if self.steady_state: + xp = get_namespace(self.A_state_transition_matrix) + dev = array_device(self.A_state_transition_matrix) + _mT = xp.linalg.matrix_transpose + + # Convert to numpy for DARE (no Array API equivalent) + A_np = np.asarray(self.A_state_transition_matrix) + H_np = np.asarray(self.H_observation_matrix) + W_np = np.asarray(self.W_process_noise_covariance) + Q_np = np.asarray(self.Q_measurement_noise_covariance) + try: - # Try with original matrices - self.P_state_covariance = solve_discrete_are( - self.A_state_transition_matrix.T, - self.H_observation_matrix.T, - self.W_process_noise_covariance, - self.Q_measurement_noise_covariance, - ) - self.K_kalman_gain = ( - self.P_state_covariance - @ self.H_observation_matrix.T - @ np.linalg.inv( - self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T - + self.Q_measurement_noise_covariance - ) + P_np = solve_discrete_are(A_np.T, H_np.T, W_np, Q_np) + self.P_state_covariance = xp_asarray(xp, P_np, device=dev) + S = ( + self.H_observation_matrix @ self.P_state_covariance @ _mT(self.H_observation_matrix) + + self.Q_measurement_noise_covariance ) + self.K_kalman_gain = self.P_state_covariance @ _mT(self.H_observation_matrix) @ xp.linalg.inv(S) except LinAlgError: - # Apply regularization and retry - # A_reg = self.A_state_transition_matrix * 0.999 # Slight damping - # W_reg = self.W_process_noise_covariance + 1e-7 * np.eye( - # self.W_process_noise_covariance.shape[0] - # ) - Q_reg = self.Q_measurement_noise_covariance + 1e-7 * np.eye(self.Q_measurement_noise_covariance.shape[0]) + Q_reg_np = Q_np + 1e-7 * np.eye(Q_np.shape[0]) try: - self.P_state_covariance = solve_discrete_are( - self.A_state_transition_matrix.T, - self.H_observation_matrix.T, - self.W_process_noise_covariance, - Q_reg, - ) - self.K_kalman_gain = ( - self.P_state_covariance - @ self.H_observation_matrix.T - @ np.linalg.inv( - self.H_observation_matrix @ self.P_state_covariance @ self.H_observation_matrix.T + Q_reg - ) - ) + P_np = solve_discrete_are(A_np.T, H_np.T, W_np, Q_reg_np) + self.P_state_covariance = xp_asarray(xp, P_np, device=dev) + Q_reg = xp_asarray(xp, Q_reg_np, device=dev) + S = self.H_observation_matrix @ self.P_state_covariance @ _mT(self.H_observation_matrix) + Q_reg + self.K_kalman_gain = self.P_state_covariance @ _mT(self.H_observation_matrix) @ xp.linalg.inv(S) print("Warning: Used regularized matrices for DARE solution") except LinAlgError: # Fallback to identity or manual initialization print("Warning: DARE failed, using identity covariance") - self.P_state_covariance = np.eye(self.A_state_transition_matrix.shape[0]) - + self.P_state_covariance = xp_create(xp.eye, self.A_state_transition_matrix.shape[0], device=dev) # else: # n_states = self.A_state_transition_matrix.shape[0] # self.P_state_covariance = ( @@ -311,29 +322,35 @@ def _compute_gain(self): # I_mat - self.K_kalman_gain @ self.H_observation_matrix # ) @ P_m - def predict(self, x_current: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + def predict(self, x_current): """ Predict the next state and covariance. This method predicts the next state and covariance using the current state. """ + xp = get_namespace(x_current) + _mT = xp.linalg.matrix_transpose + x_predicted = self.A_state_transition_matrix @ x_current if self.steady_state is True: return x_predicted, None else: P_predicted = self.alpha_fading_memory**2 * ( - self.A_state_transition_matrix @ self.P_state_covariance @ self.A_state_transition_matrix.T + self.A_state_transition_matrix @ self.P_state_covariance @ _mT(self.A_state_transition_matrix) + self.W_process_noise_covariance ) return x_predicted, P_predicted def update( self, - z_measurement: np.ndarray, - x_predicted: np.ndarray, - P_predicted: np.ndarray | None = None, - ) -> np.ndarray: + z_measurement, + x_predicted, + P_predicted=None, + ): """Update state estimate and covariance based on measurement z.""" + xp = get_namespace(z_measurement, x_predicted) + dev = array_device(x_predicted) + _mT = xp.linalg.matrix_transpose # Compute residual innovation = z_measurement - self.H_observation_matrix @ x_predicted @@ -347,19 +364,23 @@ def update( # Non-steady-state mode # System uncertainty - S = self.H_observation_matrix @ P_predicted @ self.H_observation_matrix.T + self.Q_measurement_noise_covariance + S = ( + self.H_observation_matrix @ P_predicted @ _mT(self.H_observation_matrix) + + self.Q_measurement_noise_covariance + ) # Kalman gain - K = P_predicted @ self.H_observation_matrix.T @ np.linalg.pinv(S) + K = P_predicted @ _mT(self.H_observation_matrix) @ xp.linalg.pinv(S) # Updated state x_updated = x_predicted + K @ innovation # Covariance update - I_mat = np.eye(self.A_state_transition_matrix.shape[0]) - P_updated = (I_mat - K @ self.H_observation_matrix) @ P_predicted @ ( + n = self.A_state_transition_matrix.shape[0] + I_mat = xp_create(xp.eye, n, device=dev) + P_updated = (I_mat - K @ self.H_observation_matrix) @ P_predicted @ _mT( I_mat - K @ self.H_observation_matrix - ).T + K @ self.Q_measurement_noise_covariance @ K.T + ) + K @ self.Q_measurement_noise_covariance @ _mT(K) # Save updated values self.P_state_covariance = P_updated diff --git a/src/ezmsg/learn/process/adaptive_linear_regressor.py b/src/ezmsg/learn/process/adaptive_linear_regressor.py index 3e00909..6ca65b1 100644 --- a/src/ezmsg/learn/process/adaptive_linear_regressor.py +++ b/src/ezmsg/learn/process/adaptive_linear_regressor.py @@ -1,3 +1,12 @@ +"""Adaptive linear regressor processor. + +.. note:: + This module supports the Array API standard via + ``array_api_compat.get_namespace()``. NaN checks and axis permutations + use Array API operations; a NumPy boundary is applied before sklearn + ``partial_fit``/``predict`` and before river ``learn_many``/``predict_many``. +""" + from dataclasses import field import ezmsg.core as ez @@ -6,6 +15,7 @@ import river.linear_model import river.optim import sklearn.base +from array_api_compat import get_namespace, is_numpy_array from ezmsg.baseproc import ( BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit, @@ -78,24 +88,32 @@ def _reset_state(self, message: AxisArray) -> None: pass def partial_fit(self, message: AxisArray) -> None: - if np.any(np.isnan(message.data)): + xp = get_namespace(message.data) + + if xp.any(xp.isnan(message.data)): return if self.settings.model_type in [ AdaptiveLinearRegressor.LINEAR, AdaptiveLinearRegressor.LOGISTIC, ]: - x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)}) + # river path: needs numpy/pandas + data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data + x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, data_np.T)}) y = pd.Series( data=message.attrs["trigger"].value.data[:, 0], name=message.attrs["trigger"].value.axes["ch"].data[0], ) self.state.model.learn_many(x, y) else: + # sklearn path: permute then convert to numpy X = message.data - if message.get_axis_idx("time") != 0: - X = np.moveaxis(X, message.get_axis_idx("time"), 0) - self.state.model.partial_fit(X, message.attrs["trigger"].value.data) + ax_idx = message.get_axis_idx("time") + if ax_idx != 0: + perm = (ax_idx,) + tuple(i for i in range(X.ndim) if i != ax_idx) + X = xp.permute_dims(X, perm) + X_np = np.asarray(X) if not is_numpy_array(X) else X + self.state.model.partial_fit(X_np, message.attrs["trigger"].value.data) self.state.template = replace( message.attrs["trigger"].value, @@ -107,16 +125,21 @@ def _process(self, message: AxisArray) -> AxisArray | None: if self.state.template is None: return AxisArray(np.array([]), dims=[""]) - if not np.any(np.isnan(message.data)): + xp = get_namespace(message.data) + + if not xp.any(xp.isnan(message.data)): if self.settings.model_type in [ AdaptiveLinearRegressor.LINEAR, AdaptiveLinearRegressor.LOGISTIC, ]: - # convert msg_in.data to something appropriate for river - x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, message.data.T)}) + # river path: needs numpy/pandas + data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data + x = pd.DataFrame.from_dict({k: v for k, v in zip(message.axes["ch"].data, data_np.T)}) preds = self.state.model.predict_many(x).values else: - preds = self.state.model.predict(message.data) + # sklearn path: needs numpy + data_np = np.asarray(message.data) if not is_numpy_array(message.data) else message.data + preds = self.state.model.predict(data_np) return replace( self.state.template, data=preds.reshape((len(preds), -1)), diff --git a/src/ezmsg/learn/process/refit_kalman.py b/src/ezmsg/learn/process/refit_kalman.py index 6684bc9..62ae835 100644 --- a/src/ezmsg/learn/process/refit_kalman.py +++ b/src/ezmsg/learn/process/refit_kalman.py @@ -1,13 +1,24 @@ +"""Refit Kalman filter processor for ezmsg. + +.. note:: + This module supports the Array API standard via + ``array_api_compat.get_namespace()``. State initialization and output + arrays use the source data's namespace. Buffered data for refitting + is stacked via ``xp.stack``. +""" + import pickle from pathlib import Path import ezmsg.core as ez import numpy as np +from array_api_compat import get_namespace from ezmsg.baseproc import ( BaseAdaptiveTransformer, BaseAdaptiveTransformerUnit, processor_state, ) +from ezmsg.sigproc.util.array import array_device, xp_create from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messages.util import replace @@ -54,8 +65,8 @@ class RefitKalmanFilterState: """ model: RefitKalmanFilter | None = None - x: np.ndarray | None = None - P: np.ndarray | None = None + x: object | None = None # Array API; namespace matches source data. + P: object | None = None # Array API; namespace matches source data. buffer_neural: list | None = None buffer_state: list | None = None @@ -142,7 +153,7 @@ def _init_model(self, **kwargs): config.update(kwargs) self._state.model = RefitKalmanFilter(**config) - def fit(self, X: np.ndarray, y: np.ndarray) -> None: + def fit(self, X, y) -> None: if self._state.model is None: self._init_model() if hasattr(self._state.model, "fit"): @@ -176,6 +187,11 @@ def save_checkpoint(self, checkpoint_path: str) -> None: Raises: ValueError: If the model is not initialized or has not been fitted. + + .. note:: + Checkpoint data is pickled as NumPy arrays. If model matrices live + on a non-NumPy backend (e.g. CuPy), pickle may fail — this is a + pre-existing limitation. """ if not self._state.model or not self._state.model.is_fitted: raise ValueError("Cannot save checkpoint: model not fitted") @@ -222,8 +238,15 @@ def _reset_state( if self._state.model.A_state_transition_matrix is not None: state_dim = self._state.model.A_state_transition_matrix.shape[0] - self._state.x = np.zeros(state_dim) - self._state.P = np.eye(state_dim) + # Derive xp/dev from message data when available; default to numpy. + if message is not None: + xp = get_namespace(message.data) + dev = array_device(message.data) + else: + xp, dev = np, None + + self._state.x = xp_create(xp.zeros, (state_dim,), dtype=xp.float64, device=dev) + self._state.P = xp_create(xp.eye, state_dim, dtype=xp.float64, device=dev) self._state.buffer_neural = [] self._state.buffer_state = [] @@ -252,15 +275,19 @@ def _process(self, message: AxisArray) -> AxisArray: # No checkpoint means you need to initialize and fit the model elif not self._state.model: self._init_model() + + xp = get_namespace(message.data) + dev = array_device(message.data) + state_dim = self._state.model.A_state_transition_matrix.shape[0] if self._state.x is None: - self._state.x = np.zeros(state_dim) + self._state.x = xp_create(xp.zeros, (state_dim,), dtype=xp.float64, device=dev) - filtered_data = np.zeros( - ( - message.data.shape[0], - self._state.model.A_state_transition_matrix.shape[0], - ) + filtered_data = xp_create( + xp.zeros, + (message.data.shape[0], state_dim), + dtype=message.data.dtype, + device=dev, ) for i in range(message.data.shape[0]): @@ -271,9 +298,9 @@ def _process(self, message: AxisArray) -> AxisArray: # Update x_updated = self._state.model.update(measurement, x_pred, P_pred) - # Store - self._state.x = x_updated.copy() - self._state.P = self._state.model.P_state_covariance.copy() + # Store — no .copy() needed, predict/update return new arrays + self._state.x = x_updated + self._state.P = self._state.model.P_state_covariance filtered_data[i] = self._state.x return replace( @@ -297,7 +324,8 @@ def partial_fit(self, message: AxisArray) -> None: if "trigger" not in message.attrs: raise ValueError("Invalid message format for partial_fit.") - X = np.array(message.data) + xp = get_namespace(message.data) + X = xp.asarray(message.data) values = message.attrs["trigger"].value if not isinstance(values, dict) or "Y_state" not in values: @@ -305,7 +333,7 @@ def partial_fit(self, message: AxisArray) -> None: kwargs = { "X_neural": X, - "Y_state": np.array(values["Y_state"]), + "Y_state": xp.asarray(values["Y_state"]), } # Optional fields @@ -316,7 +344,7 @@ def partial_fit(self, message: AxisArray) -> None: "hold_flags", ]: if key in values and values[key] is not None: - kwargs[key if key != "hold_flags" else "hold_indices"] = np.array(values[key]) + kwargs[key if key != "hold_flags" else "hold_indices"] = xp.asarray(values[key]) # Call model refit self._state.model.refit(**kwargs) @@ -324,7 +352,7 @@ def partial_fit(self, message: AxisArray) -> None: def log_for_refit( self, message: AxisArray, - target_position: np.ndarray | None = None, + target_position=None, hold_flag: bool | None = None, ): """ @@ -341,13 +369,13 @@ def log_for_refit( hold_flag: Boolean flag indicating if this is a hold period. """ if target_position is not None: - self._state.buffer_target_positions.append(target_position.copy()) + self._state.buffer_target_positions.append(target_position) if hold_flag is not None: self._state.buffer_hold_flags.append(hold_flag) measurement = message.data[-1] - self._state.buffer_neural.append(measurement.copy()) - self._state.buffer_state.append(self._state.x.copy()) + self._state.buffer_neural.append(measurement) + self._state.buffer_state.append(self._state.x) def refit_model(self): """ @@ -370,17 +398,19 @@ def refit_model(self): print("No buffered data to refit") return + xp = get_namespace(self._state.buffer_neural[0]) + kwargs = { - "X_neural": np.array(self._state.buffer_neural), - "Y_state": np.array(self._state.buffer_state), + "X_neural": xp.stack(self._state.buffer_neural), + "Y_state": xp.stack(self._state.buffer_state), "intention_velocity_indices": self.settings.velocity_indices[0], } if self._state.buffer_target_positions and self._state.buffer_cursor_positions: - kwargs["target_positions"] = np.array(self._state.buffer_target_positions) - kwargs["cursor_positions"] = np.array(self._state.buffer_cursor_positions) + kwargs["target_positions"] = xp.stack(self._state.buffer_target_positions) + kwargs["cursor_positions"] = xp.stack(self._state.buffer_cursor_positions) if self._state.buffer_hold_flags: - kwargs["hold_indices"] = np.array(self._state.buffer_hold_flags) + kwargs["hold_indices"] = xp.asarray(self._state.buffer_hold_flags) self._state.model.refit(**kwargs) diff --git a/src/ezmsg/learn/process/slda.py b/src/ezmsg/learn/process/slda.py index f3825a0..91a24b9 100644 --- a/src/ezmsg/learn/process/slda.py +++ b/src/ezmsg/learn/process/slda.py @@ -1,7 +1,17 @@ +"""Shrinkage LDA classifier processor. + +.. note:: + This module supports the Array API standard via + ``array_api_compat.get_namespace()``. Input data is manipulated using + Array API operations (``permute_dims``, ``reshape``); a NumPy boundary + is applied before ``sklearn.predict_proba``. +""" + import typing import ezmsg.core as ez import numpy as np +from array_api_compat import get_namespace, is_numpy_array from ezmsg.baseproc import ( BaseStatefulTransformer, BaseTransformerUnit, @@ -72,23 +82,29 @@ def _reset_state(self, message: AxisArray) -> None: ) def _process(self, message: AxisArray) -> ClassifierMessage: + xp = get_namespace(message.data) samp_ax_idx = message.dims.index(self.settings.axis) - X = np.moveaxis(message.data, samp_ax_idx, 0) + + # Move sample axis to front + perm = (samp_ax_idx,) + tuple(i for i in range(message.data.ndim) if i != samp_ax_idx) + X = xp.permute_dims(message.data, perm) if X.shape[0]: if isinstance(self.settings.settings_path, str) and self.settings.settings_path[-4:] == ".mat": - # Assumes F-contiguous weights + # Assumes F-contiguous weights — need numpy for predict_proba + X_np = np.asarray(X) if not is_numpy_array(X) else X pred_probas = [] - for samp in X: + for samp in X_np: tmp = samp.flatten(order="F") * 1e-6 tmp = np.expand_dims(tmp, axis=0) probas = self.state.lda.predict_proba(tmp) pred_probas.append(probas) pred_probas = np.concatenate(pred_probas, axis=0) else: - # This creates a copy. - X = X.reshape(X.shape[0], -1) - pred_probas = self.state.lda.predict_proba(X) + # Numpy boundary before sklearn predict_proba + X_np = np.asarray(X) if not is_numpy_array(X) else X + X_np = X_np.reshape(X_np.shape[0], -1) + pred_probas = self.state.lda.predict_proba(X_np) update_ax = self.state.out_template.axes[self.settings.axis] update_ax.offset = message.axes[self.settings.axis].offset diff --git a/tests/unit/test_cca.py b/tests/unit/test_cca.py new file mode 100644 index 0000000..a4ffdb5 --- /dev/null +++ b/tests/unit/test_cca.py @@ -0,0 +1,125 @@ +import numpy as np +import pytest + +from ezmsg.learn.model.cca import IncrementalCCA, _inv_sqrtm_spd + + +@pytest.fixture +def rng(): + return np.random.default_rng(42) + + +@pytest.fixture +def synthetic_data(rng): + """Create correlated synthetic data for CCA testing.""" + n_samples, d1, d2 = 100, 5, 4 + # Shared latent factor + latent = rng.standard_normal((n_samples, 2)) + X1 = latent @ rng.standard_normal((2, d1)) + 0.1 * rng.standard_normal((n_samples, d1)) + X2 = latent @ rng.standard_normal((2, d2)) + 0.1 * rng.standard_normal((n_samples, d2)) + return X1, X2 + + +def test_initialize(): + """Test that initialize creates matrices with correct shapes.""" + cca = IncrementalCCA(n_components=2) + cca.initialize(5, 4) + + assert cca.C11.shape == (5, 5) + assert cca.C22.shape == (4, 4) + assert cca.C12.shape == (5, 4) + assert cca.initialized is True + assert cca.d1 == 5 + assert cca.d2 == 4 + + +def test_initialize_with_ref_array(): + """Test that initialize respects ref_array namespace.""" + ref = np.zeros(3) + cca = IncrementalCCA(n_components=2) + cca.initialize(5, 4, ref_array=ref) + + assert cca.C11.shape == (5, 5) + assert cca.C11.dtype == np.float64 + + +def test_partial_fit(synthetic_data): + """Test that partial_fit runs without error and updates covariance.""" + X1, X2 = synthetic_data + cca = IncrementalCCA(n_components=2) + cca.partial_fit(X1, X2) + + assert cca.initialized is True + assert not np.allclose(cca.C11, 0) + assert not np.allclose(cca.C22, 0) + assert not np.allclose(cca.C12, 0) + assert hasattr(cca, "x_weights_") + assert hasattr(cca, "y_weights_") + + +def test_partial_fit_incremental(synthetic_data): + """Test that multiple partial_fit calls update smoothly.""" + X1, X2 = synthetic_data + cca = IncrementalCCA(n_components=2) + + # First fit + cca.partial_fit(X1[:50], X2[:50]) + C11_first = cca.C11.copy() + + # Second fit should change covariance + cca.partial_fit(X1[50:], X2[50:]) + assert not np.allclose(C11_first, cca.C11) + + +def test_transform(synthetic_data): + """Test that transform produces correct output shapes.""" + X1, X2 = synthetic_data + cca = IncrementalCCA(n_components=2) + cca.partial_fit(X1, X2) + + X1_proj, X2_proj = cca.transform(X1, X2) + assert X1_proj.shape == (100, 2) + assert X2_proj.shape == (100, 2) + + +def test_transform_single_component(synthetic_data): + """Test with a single canonical component.""" + X1, X2 = synthetic_data + cca = IncrementalCCA(n_components=1) + cca.partial_fit(X1, X2) + + X1_proj, X2_proj = cca.transform(X1, X2) + assert X1_proj.shape == (100, 1) + assert X2_proj.shape == (100, 1) + + +def test_numerical_equivalence_inv_sqrtm(): + """Compare eigh-based _inv_sqrtm_spd with scipy.linalg on known SPD matrices.""" + scipy_linalg = pytest.importorskip("scipy.linalg") + + rng = np.random.default_rng(123) + # Create a known SPD matrix + A_raw = rng.standard_normal((5, 5)) + A = A_raw.T @ A_raw + 0.1 * np.eye(5) # Ensure SPD + + # scipy reference + sqrtm_A = scipy_linalg.sqrtm(A) + inv_sqrtm_scipy = scipy_linalg.inv(np.real(sqrtm_A)) + + # Our implementation + inv_sqrtm_eigh = _inv_sqrtm_spd(np, A) + + np.testing.assert_allclose(inv_sqrtm_eigh, inv_sqrtm_scipy, atol=1e-10) + + +def test_correlations_are_computed(synthetic_data): + """Test that canonical correlations are stored after fit.""" + X1, X2 = synthetic_data + cca = IncrementalCCA(n_components=2) + cca.partial_fit(X1, X2) + + assert hasattr(cca, "correlations_") + assert len(cca.correlations_) >= 2 + # Correlations should be between 0 and 1 for well-conditioned data + assert np.all(cca.correlations_[:2] >= 0) + assert np.all(cca.correlations_[:2] <= 1.0 + 1e-6) From 60abdb710add5c8f35ab2f92228a677220d7b2d8 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Sat, 7 Feb 2026 03:37:16 -0500 Subject: [PATCH 2/3] Add Array API documentation --- docs/source/guides/array_api.rst | 246 +++++++++++++++++++++++++++++++ docs/source/index.rst | 1 + 2 files changed, 247 insertions(+) create mode 100644 docs/source/guides/array_api.rst diff --git a/docs/source/guides/array_api.rst b/docs/source/guides/array_api.rst new file mode 100644 index 0000000..d7e15d5 --- /dev/null +++ b/docs/source/guides/array_api.rst @@ -0,0 +1,246 @@ +Array API Compatibility +======================= + +ezmsg-learn uses the `Array API standard `_ +to allow processors to operate on arrays from different backends — NumPy, CuPy, +PyTorch, and others — without code changes. + +.. contents:: On this page + :local: + :depth: 2 + + +How It Works +------------ + +Modules that support the Array API derive the array namespace from their input +data using ``array_api_compat.get_namespace()``: + +.. code-block:: python + + from array_api_compat import get_namespace + + def process(self, data): + xp = get_namespace(data) # numpy, cupy, torch, etc. + result = xp.linalg.inv(data) # dispatches to the right backend + return result + +This means that if you pass a CuPy array, all computation stays on the GPU. +If you pass a NumPy array, it behaves exactly as before. + +Helper utilities from ``ezmsg.sigproc.util.array`` handle device placement +and creation functions portably: + +- ``array_device(x)`` — returns the device of an array, or ``None`` +- ``xp_create(fn, *args, dtype=None, device=None)`` — calls creation + functions (``zeros``, ``eye``) with optional device +- ``xp_asarray(xp, obj, dtype=None, device=None)`` — portable ``asarray`` + + +Module Compatibility +-------------------- + +The table below summarises the Array API status of each module. + +Fully compatible +^^^^^^^^^^^^^^^^ + +These modules perform all computation in the source array namespace. + +.. list-table:: + :header-rows: 1 + :widths: 35 65 + + * - Module + - Notes + * - ``process.ssr`` + - LRR / self-supervised regression. Full Array API. + * - ``model.cca`` + - Incremental CCA. Replaced ``scipy.linalg.sqrtm`` with an + eigendecomposition-based inverse square root using only Array API ops. + * - ``process.rnn`` + - PyTorch-native; operates on ``torch.Tensor`` throughout. + +Mostly compatible (with NumPy boundaries) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +These modules use the Array API for data manipulation but fall back to NumPy +at specific points where a dependency requires it. + +.. list-table:: + :header-rows: 1 + :widths: 25 35 40 + + * - Module + - NumPy boundary + - Reason + * - ``model.refit_kalman`` + - ``_compute_gain()`` + - ``scipy.linalg.solve_discrete_are`` has no Array API equivalent. + Matrices are converted to NumPy for the DARE solver, then converted back. + * - ``model.refit_kalman`` + - ``refit()`` mutation loop + - Per-sample velocity remapping uses ``np.linalg.norm`` on small vectors + and scalar element assignment. + * - ``process.refit_kalman`` + - Inherits boundaries from model + - State init and output arrays use the source namespace. + * - ``process.slda`` + - ``predict_proba`` + - sklearn ``LinearDiscriminantAnalysis`` requires NumPy input. + * - ``process.adaptive_linear_regressor`` + - ``partial_fit`` / ``predict`` + - sklearn and river models require NumPy / pandas input. + * - ``dim_reduce.adaptive_decomp`` + - ``partial_fit`` / ``transform`` + - sklearn ``IncrementalPCA`` and ``MiniBatchNMF`` require NumPy input. + +Not converted +^^^^^^^^^^^^^ + +These modules use NumPy directly. Conversion would provide little benefit +because the underlying estimator is the bottleneck. + +.. list-table:: + :header-rows: 1 + :widths: 25 75 + + * - Module + - Reason + * - ``process.linear_regressor`` + - Thin wrapper around sklearn ``LinearModel.predict``. + Could be made compatible if sklearn's ``array_api_dispatch`` is enabled + (see below). + * - ``process.sgd`` + - sklearn ``SGDClassifier`` has no Array API support. + * - ``process.sklearn`` + - Generic wrapper for arbitrary models; cannot assume Array API support. + * - ``dim_reduce.incremental_decomp`` + - Delegates to ``adaptive_decomp``; trivial numpy usage (``np.prod`` on + Python tuples). + + +sklearn Array API Dispatch +-------------------------- + +scikit-learn 1.8+ has experimental support for Array API dispatch on a subset +of estimators. Two estimators used in ezmsg-learn are on the supported list: + +.. list-table:: + :header-rows: 1 + :widths: 30 30 40 + + * - Estimator + - Used in + - Constraint + * - ``LinearDiscriminantAnalysis`` + - ``process.slda`` + - Requires ``solver="svd"`` (the ``"lsqr"`` solver with ``shrinkage`` + is not supported) + * - ``Ridge`` + - ``process.linear_regressor`` + - Requires ``solver="svd"`` + +To use dispatch, enable it before creating the estimator: + +.. code-block:: python + + from sklearn import set_config + set_config(array_api_dispatch=True) + +.. warning:: + + - ``array_api_dispatch`` is marked **experimental** in sklearn. + - Solver constraints (``solver="svd"``) may produce slightly different + numerical results compared to other solvers. + - Enabling dispatch globally may affect other sklearn estimators in the + same process. + - ezmsg-learn does **not** enable dispatch by default. + +Estimators that do **not** support Array API dispatch: + +- ``IncrementalPCA``, ``MiniBatchNMF`` — only batch ``PCA`` is supported +- ``SGDClassifier``, ``SGDRegressor``, ``PassiveAggressiveRegressor`` +- All river models + + +Writing Array API Compatible Code +---------------------------------- + +When adding or modifying processors in ezmsg-learn, follow these patterns. + +Deriving the namespace +^^^^^^^^^^^^^^^^^^^^^^ + +Always derive ``xp`` from the input data, not from a hardcoded ``numpy``: + +.. code-block:: python + + from array_api_compat import get_namespace + from ezmsg.sigproc.util.array import array_device, xp_create + + def _process(self, message): + xp = get_namespace(message.data) + dev = array_device(message.data) + +Transposing matrices +^^^^^^^^^^^^^^^^^^^^ + +The Array API does not support ``.T``. Use ``xp.linalg.matrix_transpose()``: + +.. code-block:: python + + # Before (numpy-only) + result = A.T @ B + + # After (Array API) + _mT = xp.linalg.matrix_transpose + result = _mT(A) @ B + +Creating arrays +^^^^^^^^^^^^^^^ + +Use ``xp_create`` to handle device placement portably: + +.. code-block:: python + + # Before + I = np.eye(n) + z = np.zeros((m, n), dtype=np.float64) + + # After + I = xp_create(xp.eye, n, device=dev) + z = xp_create(xp.zeros, (m, n), dtype=xp.float64, device=dev) + +Handling sklearn boundaries +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When calling into sklearn (or other NumPy-only libraries), convert at the +boundary and convert back: + +.. code-block:: python + + from array_api_compat import is_numpy_array + + # Convert to numpy for sklearn + X_np = np.asarray(X) if not is_numpy_array(X) else X + result_np = estimator.predict(X_np) + + # Convert back to source namespace + result = xp.asarray(result_np) if not is_numpy_array(X) else result_np + +Checking for NaN +^^^^^^^^^^^^^^^^ + +Use ``xp.isnan`` instead of ``np.isnan``: + +.. code-block:: python + + if xp.any(xp.isnan(message.data)): + return + +Norms +^^^^^ + +Use ``xp.linalg.matrix_norm`` (Frobenius by default) instead of +``np.linalg.norm`` for matrices. For vectors, use ``xp.linalg.vector_norm``. diff --git a/docs/source/index.rst b/docs/source/index.rst index db90cf2..e240d38 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -54,6 +54,7 @@ For general ezmsg tutorials and guides, visit `ezmsg.org :caption: Contents: guides/classification + guides/array_api api/index From 9c12e327dc2030cf778e4e7115e687f24fe22024 Mon Sep 17 00:00:00 2001 From: Chadwick Boulay Date: Wed, 11 Feb 2026 15:33:54 -0500 Subject: [PATCH 3/3] Bump dependencies --- pyproject.toml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b175a1a..315ec87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,8 +9,9 @@ license = "MIT" requires-python = ">=3.10.15" dynamic = ["version"] dependencies = [ - "ezmsg-baseproc>=1.4.0", - "ezmsg-sigproc>=2.15.0", + "ezmsg>=3.7.3", + "ezmsg-baseproc>=1.5.1", + "ezmsg-sigproc>=2.17.0", "river>=0.22.0", "scikit-learn>=1.6.0", "torch>=2.6.0", @@ -73,5 +74,4 @@ known-third-party = ["ezmsg", "ezmsg.baseproc", "ezmsg.sigproc"] [tool.uv.sources] # Uncomment to use development version of ezmsg from git -#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" } -#ezmsg-sigproc = { path = "../ezmsg-sigproc", editable = true } \ No newline at end of file +#ezmsg = { git = "https://github.com/ezmsg-org/ezmsg.git", branch = "feature/profiling" } \ No newline at end of file