diff --git a/docs/API_DESIGN_DOCUMENT.md b/docs/API_DESIGN_DOCUMENT.md index 4611126..ef821e4 100644 --- a/docs/API_DESIGN_DOCUMENT.md +++ b/docs/API_DESIGN_DOCUMENT.md @@ -336,7 +336,7 @@ src/psyphy/ │ └── expected_improvement.py │ ├── data/ # Data handling -│ ├── dataset.py # ResponseData class +│ ├── dataset.py # ResponseData class, TrialData class │ └── io.py # Loading/saving │ └── utils/ # Utilities diff --git a/docs/examples/wppm/full_wppm_fit_example.md b/docs/examples/wppm/full_wppm_fit_example.md index c1809c5..b18dfa9 100644 --- a/docs/examples/wppm/full_wppm_fit_example.md +++ b/docs/examples/wppm/full_wppm_fit_example.md @@ -77,11 +77,13 @@ Note on data used in this script: here, we simulate data (and hence have a groun - The canonical, (batched) input expected by likelihood evaluation and optimizers (e.g., MAPOptimizer.fit(...)). It holds JAX arrays: - - refs: (N, d) - - comparisons: (N, d) - - responses: (N,) + - stimuli: (N, K, d) + - responses: (N, R) + - context: (N, C) (optional) -where $d$ refers to the input dimension, here 2. +where $N$ is the number of trials, $K$ is the number of distinct stimuli per trial (e.g. K=2 in the Oddity task: reference and comparison), $d$ is the stimulus space dimensionality, $R$ is the number of response channels, and $C$ is the number of context channels. + +Context is an optional attribute which is intended to track task-wide features that vary across trials which might condition likelihoods. For example, psyiological metrics like pupil dilation could be included as context. PsyPhy does not yet have inbuilt uses for context. **`ResponseData` (collection/I/O-first; convenient for experiments):** - A Python-friendly log (stores trials in lists) designed for incremental collection, saving/loading (e.g., CSV), and adaptive experiments but expensive for computation diff --git a/docs/examples/wppm/full_wppm_fit_example.py b/docs/examples/wppm/full_wppm_fit_example.py index db3380d..efc1add 100644 --- a/docs/examples/wppm/full_wppm_fit_example.py +++ b/docs/examples/wppm/full_wppm_fit_example.py @@ -115,7 +115,7 @@ def _ellipse_segments_from_covs( eigvals = jnp.linalg.eigvalsh(covs) valid = jnp.all(eigvals > 0, axis=-1) - # Cholesky is faster than eigendecomposition for SPD matrices. + # Cholesky is faster than eigen decomposition for SPD matrices. # we only use it for plotting ellipses (shape/orientation), not inference. def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray: L = jnp.linalg.cholesky(cov) @@ -257,13 +257,16 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray: # Build the canonical batched dataset for compute. # # Notes: -# - This is equivalent to storing X with shape (N, 2, d) and y with shape (N,) -# where X[:,0,:]=refs and X[:,1,:]=comparisons. -# - We keep named fields because it's currently native to OddityTask. -# - Even though oddity is a 3-item task, we only store (ref, comparison) +# - This is equivalent to storing X with shape (N, K, d) and y with shape (N, R) +# where N is trials, K is number of distinct stimuli per trial, d is stimulus +# dimensionality, and R is number of response channels. +# - For OddityTask, X[:,0,:]=refs and X[:,1,:]=comparisons. +# - Note that even though oddity is a 3-item task, we only store (ref, comparison) # because the oddity trial is assumed to be (ref, ref, comparison) +# # --8<-- [start:data] -data = TrialData(refs=refs, comparisons=comparisons, responses=ys) +stimuli = jnp.stack([refs, comparisons], axis=1) +data = TrialData(stimuli=stimuli, responses=ys) # --8<-- [end:data] # --8<-- [end:simulate_data] diff --git a/docs/examples/wppm/plots/quick_start_ellipses.png b/docs/examples/wppm/plots/quick_start_ellipses.png index d195de5..ceca592 100644 Binary files a/docs/examples/wppm/plots/quick_start_ellipses.png and b/docs/examples/wppm/plots/quick_start_ellipses.png differ diff --git a/docs/examples/wppm/plots/quick_start_learning_curve.png b/docs/examples/wppm/plots/quick_start_learning_curve.png index 2ba3f42..469767b 100644 Binary files a/docs/examples/wppm/plots/quick_start_learning_curve.png and b/docs/examples/wppm/plots/quick_start_learning_curve.png differ diff --git a/docs/examples/wppm/quick_start.md b/docs/examples/wppm/quick_start.md index 24cdf1b..287ec50 100644 --- a/docs/examples/wppm/quick_start.md +++ b/docs/examples/wppm/quick_start.md @@ -129,5 +129,5 @@ initialization (blue), and the MAP fit (red) at the single reference point: ## Next steps - **Spatially-varying field:** scale up to a full 2-D grid → [full example](full_wppm_fit_example.md). -- **Your own data:** replace the simulated `TrialData` with your own `refs`, `comparisons`, and `responses` arrays. +- **Your own data:** replace the simulated `TrialData` with your own `inputs` and `responses` arrays. (currently calculations are only supported when the inputs are "refs" and "comparisons.") - **API reference:** see [`MAPOptimizer`](../../reference/inference.md), [`WPPM`](../../reference/model.md), and [`WPPMCovarianceField`](../../reference/model.md). diff --git a/docs/examples/wppm/quick_start.py b/docs/examples/wppm/quick_start.py index 7e980b2..3e063fa 100644 --- a/docs/examples/wppm/quick_start.py +++ b/docs/examples/wppm/quick_start.py @@ -166,9 +166,8 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray: # --8<-- [end:simulate_data] # --8<-- [start:data] -data = TrialData( - refs=refs, comparisons=comparisons, responses=ys -) # contains 3 JAX arrays +stimuli = jnp.stack([refs, comparisons], axis=1) +data = TrialData(stimuli=stimuli, responses=ys) # contains 2 JAX arrays # --8<-- [end:data] diff --git a/src/psyphy/__init__.py b/src/psyphy/__init__.py index bb68f75..6e293a7 100644 --- a/src/psyphy/__init__.py +++ b/src/psyphy/__init__.py @@ -82,7 +82,7 @@ # from . import session as session from . import trial_placement as trial_placement from . import utils as utils -from .data.dataset import ResponseData, TrialBatch +from .data.dataset import ResponseData, TrialBatch, TrialData from .inference.langevin import LangevinSampler from .inference.laplace import LaplaceApproximation @@ -116,6 +116,7 @@ # Data handling "ResponseData", "TrialBatch", + "TrialData", # Subpackages "model", "inference", diff --git a/src/psyphy/data/dataset.py b/src/psyphy/data/dataset.py index 61d8c4d..efe5117 100644 --- a/src/psyphy/data/dataset.py +++ b/src/psyphy/data/dataset.py @@ -33,46 +33,90 @@ class TrialData: Shapes ------ - refs : (N, d) - comparisons : (N, d) - responses : (N,) + stimuli : (N, K, d) + responses : (N, R) + context : optional (N, C) + + Dimension key + ------------- + N : number of trials (batch dimension) + K : number of stimuli per trial (e.g. K=2 for a two-alternative task; + K=2 for the oddity task, where the reference is presented twice but + only the unique mean is stored — the duplication is encoded in the + task likelihood, not here) + d : dimensionality of each stimulus coordinate + R : number of response channels (R=1 for binary; R=2 for e.g. (choice, RT)) + C : number of context channels (observer-state covariates that condition the + likelihood but are not part of the stimulus space, e.g. fatigue level) Notes ----- - You can also think of this as a more generic ML-style dataset - ``X`` with shape (N, 2, d) plus ``y`` with shape (N,). The explicit - field names (refs/comparisons) are currently native to :class:`OddityTask`. + ``X`` with shape (N, K, d) plus ``y`` with shape (N, R). - This is intended to be JAX-friendly (PyTree of arrays) so likelihood and inference code can be JIT-compiled without touching Python containers. + - Context is optional. No current inbuilt uses. """ - refs: jnp.ndarray # TODO: references - comparisons: jnp.ndarray + stimuli: jnp.ndarray responses: jnp.ndarray + context: jnp.ndarray | None = None + stimulus_names: tuple[str, ...] = () def __post_init__(self) -> None: + # Callers that construct TrialData directly (bypassing ResponseData.add_trial) + # may pass responses as a plain 1D array (N,). Expand to + # (N, 1) here so strict shape validation below does not break those call sites. + # object.__setattr__ is required because the dataclass is frozen; __post_init__ + # is the only window to mutate fields during construction. + if self.responses.ndim == 1: + object.__setattr__(self, "responses", self.responses[:, None]) # Basic shape validation (keep lightweight; raise early for common mistakes). - if self.refs.ndim != 2: - raise ValueError(f"refs must be 2D (N,d), got shape {self.refs.shape}") - if self.comparisons.ndim != 2: + if self.stimuli.ndim != 3: raise ValueError( - f"comparisons must be 2D (N,d), got shape {self.comparisons.shape}" + f"stimuli must be 3D (N, K, d), got shape {self.stimuli.shape}" ) - if self.responses.ndim != 1: + if self.responses.ndim != 2: raise ValueError( - f"responses must be 1D (N,), got shape {self.responses.shape}" + f"responses must be 2D (N, R), got shape {self.responses.shape}" ) - if self.refs.shape[0] != self.comparisons.shape[0]: + if self.stimuli.shape[0] != self.responses.shape[0]: raise ValueError( - "refs and comparisons must have same first dimension; " - f"got {self.refs.shape[0]} vs {self.comparisons.shape[0]}" + "stimuli and responses must have same first dimension; " + f"got {self.stimuli.shape[0]} vs {self.responses.shape[0]}." ) - if self.refs.shape[0] != self.responses.shape[0]: + if self.context is not None and self.context.shape[0] != self.stimuli.shape[0]: raise ValueError( - "refs and responses must have same first dimension; " - f"got {self.refs.shape[0]} vs {self.responses.shape[0]}" + "if context is provided, it must share the same first dimension;" + f"got {self.context.shape[0]} vs {self.stimuli.shape[0]}." ) + def stimulus(self, name: str) -> jnp.ndarray: + """Return stimuli[:, k, :] for the slot named `name`. + + Parameters + ---------- + name : str + Must match one of the entries in ``stimulus_names``. + + Returns + ------- + jnp.ndarray, shape (N, d) + Stimulus coordinates for all trials at the named slot. + """ + if not self.stimulus_names: + raise ValueError( + "stimulus_names is empty — set it at construction time to use " + "named access, e.g. stimulus_names=('ref', 'comp')." + ) + if name not in self.stimulus_names: + raise ValueError( + f"unknown stimulus name '{name}'. " + f"Available names: {self.stimulus_names}." + ) + idx = self.stimulus_names.index(name) # resolved in Python, not JAX-traced + return self.stimuli[:, idx, :] + def __len__(self) -> int: """Number of trials (N).""" return int(self.responses.shape[0]) @@ -93,144 +137,196 @@ class ResponseData: """ def __init__(self) -> None: - self.refs: list[Any] = [] - self.comparisons: list[Any] = [] - self.responses: list[int] = [] + self.stimuli: list[np.array] = [] + self.responses: list[np.array] = [] + self.stim_shape: tuple | None = None # set on first add_trial call + self.contexts: list[np.array] = [] - def add_trial(self, ref: Any, comparison: Any, resp: int) -> None: + def add_trial(self, input: tuple[Any, ...], resp: Any, context: Any = None) -> None: """ append a single trial. Parameters ---------- - ref : Any - Reference stimulus (numpy array, list, etc.) - comparison : Any - Probe stimulus - resp : int - Subject response (binary or categorical) + input : tuple(Any, ...) + Group of presented stimuli each represented in any format (numpy array, + list, etc.) + Input must contain appropriate number of stimuli of appropriate dimension. + resp : Any + Subject response """ - self.refs.append(ref) - self.comparisons.append(comparison) - self.responses.append(resp) - - def add_batch(self, responses: list[int], trial_batch: TrialBatch) -> None: + input_arr = np.atleast_2d(np.asarray(input)) # (K, d) — 1D input treated as K=1 + resp_arr = np.atleast_1d(np.asarray(resp)) # (R,) — scalar treated as R=1 + if self.stimuli: + if self.stim_shape != input_arr.shape: + raise ValueError( + f"stimuli must have consistent shape (K, d). Expected {self.stim_shape}, but received {input_arr.shape}" + ) + else: + self.stim_shape = input_arr.shape + + if context is None: + if self.contexts: + raise ValueError( + "Context cannot be omitted if it was included in previous trials." + "This ResponseData instance expected context but received none." + ) + else: + if self.contexts or self.stimuli == []: + self.contexts.append(np.asarray(context)) + else: + raise ValueError( + "Context cannot be accepted if it was excluded from prior trials." + f"This ResponseData instance expected no context, but received {context}" + ) + self.stimuli.append(input_arr) + self.responses.append(resp_arr) + + def add_batch( + self, + responses: list[Any], + trial_batch: TrialBatch, + contexts: list[Any] | None = None, + ) -> None: """ Append responses for a batch of trials. Parameters ---------- - responses : List[int] - Responses corresponding to each (ref, comparison) in the trial batch. + responses : List[Any] + Responses corresponding to each stimulus group in the trial batch. trial_batch : TrialBatch The batch of proposed trials. """ - for (ref, comparison), resp in zip(trial_batch.stimuli, responses): - self.add_trial(ref, comparison, resp) + if contexts is None: + for input, resp in zip(trial_batch.stimuli, responses): + self.add_trial(input, resp) + else: + for input, resp, context in zip(trial_batch.stimuli, responses, contexts): + self.add_trial(input, resp, context) - def to_numpy(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]: - """Return refs, comparisons, responses as NumPy arrays.""" + def to_numpy(self) -> tuple[np.ndarray, np.ndarray]: + """Return stimuli and responses as NumPy arrays. + Will NOT include contexts by default. Output always fixed length of 2. + """ return ( - np.asarray(self.refs), - np.asarray(self.comparisons), + np.asarray(self.stimuli), # shape = (N, K, d) np.asarray(self.responses), ) def to_trial_data(self) -> TrialData: """Convert this log into the canonical JAX batch (:class:`TrialData`).""" - refs, comparisons, responses = self.to_numpy() - return TrialData( - refs=jnp.asarray(refs), - comparisons=jnp.asarray(comparisons), - responses=jnp.asarray(responses), - ) + stimuli, responses = self.to_numpy() + if self.contexts: + context = np.asarray(self.contexts) + return TrialData( + stimuli=jnp.asarray(stimuli), + responses=jnp.asarray(responses), + context=jnp.asarray(context), + ) + else: + return TrialData( + stimuli=jnp.asarray(stimuli), responses=jnp.asarray(responses) + ) @property - def trials(self) -> list[tuple[Any, Any, int]]: + def trials(self) -> list[tuple[Any, ...]]: """ - Return list of (ref, comparison, response) tuples. + Return list of (stim1, stim2, ... , response) tuples. + Does NOT include context information. Returns ------- list[tuple] - Each element is (ref, comparison, resp) + Each element is tuple representing all stimuli and the associated + response for a given trial. """ - return list(zip(self.refs, self.comparisons, self.responses)) + return [i + (r,) for i, r in zip(self.stimuli, self.responses)] def __len__(self) -> int: """Return number of trials.""" - return len(self.refs) + return len(self.stimuli) @classmethod def from_arrays( cls, X: jnp.ndarray | np.ndarray, y: jnp.ndarray | np.ndarray, - *, - comparisons: jnp.ndarray | np.ndarray | None = None, + c: jnp.ndarray | np.ndarray | None = None, ) -> ResponseData: """ Construct ResponseData from arrays. Parameters ---------- - X : array, shape (n_trials, 2, input_dim) or (n_trials, input_dim) - Stimuli. If 3D, second axis is [reference, comparison]. - If 2D, comparisons must be provided separately. - y : array, shape (n_trials,) + X : array, shape (n_trials, n_stimuli, input_dim) or (n_trials, input_dim) + Stimuli. If 3D, second axis is input stumili. For OddityTask, this is + (ref, comparison) + y : array, shape (n_trials, response_dim) Responses - comparisons : array, shape (n_trials, input_dim), optional - Probe stimuli. Only needed if X is 2D. + c : optional array, shape (n_trials, context_dim) + Context Returns ------- ResponseData Data container - Examples + OddityTask Example -------- >>> # From paired stimuli >>> X = jnp.array([[[0, 0], [1, 0]], [[1, 1], [2, 1]]]) + >>> # X is formed from refs = [[0, 0], [1, 1]], comparisons = [[1, 0], [2, 1]] >>> y = jnp.array([1, 0]) >>> data = ResponseData.from_arrays(X, y) - - >>> # From separate refs and comparisons - >>> refs = jnp.array([[0, 0], [1, 1]]) - >>> comparisons = jnp.array([[1, 0], [2, 1]]) - >>> data = ResponseData.from_arrays(refs, y, comparisons=comparisons) """ data = cls() X = np.asarray(X) y = np.asarray(y) - if X.ndim == 3: - # X is (n_trials, 2, input_dim) - refs = X[:, 0, :] - comparisons_arr = X[:, 1, :] - elif X.ndim == 2 and comparisons is not None: - refs = X - comparisons_arr = np.asarray(comparisons) - else: + if X.ndim == 2: + # reshape to ensure appropriate conversion to stimuli groups + dims = X.shape + new_dims = (dims[0], 1, dims[1]) + X = np.reshape(X, new_dims) + elif X.ndim != 3: raise ValueError( - "X must be shape (n_trials, 2, input_dim) or " - "(n_trials, input_dim) with comparisons argument" + "X must be shape (n_trials, n_stimuli, input_dim) or \ + (n_trials, input_dim)." ) - - for ref, comparison, response in zip(refs, comparisons_arr, y): - data.add_trial(ref, comparison, int(response)) + if y.shape[0] != X.shape[0]: + raise ValueError("X and y must contain the same n_trials.") + if c is not None and c.shape[0] != X.shape[0]: + raise ValueError("c must contain same n_trials as X.") + + # X is (n_trials, K, d) — split into per-trial tuples of K stimulus rows + stimuli = [] + for plane in X: + stimuli.append(tuple(plane)) + + if c is not None: + for stim, response, context in zip(stimuli, y, c): + data.add_trial(stim, response, context) + else: + for stim, response in zip(stimuli, y): + data.add_trial(stim, response) return data @classmethod def from_trial_data(cls, data: TrialData) -> ResponseData: """Build a ResponseData log from a :class:`TrialData` batch.""" - refs = np.asarray(data.refs) - comps = np.asarray(data.comparisons) + stimuli = np.asarray(data.stimuli) ys = np.asarray(data.responses) out = cls() - for r, c, y in zip(refs, comps, ys): - out.add_trial(r, c, int(y)) + if data.context is not None: + cs = np.asarray(data.context) + for s, y, c in zip(stimuli, ys, cs): + out.add_trial(s, y, c) + else: + for s, y in zip(stimuli, ys): + out.add_trial(s, y) return out def merge(self, other: ResponseData) -> None: @@ -242,9 +338,31 @@ def merge(self, other: ResponseData) -> None: other : ResponseData Dataset to merge """ - self.refs.extend(other.refs) - self.comparisons.extend(other.comparisons) + no_empty = self.stimuli and other.stimuli + + if no_empty and self.stimuli[0].shape != other.stimuli[0].shape: + raise ValueError( + "Cannot merge ResponseData instances with inconsistent input shapes." + f"Received input shapes of {self.stimuli[0].shape} and {other.stimuli[0].shape}" + ) + if no_empty and self.responses[0].shape != other.responses[0].shape: + raise ValueError( + "Cannot merge ResponseData instances with inconsistent response shapes." + f"Received response shapes of {self.responses[0].shape} and {other.responses[0].shape}" + ) + + self.stimuli.extend(other.stimuli) self.responses.extend(other.responses) + both_contexts = self.contexts and other.contexts + + if self.contexts == [] and other.contexts == []: + pass + elif both_contexts and self.contexts[0].shape == other.contexts[0].shape: + self.contexts.extend(other.contexts) + else: + raise ValueError( + "Cannot merge ResponseData instances with inconsistent context." + ) def tail(self, n: int) -> ResponseData: """ @@ -261,9 +379,10 @@ def tail(self, n: int) -> ResponseData: New dataset with last n trials """ new_data = ResponseData() - new_data.refs = self.refs[-n:] - new_data.comparisons = self.comparisons[-n:] + new_data.stimuli = self.stimuli[-n:] new_data.responses = self.responses[-n:] + if self.contexts is not None: + new_data.contexts = self.contexts[-n:] return new_data def copy(self) -> ResponseData: @@ -276,28 +395,31 @@ def copy(self) -> ResponseData: New dataset with copied data """ new_data = ResponseData() - new_data.refs = list(self.refs) - new_data.comparisons = list(self.comparisons) + new_data.stimuli = list(self.stimuli) new_data.responses = list(self.responses) + if self.contexts is not None: + new_data.contexts = list(self.contexts) return new_data class TrialBatch: """ - Container for a proposed batch of trials + Container for a proposed batch of trials. + Does NOT include context or responses. Attributes ---------- - stimuli : List[Tuple[Any, Any]] - Each trial is a (reference, comparison) tuple. + stimuli : List[Tuple[Any, ...]] + Each trial is a tuple of all presented stimuli (stim1, stim2, ...). + For OddityTask this is (reference, comparison) """ - def __init__(self, stimuli: list[tuple[Any, Any]]) -> None: + def __init__(self, stimuli: list[tuple[Any, ...]]) -> None: self.stimuli = list(stimuli) @classmethod - def from_stimuli(cls, pairs: list[tuple[Any, Any]]) -> TrialBatch: + def from_stimuli(cls, groups: list[tuple[Any, ...]]) -> TrialBatch: """ - Construct a TrialBatch from a list of stimuli (ref, comparison) pairs. + Construct a TrialBatch from a list of stimuli (stim1, stim2, ...) groups. """ - return cls(pairs) + return cls(groups) diff --git a/src/psyphy/data/io.py b/src/psyphy/data/io.py index 56a0225..5881c8f 100644 --- a/src/psyphy/data/io.py +++ b/src/psyphy/data/io.py @@ -31,7 +31,10 @@ def save_responses_csv(data: TrialData | ResponseData, path: PathLike) -> None: """ - Save ResponseData to a CSV file. + Save ResponseData to a CSV file. This is completely task agnostic. In the + current implementation, it will simply create as many stimulus columns as + inputs in the data and will label them "stimulus 1", "stimulus 2", etc. The + response column will be labeled "response" Parameters ---------- @@ -39,27 +42,34 @@ def save_responses_csv(data: TrialData | ResponseData, path: PathLike) -> None: path : str or Path """ if isinstance(data, TrialData): - refs, comparisons, resps = ( - np.asarray(data.refs), - np.asarray(data.comparisons), + inputs, resps = ( + np.asarray(data.stimuli), np.asarray(data.responses), ) else: - refs, comparisons, resps = data.to_numpy() + inputs, resps = data.to_numpy() + row_names = [] + for s in range(inputs.shape[1]): + row_names.append("stimulus " + str(s)) + row_names.append("response") with open(path, "w", newline="") as f: writer = csv.writer(f) - writer.writerow(["ref", "probe", "response"]) - for r, p, y in zip(refs, comparisons, resps): - writer.writerow([r.tolist(), p.tolist(), int(y)]) + writer.writerow(row_names) + for x, y in zip(inputs, resps): + row = x.tolist() + row.append(y.tolist()) + writer.writerow(row) def load_responses_csv(path: PathLike) -> TrialData: """ Load ResponseData from a CSV file. + Currently catering to OddityTask data format ONLY. Parameters ---------- path : str or Path + - must be of expected format for OddityTask Returns ------- @@ -72,7 +82,7 @@ def load_responses_csv(path: PathLike) -> TrialData: ref = ast.literal_eval(row["ref"]) probe = ast.literal_eval(row["probe"]) resp = int(row["response"]) - data.add_trial(ref, probe, resp) + data.add_trial((ref, probe), resp) return data.to_trial_data() diff --git a/src/psyphy/model/likelihood.py b/src/psyphy/model/likelihood.py index 7c92728..bc9a776 100644 --- a/src/psyphy/model/likelihood.py +++ b/src/psyphy/model/likelihood.py @@ -132,7 +132,7 @@ def loglik( params : Any Model parameters. data : Any - Object with ``.refs``, ``.comparisons``, ``.responses`` array attributes. + Object with ``.stimuli``, ``.responses`` array attributes. model : Any Model instance. key : jax.random.KeyArray, optional @@ -144,9 +144,11 @@ def loglik( jnp.ndarray Scalar sum of Bernoulli log-likelihoods over all trials. """ - refs = jnp.asarray(data.refs) - comparisons = jnp.asarray(data.comparisons) + stimuli = jnp.asarray(data.stimuli) + refs = stimuli[:, 0, :] + comparisons = stimuli[:, 1, :] responses = jnp.asarray(data.responses) + responses = responses.astype(int) n_trials = int(refs.shape[0]) base_key = key if key is not None else jr.PRNGKey(0) diff --git a/src/psyphy/trial_placement/__init__.py b/src/psyphy/trial_placement/__init__.py index 2bff8f0..a80984b 100644 --- a/src/psyphy/trial_placement/__init__.py +++ b/src/psyphy/trial_placement/__init__.py @@ -4,6 +4,8 @@ Non-acquisition-based trial placement strategies. +NOTE: this modeule is currently untested and still in active development. + This module provides classical (non-Bayesian-optimization) placement strategies: - GridPlacement: Fixed grid designs for systematic exploration diff --git a/src/psyphy/trial_placement/grid.py b/src/psyphy/trial_placement/grid.py index b43d581..4cda6c0 100644 --- a/src/psyphy/trial_placement/grid.py +++ b/src/psyphy/trial_placement/grid.py @@ -4,9 +4,6 @@ Grid-based placement strategy. -MVP: -- Iterates through a fixed list of grid points. -- Ignores the posterior (non-adaptive). Full WPPM mode: - Could refine the grid adaptively around regions of high posterior uncertainty. @@ -26,8 +23,11 @@ class GridPlacement: Notes ----- - - grid = your set of allowable trials. - - this class simply walks through that set. + - grid = your set of allowable trials; this class simply walks through that set. + - Not yet tested. Pending the same ``TrialBatch`` redesign as ``SobolPlacement``: + each trial's stimuli should be a single ``np.ndarray`` of shape ``(K, d)`` + rather than a two-element tuple, to align with the generalised ``TrialData`` + layout introduced in the data-object refactor. """ def __init__(self, grid_points): @@ -41,7 +41,6 @@ def propose(self, posterior, batch_size: int) -> TrialBatch: Parameters ---------- posterior : Posterior - Ignored in MVP (grid is non-adaptive). batch_size : int Number of trials to return. diff --git a/src/psyphy/trial_placement/sobol.py b/src/psyphy/trial_placement/sobol.py index 58d1640..751b688 100644 --- a/src/psyphy/trial_placement/sobol.py +++ b/src/psyphy/trial_placement/sobol.py @@ -29,6 +29,20 @@ class SobolPlacement: Bounds per dimension. seed : int, optional RNG seed. + + Notes + ----- + Not yet tested. Pending two design changes tracked in the trial-placement + follow-up issue: + + 1. ``TrialBatch`` currently stores stimuli as ``list[tuple[Any, Any]]`` + (two-stimulus tuples). It should be updated to ``list[np.ndarray]`` each + of shape ``(K, d)`` to align with ``TrialData.stimuli`` and allow + ``ResponseData.add_batch`` to consume it without conversion. + + 2. The zero reference vector is hardcoded here. It should be an explicit + parameter so the caller controls which point in stimulus space acts as + the reference, rather than always using the origin. """ def __init__(self, dim: int, bounds, seed: int = 0): diff --git a/tests/test_3stimulus_decision_rule.py b/tests/test_3stimulus_decision_rule.py index 87aedde..f4e756c 100644 --- a/tests/test_3stimulus_decision_rule.py +++ b/tests/test_3stimulus_decision_rule.py @@ -39,7 +39,7 @@ def test_identical_stimuli_near_chance(self): data = ResponseData() ref = jnp.array([0.5, 0.5]) comparison = ref # identical! - data.add_trial(ref, comparison, resp=1) + data.add_trial((ref, comparison), resp=1) # Compute P(correct) with many samples for accurate estimate # loglik = log P(correct | ref, comparison, params) @@ -82,7 +82,7 @@ def test_distant_stimuli_near_perfect(self): data = ResponseData() ref = jnp.array([0.0, 0.0]) comparison = jnp.array([5.0, 5.0]) # Very far! - data.add_trial(ref, comparison, resp=1) + data.add_trial((ref, comparison), resp=1) # Compute P(correct) loglik = model.likelihood.loglik( @@ -121,7 +121,7 @@ def test_intermediate_stimuli_between_chance_and_perfect(self): data = ResponseData() ref = jnp.array([0.0, 0.0]) comparison = jnp.array([0.3, 0.3]) # Moderately different - data.add_trial(ref, comparison, resp=1) + data.add_trial((ref, comparison), resp=1) # Compute P(correct) loglik = model.likelihood.loglik(params, data, model, key=jr.PRNGKey(42)) @@ -155,7 +155,7 @@ def test_wishart_mode_same_behavior(self): # Test with identical stimuli data_identical = ResponseData() ref = jnp.array([0.5, 0.5]) - data_identical.add_trial(ref, ref, resp=1) + data_identical.add_trial((ref, ref), resp=1) model.likelihood = OddityTask( config=OddityTaskConfig(num_samples=3000, bandwidth=1e-2) @@ -171,7 +171,7 @@ def test_wishart_mode_same_behavior(self): # Test with distant stimuli data_distant = ResponseData() comparison_far = jnp.array([5.0, 5.0]) - data_distant.add_trial(ref, comparison_far, resp=1) + data_distant.add_trial((ref, comparison_far), resp=1) loglik_distant = model.likelihood.loglik( params, @@ -304,7 +304,7 @@ def test_very_small_bandwidth(self): data = ResponseData() ref = jnp.array([0.0, 0.0]) comparison = jnp.array([2.0, 2.0]) - data.add_trial(ref, comparison, resp=1) + data.add_trial((ref, comparison), resp=1) loglik = model.likelihood.loglik(params, data, model, key=jr.PRNGKey(42)) @@ -336,7 +336,7 @@ def test_small_num_samples(self): data = ResponseData() ref = jnp.array([0.0, 0.0]) comparison = jnp.array([3.0, 3.0]) - data.add_trial(ref, comparison, resp=1) + data.add_trial((ref, comparison), resp=1) loglik = model.likelihood.loglik(params, data, model, key=jr.PRNGKey(42)) @@ -366,7 +366,7 @@ def test_large_num_samples_convergence(self): data = ResponseData() ref = jnp.array([0.0, 0.0]) comparison = jnp.array([3.0, 3.0]) - data.add_trial(ref, comparison, resp=1) + data.add_trial((ref, comparison), resp=1) # Compute with two different seeds (same task config) loglik1 = model.likelihood.loglik(params, data, model, key=jr.PRNGKey(42)) @@ -402,13 +402,13 @@ def test_multiple_trials_accumulates_loglik(self): ) params = model.init_params(jr.PRNGKey(0)) - # Create multiple trials - data_multi = ResponseData() ref = jnp.array([0.0, 0.0]) comparison = jnp.array([1.0, 1.0]) + # Create multiple trials + data_multi = ResponseData() for _ in range(5): - data_multi.add_trial(ref, comparison, resp=1) + data_multi.add_trial((ref, comparison), resp=1) # Compute combined loglik loglik_multi = model.likelihood.loglik( @@ -420,7 +420,7 @@ def test_multiple_trials_accumulates_loglik(self): # Single trial loglik data_single = ResponseData() - data_single.add_trial(ref, comparison, resp=1) + data_single.add_trial((ref, comparison), resp=1) loglik_single = model.likelihood.loglik( params, @@ -433,8 +433,14 @@ def test_multiple_trials_accumulates_loglik(self): # (with some tolerance for MC variance) expected_loglik = 5 * loglik_single - assert ( - jnp.abs(loglik_multi - expected_loglik) / jnp.abs(expected_loglik) < 0.1 + # Use combined absolute +relative tolerance (np.allclose pattern) so the + # assertion is robust when expected_loglik is near zero, as + # pure relative + # tolerance blows up in that regime due to MC variance across key splits. + atol = 0.05 + rtol = 0.1 + assert jnp.abs(loglik_multi - expected_loglik) <= atol + rtol * jnp.abs( + expected_loglik ), ( f"Multi-trial loglik = {loglik_multi:.3f}, " f"expected ≈ {expected_loglik:.3f} (5 × single trial). " @@ -443,7 +449,8 @@ def test_multiple_trials_accumulates_loglik(self): def test_zero_samples_raises_error(self): """Test that num_samples=0 raises an error.""" - # Strict API: num_samples comes from task config, so invalid config should fail at construction. + # Strict API: num_samples comes from task config, so invalid + # config should fail at construction. with pytest.raises(ValueError, match="num_samples must be > 0"): _ = OddityTask(config=OddityTaskConfig(num_samples=0, bandwidth=1e-2)) @@ -463,7 +470,7 @@ def test_reproducibility_with_same_seed(self): params = model.init_params(jr.PRNGKey(0)) data = ResponseData() - data.add_trial(jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0]), resp=1) + data.add_trial((jnp.array([0.0, 0.0]), jnp.array([1.0, 1.0])), resp=1) # Compute twice with same seed loglik1 = model.likelihood.loglik( @@ -510,7 +517,7 @@ def test_decision_rule_symmetric_in_reference_samples(self): data = ResponseData() ref = jnp.array([0.0, 0.0]) comparison = jnp.array([1.0, 1.0]) - data.add_trial(ref, comparison, resp=1) + data.add_trial((ref, comparison), resp=1) # Compute with one seed loglik1 = model.likelihood.loglik( diff --git a/tests/test_covariance_field.py b/tests/test_covariance_field.py index e0e57fc..2419b81 100644 --- a/tests/test_covariance_field.py +++ b/tests/test_covariance_field.py @@ -59,7 +59,7 @@ def sample_data(): key, subkey = jr.split(key) probe = ref + 0.1 * jr.normal(subkey, shape=(2,)) response = 1 if i % 2 == 0 else 0 - data.add_trial(ref, probe, response) + data.add_trial((ref, probe), response) return data diff --git a/tests/test_data_format.py b/tests/test_data_format.py new file mode 100644 index 0000000..b4cc3f3 --- /dev/null +++ b/tests/test_data_format.py @@ -0,0 +1,285 @@ +import jax.numpy as jnp +import numpy as np +import pytest + +from psyphy.data.dataset import ResponseData, TrialData + + +class TestResponses: + """Test that all data structures accept appropriate continuous & > 1D responses""" + + def test_cts_responses(self): + """1D continuous response values be accepted and manipulated appropriately.""" + + # Create ResponseData + data = ResponseData() + stimuli = ([0.5, 0.5], [1, 0]) + + # Add trials with non-binary responses: + data.add_trial(input=stimuli, resp=1.0) + data.add_trial(input=stimuli, resp=0.3) + data.add_trial(input=stimuli, resp=0.82) + + # Check to ensure responses are not stored as binary: + nb_resp = data.responses[0] + assert nb_resp is not int, ( + "continuous response was incorrectly forced to type int." + ) + + # Convert to TrialData + td_data = data.to_trial_data() + + # Check that TrialData still stores continuous responses: + nb_resp = td_data.responses[1] + assert nb_resp == 0.3, ( + "continuous response was not correctly preserved during ResponseData \ + --> TrialData conversion" + ) + + # Convert back to ResponseData & recheck + r_data = ResponseData.from_trial_data(td_data) + assert r_data.responses[2] == 0.82, ( + "continuous response was not correctly preserved during TrialData \ + --> ResponseData re-conversion" + ) + + def test_response_dimensions(self): + """Test that responses can be > 1D""" + + # Create Response Data: + stimuli = (jnp.array([0.5, 0.5]), jnp.array([1, 1])) + + # Create many different types of responses to add: + responses = [ + 1, + 0.5, + jnp.array([0.5, 0.5]), + np.array([1, 23, 0.1]), + [1, 3], + [1, 0.2, 0.9], + ] + + # Add & test responses as both ResponseData and TrialData: + for resp in responses: + data = ResponseData() + data.add_trial(stimuli, resp) + assert (data.responses[0] == np.asarray(resp)).all(), ( + f"The response value of {resp} was incorrectly saved as \ + {data.responses} in ResponseData" + ) + td_data = data.to_trial_data() + assert (td_data.responses == jnp.asarray(resp)).all(), ( + f"The response value of {resp} was incorrectly saved as\ + {td_data.responses} in TrialData" + ) + + +class TestInputs: + """Test that all data structures correctly handle general response types. + + Should be able to handle stimuli of arbitrary dimension and arbitrary K + (number of stimuli per trial). + + Should NOT accept stimuli with altered number of + stimuli or altered stimulus dimensions once precedent has already been + established for a given ResponseData instance. + """ + + def test_num_stim(self): + """Test all data structures can handle when there are not exactly two stimuli""" + response = 1 + stimulus_1 = (jnp.array([1, 1]),) # K=1: wrap in tuple so len() gives K, not d + stimulus_2 = (jnp.array([0, 0.5]), jnp.array([1, 1])) + stimulus_3 = (jnp.array([0, 0.5]), jnp.array([1, 1]), jnp.array([0, 6])) + stimuli = [stimulus_1, stimulus_2, stimulus_3] + for i, stim in enumerate(stimuli): + data = ResponseData() + data.add_trial(stim, response) + data.add_trial(stim, response) + with pytest.raises(ValueError): + data.add_trial(stimuli[abs(i - 1)], response) + assert len(stim) == data.stim_shape[0], ( + f"ResponseData was given {len(stim)} stimuli, but represented {data.stim_shape[0]}." + ) + td_data = data.to_trial_data() + assert len(stim) == td_data.stimuli.shape[1], ( + f"TrialData was given {len(stim)} stimuli, but represented {data.stim_shape[0]}." + ) + + def test_stim_shape_none_before_first_trial(self): + """stim_shape must be None on a fresh ResponseData, not raise AttributeError.""" + data = ResponseData() + assert data.stim_shape is None + + def test_from_arrays_2d_input(self): + """from_arrays must accept 2D X (n_trials, d) and treat it as K=1.""" + X = np.array([[0.1, 0.2], [0.3, 0.4]]) # shape (2, d) — no K axis + y = np.array([[1], [0]]) + data = ResponseData.from_arrays(X, y) + assert len(data) == 2 + assert data.stim_shape == (1, 2) # K=1, d=2 + + def test_stim_dem(self): + """Test that all data structures correctly handle stimuli that are not 2D""" + response = 1 + stim_1 = (jnp.array([0]), jnp.array([1])) + stim_2 = (jnp.array([0, 0.5]), jnp.array([1, 1])) + stim_3 = (jnp.array([0, 1, 0.3]), jnp.array([0.4, 0.4, 0.4])) + stimuli = [stim_1, stim_2, stim_3] + for i, stim in enumerate(stimuli): + data = ResponseData() + data.add_trial(stim, response) + data.add_trial(stim, response) + with pytest.raises(ValueError): + data.add_trial(stimuli[abs(i - 1)], response) + dim = stim[0].shape[0] + assert dim == data.stim_shape[1], ( + f"ResponseData was given {dim}-dim stimuli, but represented \ + {data.stim_shape[1]} dimensions." + ) + td_data = data.to_trial_data() + assert dim == td_data.stimuli.shape[2], ( + f"TrialData was given {dim}-dim stimuli, but represented \ + {data.stim_shape[0]} dimensions." + ) + + +class TestShapeValidation: + """TrialData must reject arrays that are not exactly (N, K, d) and (N, R).""" + + def test_4d_stimuli_raises(self): + """A 4D stimuli array must be rejected — previously slipped through > 3 check.""" + stimuli_4d = jnp.ones((5, 2, 3, 4)) # one extra axis + responses = jnp.ones((5, 1)) + with pytest.raises(ValueError, match="stimuli must be 3D"): + TrialData(stimuli=stimuli_4d, responses=responses) + + def test_2d_stimuli_raises(self): + """A 2D stimuli array (missing K axis) must be rejected.""" + stimuli_2d = jnp.ones((5, 3)) + responses = jnp.ones((5, 1)) + with pytest.raises(ValueError, match="stimuli must be 3D"): + TrialData(stimuli=stimuli_2d, responses=responses) + + def test_3d_responses_raises(self): + """A 3D responses array must be rejected — previously slipped through > 2 check.""" + stimuli = jnp.ones((5, 2, 3)) + responses_3d = jnp.ones((5, 1, 1)) + with pytest.raises(ValueError, match="responses must be 2D"): + TrialData(stimuli=stimuli, responses=responses_3d) + + def test_valid_shapes_accepted(self): + """Correct (N, K, d) / (N, R) shapes must not raise.""" + TrialData(stimuli=jnp.ones((5, 2, 3)), responses=jnp.ones((5, 1))) + + +class TestContext: + """Ensure that the optional attribute context behaves as expected for all + data structures. + """ + + def test_merge_with_1d_context(self): + """merge() must not crash when context is 1D (scalar per trial). + Previously raised IndexError via .shape[1] on a shape-() array.""" + stim = ([0.5, 0.5], [1.0, 0.0]) + a = ResponseData() + a.add_trial(stim, 1, context=0.3) + b = ResponseData() + b.add_trial(stim, 0, context=0.7) + a.merge(b) + assert len(a) == 2 + + def test_add_1Dcontext_to_ResponseData(self): + """Add 1D contexts""" + data = ResponseData() + stimuli = ([0.5, 0.5], [1, 0]) + response = 1 + c1 = 2 + c2 = 9.2 + data.add_trial(input=stimuli, resp=response, context=c1) + assert data.contexts == [c1] + data.add_trial(input=stimuli, resp=response, context=c2) + assert data.contexts == [c1, c2] + + def test_add_nDcontext_to_ResponseData(self): + """Add n-dimensional contexts""" + data = ResponseData() + stimuli = ([0.5, 0.5], [1, 0]) + response = 1 + c1 = [0, 2, 0.1] + c2 = [3, 1, 2] + data.add_trial(input=stimuli, resp=response, context=c1) + assert (data.contexts[0] == np.asarray(c1)).all() + data.add_trial(input=stimuli, resp=response, context=c2) + for i, c in enumerate([c1, c2]): + assert (data.contexts[i] == c).all() + + def test_convert_data_with_context(self): + """test conversions between ResponseData and TrialData for instances + with context.""" + + # Create ResponseData instance + data = ResponseData() + + # Add trial with context to ResponseData + stimuli = ([0.5, 0.5], [1, 0]) + response = 1 + context = [0, 2, 0.1] + data.add_trial(input=stimuli, resp=response, context=context) + + # Convert to TrialData and check context + td_data = data.to_trial_data() + np.testing.assert_array_equal(td_data.context, jnp.asarray([context])) + + # Convert back to ResponseData and check context + r_data = ResponseData.from_trial_data(td_data) + np.testing.assert_allclose(r_data.contexts[0], context) + + +class TestStimulusNames: + """TrialData.stimulus_names and the .stimulus() named accessor.""" + + _stimuli = jnp.array( + [[[0.3, -0.5], [0.4, -0.4]], [[0.1, 0.2], [0.5, 0.6]]] + ) # (N=2, K=2, d=2) + _responses = jnp.array([[1], [0]]) # (N=2, R=1) + + def test_stimulus_names_default_empty(self): + """stimulus_names defaults to an empty tuple when not provided.""" + data = TrialData(stimuli=self._stimuli, responses=self._responses) + assert data.stimulus_names == () + + def test_stimulus_names_stored(self): + """stimulus_names is stored and retrievable.""" + data = TrialData( + stimuli=self._stimuli, + responses=self._responses, + stimulus_names=("ref", "comp"), + ) + assert data.stimulus_names == ("ref", "comp") + + def test_stimulus_accessor_returns_correct_slice(self): + """data.stimulus('ref') returns stimuli[:, 0, :] for slot 0.""" + data = TrialData( + stimuli=self._stimuli, + responses=self._responses, + stimulus_names=("ref", "comp"), + ) + np.testing.assert_array_equal(data.stimulus("ref"), self._stimuli[:, 0, :]) + np.testing.assert_array_equal(data.stimulus("comp"), self._stimuli[:, 1, :]) + + def test_stimulus_accessor_unknown_name_raises(self): + """data.stimulus() with an unrecognised name raises ValueError.""" + data = TrialData( + stimuli=self._stimuli, + responses=self._responses, + stimulus_names=("ref", "comp"), + ) + with pytest.raises(ValueError, match="unknown stimulus name"): + data.stimulus("nonexistent") + + def test_stimulus_accessor_empty_names_raises(self): + """data.stimulus() raises if stimulus_names was not set.""" + data = TrialData(stimuli=self._stimuli, responses=self._responses) + with pytest.raises(ValueError, match="stimulus_names"): + data.stimulus("ref") diff --git a/tests/test_imports.py b/tests/test_imports.py index 788369b..72cfa1c 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -14,5 +14,6 @@ def test_top_level_api_imports(): "MAPPosterior", "ResponseData", "TrialBatch", + "TrialData", ]: assert hasattr(p, name) diff --git a/tests/test_mc_likelihood.py b/tests/test_mc_likelihood.py index 1436b2e..0089091 100644 --- a/tests/test_mc_likelihood.py +++ b/tests/test_mc_likelihood.py @@ -54,9 +54,12 @@ def test_mc_likelihood_method_exists(self, model): def test_mc_likelihood_shape_and_dtype(self, model, simple_params): """Test MC likelihood returns scalar with correct dtype.""" + + refs = jnp.array([[0.0, 0.0]]) + comparisons = jnp.array([[0.1, 0.1]]) + data = TrialData( - refs=jnp.array([[0.0, 0.0]]), - comparisons=jnp.array([[0.1, 0.1]]), + stimuli=jnp.stack([refs, comparisons], axis=1), responses=jnp.array([1], dtype=jnp.int32), ) @@ -86,8 +89,7 @@ def test_mc_likelihood_convergence_to_analytical(self, model, simple_params): ref = jnp.array([0.0, 0.0]) comparison = jnp.array([0.5, 0.5]) # Far enough for clear discrimination data = TrialData( - refs=jnp.array([ref]), - comparisons=jnp.array([comparison]), + stimuli=jnp.stack([jnp.array([ref]), jnp.array([comparison])], axis=1), responses=jnp.array([1], dtype=jnp.int32), ) @@ -137,9 +139,11 @@ def test_mc_likelihood_increases_num_samples(self, model, simple_params): noise=model.noise, ) data = ResponseData() + refs = jnp.array([[0.0, 0.0]]) + comparisons = jnp.array([[0.2, 0.1]]) + data = TrialData( - refs=jnp.array([[0.0, 0.0]]), - comparisons=jnp.array([[0.2, 0.1]]), + stimuli=jnp.stack([refs, comparisons], axis=1), responses=jnp.array([1], dtype=jnp.int32), ) @@ -176,9 +180,11 @@ def test_mc_likelihood_increases_num_samples(self, model, simple_params): def test_mc_likelihood_batch_correctness(self, model, simple_params): """MC likelihood should handle multiple trials correctly.""" + refs = jnp.array([[0.0, 0.0], [0.5, 0.5], [-0.3, 0.2]]) + comparisons = jnp.array([[0.1, 0.1], [0.6, 0.4], [-0.2, 0.3]]) + data = TrialData( - refs=jnp.array([[0.0, 0.0], [0.5, 0.5], [-0.3, 0.2]]), - comparisons=jnp.array([[0.1, 0.1], [0.6, 0.4], [-0.2, 0.3]]), + stimuli=jnp.stack([refs, comparisons], axis=1), responses=jnp.array([1, 0, 1], dtype=jnp.int32), ) @@ -202,9 +208,10 @@ def test_mc_likelihood_bandwidth_sensitivity(self, model, simple_params, bandwid Smaller bandwidth -> sharper transitions (closer to step function). """ data = ResponseData() + refs = jnp.array([[0.0, 0.0]]) + comparisons = jnp.array([[0.1, 0.1]]) data = TrialData( - refs=jnp.array([[0.0, 0.0]]), - comparisons=jnp.array([[0.1, 0.1]]), + stimuli=jnp.stack([refs, comparisons], axis=1), responses=jnp.array([1], dtype=jnp.int32), ) @@ -237,9 +244,7 @@ def test_mc_likelihood_reproducibility(self, model, simple_params): noise=model.noise, ) data = ResponseData() - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.1, 0.1]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([0.1, 0.1])), resp=1) ll_1 = model.likelihood.loglik( params=simple_params, @@ -268,9 +273,7 @@ def test_mc_likelihood_different_seeds_vary(self, model, simple_params): noise=model.noise, ) data = ResponseData() - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.1, 0.1]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([0.1, 0.1])), resp=1) ll_1 = model.likelihood.loglik( params=simple_params, @@ -327,8 +330,7 @@ def test_mc_likelihood_wishart_mode(self): # Create oddity trial data data = ResponseData() data.add_trial( - ref=jnp.array([0.0, 0.0]), - comparison=jnp.array([0.5, 0.0]), + input=(jnp.array([0.0, 0.0]), jnp.array([0.5, 0.0])), resp=1, ) @@ -356,7 +358,7 @@ def test_mc_likelihood_identical_stimuli(self, model): data = ResponseData() # Identical stimuli stim = jnp.array([0.5, 0.5]) - data.add_trial(ref=stim, comparison=stim, resp=1) + data.add_trial(input=(stim, stim), resp=1) model = WPPM( input_dim=model.input_dim, @@ -383,9 +385,7 @@ def test_mc_likelihood_extreme_discriminability(self, model): params = model.init_params(jr.PRNGKey(0)) data = ResponseData() # Very far apart - data.add_trial( - ref=jnp.array([-0.9, -0.9]), comparison=jnp.array([0.9, 0.9]), resp=1 - ) + data.add_trial(input=(jnp.array([-0.9, -0.9]), jnp.array([0.9, 0.9])), resp=1) model = WPPM( input_dim=model.input_dim, @@ -454,9 +454,7 @@ def test_gradients_are_finite_normal_case(self, model): # Create simple dataset: one trial with moderate discriminability data = ResponseData() - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.5, 0.5]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([0.5, 0.5])), resp=1) # Define loss function (negative log-likelihood) def loss_fn(p): @@ -501,9 +499,7 @@ def test_gradients_finite_with_identical_stimuli(self, model): data = ResponseData() # Identical stimuli - observer is just guessing - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.0, 0.0]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([0.0, 0.0])), resp=1) def loss_fn(p): return -model.likelihood.loglik( @@ -534,9 +530,7 @@ def test_gradients_finite_with_extreme_discriminability(self, model): data = ResponseData() # Very far apart stimuli - near perfect discrimination - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([5.0, 5.0]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([5.0, 5.0])), resp=1) def loss_fn(p): return -model.likelihood.loglik( @@ -574,9 +568,7 @@ def test_gradients_finite_with_small_bandwidth(self, model): params = model.init_params(jr.PRNGKey(0)) data = ResponseData() - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.3]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([0.3, 0.3])), resp=1) def loss_fn(p): return -model.likelihood.loglik( @@ -628,9 +620,7 @@ def test_log_likelihood_finite_at_chance(self, model): data = ResponseData() # Identical stimuli -> chance performance - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.0, 0.0]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([0.0, 0.0])), resp=1) ll = model.likelihood.loglik( params=params, @@ -661,9 +651,7 @@ def test_log_likelihood_finite_at_perfect(self, model): data = ResponseData() # Very far apart -> near perfect discrimination - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([10.0, 10.0]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([10.0, 10.0])), resp=1) ll = model.likelihood.loglik( params=params, @@ -701,7 +689,7 @@ def test_probability_never_exactly_zero_or_one(self, model): for ref, comp, description in test_cases: data = ResponseData() - data.add_trial(ref=ref, comparison=comp, resp=1) + data.add_trial(input=(ref, comp), resp=1) ll = model.likelihood.loglik( params=params, @@ -750,8 +738,7 @@ def test_stability_with_very_small_noise(self): data = ResponseData() data.add_trial( - ref=jnp.array([0.0, 0.0, 0.0]), - comparison=jnp.array([0.1, 0.1, 0.1]), + input=(jnp.array([0.0, 0.0, 0.0]), jnp.array([0.1, 0.1, 0.1])), resp=1, ) @@ -788,8 +775,7 @@ def test_stability_with_wishart_mode(self): data = ResponseData() data.add_trial( - ref=jnp.array([0.0, 0.0, 0.0]), - comparison=jnp.array([0.5, 0.5, 0.5]), + input=(jnp.array([0.0, 0.0, 0.0]), jnp.array([0.5, 0.5, 0.5])), resp=1, ) @@ -838,9 +824,7 @@ def test_variance_decreases_with_more_samples(self, model): data = ResponseData() # Use intermediate discriminability for variance (not too easy, not too hard) - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.3, 0.3]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([0.3, 0.3])), resp=1) # Test with small and large sample sizes sample_sizes = [100, 1600] # 16x difference diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 46d0b78..1770ad6 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -48,7 +48,7 @@ def data_arrays(self): # Create ResponseData object data = ResponseData() for i in range(n): - data.add_trial(refs[i], comparisons[i], int(y[i])) + data.add_trial((refs[i], comparisons[i]), int(y[i])) return data def test_optimizer_fit(self, model, data_arrays): @@ -89,7 +89,7 @@ def posterior(self): data = ResponseData() for i in range(n): - data.add_trial(refs[i], comparisons[i], int(y[i])) + data.add_trial((refs[i], comparisons[i]), int(y[i])) optimizer = MAPOptimizer(steps=20) return optimizer.fit(model, data) @@ -158,7 +158,7 @@ def initial_data(self): data = ResponseData() for i in range(n): - data.add_trial(refs[i], comparisons[i], int(y[i])) + data.add_trial((refs[i], comparisons[i]), int(y[i])) return data def test_initial_fit(self, model, initial_data): @@ -191,7 +191,7 @@ def test_full_new_api_workflow(self): data = ResponseData() for i in range(n): - data.add_trial(refs[i], comparisons[i], int(y[i])) + data.add_trial((refs[i], comparisons[i]), int(y[i])) # 3. Fit model (new API) optimizer = MAPOptimizer(steps=50) @@ -234,7 +234,7 @@ def test_manual_online_loop(self): key, subkey = jr.split(key) comp = ref + jr.normal(subkey, (2,)) * 0.1 - data.add_trial(ref, comp, 1) + data.add_trial((ref, comp), 1) # Re-fit every step (naive online learning) posterior = optimizer.fit(model, data) diff --git a/tests/test_noise_models.py b/tests/test_noise_models.py index c137037..e6c6a36 100644 --- a/tests/test_noise_models.py +++ b/tests/test_noise_models.py @@ -56,9 +56,7 @@ def test_noise_models_give_different_results(self, model_gaussian, model_student # Create a trial data = ResponseData() - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.1, 0.1]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([0.1, 0.1])), resp=1) # Compute likelihood with Gaussian noise ll_gaussian = model_gaussian.likelihood.loglik( @@ -88,9 +86,7 @@ def test_student_t_heavy_tails(self, model_student_t): """ params = model_student_t.init_params(jr.PRNGKey(0)) data = ResponseData() - data.add_trial( - ref=jnp.array([0.0, 0.0]), comparison=jnp.array([0.5, 0.5]), resp=1 - ) + data.add_trial(input=(jnp.array([0.0, 0.0]), jnp.array([0.5, 0.5])), resp=1) # Override MC fidelity for this test via task config. model_student_t = WPPM( diff --git a/tests/test_posteriors.py b/tests/test_posteriors.py index 89560fe..8595579 100644 --- a/tests/test_posteriors.py +++ b/tests/test_posteriors.py @@ -42,7 +42,9 @@ def data(self): refs = jnp.array([[0.0, 0.0], [1.0, 1.0]]) comparisons = jnp.array([[0.5, 0.5], [1.5, 1.0]]) responses = jnp.array([1, 0], dtype=jnp.int32) - return TrialData(refs=refs, comparisons=comparisons, responses=responses) + return TrialData( + stimuli=jnp.stack([refs, comparisons], axis=1), responses=responses + ) @pytest.fixture def param_posterior(self, model, data): @@ -103,7 +105,9 @@ def data(self): refs = jr.normal(jr.PRNGKey(0), (10, 2)) comparisons = refs + jr.normal(jr.PRNGKey(1), (10, 2)) * 0.3 responses = jnp.ones((10,), dtype=jnp.int32) - return TrialData(refs=refs, comparisons=comparisons, responses=responses) + return TrialData( + stimuli=jnp.stack([refs, comparisons], axis=1), responses=responses + ) @pytest.fixture def param_posterior(self, model, data): @@ -206,7 +210,9 @@ def test_full_workflow(self): refs = jr.normal(k_ref, (20, 2)) comparisons = refs + jr.normal(k_eps, (20, 2)) * 0.5 responses = jnp.ones((20,), dtype=jnp.int32) - data = TrialData(refs=refs, comparisons=comparisons, responses=responses) + data = TrialData( + stimuli=jnp.stack([refs, comparisons], axis=1), responses=responses + ) # 3. Fit model -> ParameterPosterior optimizer = MAPOptimizer(steps=50) diff --git a/tests/test_wishart_covariance.py b/tests/test_wishart_covariance.py index f93e705..787184e 100644 --- a/tests/test_wishart_covariance.py +++ b/tests/test_wishart_covariance.py @@ -259,7 +259,8 @@ def test_fit_with_wishart(self): # X = jnp.stack([refs, comparisons], axis=1) responses = jnp.ones((n,), dtype=jnp.int32) - data = TrialData(refs=refs, comparisons=comparisons, responses=responses) + stimuli = jnp.stack([refs, comparisons], axis=1) + data = TrialData(stimuli=stimuli, responses=responses) # fit optimizer = MAPOptimizer(steps=10)