Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/API_DESIGN_DOCUMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ src/psyphy/
│ └── expected_improvement.py
├── data/ # Data handling
│ ├── dataset.py # ResponseData class
│ ├── dataset.py # ResponseData class, TrialData class
│ └── io.py # Loading/saving
└── utils/ # Utilities
Expand Down
10 changes: 6 additions & 4 deletions docs/examples/wppm/full_wppm_fit_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,13 @@ Note on data used in this script: here, we simulate data (and hence have a groun
- The canonical, (batched) input expected by likelihood evaluation and optimizers (e.g., MAPOptimizer.fit(...)).
It holds JAX arrays:

- refs: (N, d)
- comparisons: (N, d)
- responses: (N,)
- stimuli: (N, K, d)
- responses: (N, R)
- context: (N, C) (optional)

where $d$ refers to the input dimension, here 2.
where $N$ is the number of trials, $K$ is the number of distinct stimuli per trial (e.g. K=2 in the Oddity task: reference and comparison), $d$ is the stimulus space dimensionality, $R$ is the number of response channels, and $C$ is the number of context channels.

Context is an optional attribute which is intended to track task-wide features that vary across trials which might condition likelihoods. For example, psyiological metrics like pupil dilation could be included as context. PsyPhy does not yet have inbuilt uses for context.

**`ResponseData` (collection/I/O-first; convenient for experiments):**
- A Python-friendly log (stores trials in lists) designed for incremental collection, saving/loading (e.g., CSV), and adaptive experiments but expensive for computation
Expand Down
15 changes: 9 additions & 6 deletions docs/examples/wppm/full_wppm_fit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _ellipse_segments_from_covs(
eigvals = jnp.linalg.eigvalsh(covs)
valid = jnp.all(eigvals > 0, axis=-1)

# Cholesky is faster than eigendecomposition for SPD matrices.
# Cholesky is faster than eigen decomposition for SPD matrices.
# we only use it for plotting ellipses (shape/orientation), not inference.
def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray:
L = jnp.linalg.cholesky(cov)
Expand Down Expand Up @@ -257,13 +257,16 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray:
# Build the canonical batched dataset for compute.
#
# Notes:
# - This is equivalent to storing X with shape (N, 2, d) and y with shape (N,)
# where X[:,0,:]=refs and X[:,1,:]=comparisons.
# - We keep named fields because it's currently native to OddityTask.
# - Even though oddity is a 3-item task, we only store (ref, comparison)
# - This is equivalent to storing X with shape (N, K, d) and y with shape (N, R)
# where N is trials, K is number of distinct stimuli per trial, d is stimulus
# dimensionality, and R is number of response channels.
# - For OddityTask, X[:,0,:]=refs and X[:,1,:]=comparisons.
# - Note that even though oddity is a 3-item task, we only store (ref, comparison)
# because the oddity trial is assumed to be (ref, ref, comparison)
#
# --8<-- [start:data]
data = TrialData(refs=refs, comparisons=comparisons, responses=ys)
stimuli = jnp.stack([refs, comparisons], axis=1)
data = TrialData(stimuli=stimuli, responses=ys)
# --8<-- [end:data]

# --8<-- [end:simulate_data]
Expand Down
Binary file modified docs/examples/wppm/plots/quick_start_ellipses.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/examples/wppm/plots/quick_start_learning_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/examples/wppm/quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,5 +129,5 @@ initialization (blue), and the MAP fit (red) at the single reference point:
## Next steps

- **Spatially-varying field:** scale up to a full 2-D grid → [full example](full_wppm_fit_example.md).
- **Your own data:** replace the simulated `TrialData` with your own `refs`, `comparisons`, and `responses` arrays.
- **Your own data:** replace the simulated `TrialData` with your own `inputs` and `responses` arrays. (currently calculations are only supported when the inputs are "refs" and "comparisons.")
- **API reference:** see [`MAPOptimizer`](../../reference/inference.md), [`WPPM`](../../reference/model.md), and [`WPPMCovarianceField`](../../reference/model.md).
5 changes: 2 additions & 3 deletions docs/examples/wppm/quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,8 @@ def _cov_to_points(cov: jnp.ndarray, center: jnp.ndarray) -> jnp.ndarray:
# --8<-- [end:simulate_data]

# --8<-- [start:data]
data = TrialData(
refs=refs, comparisons=comparisons, responses=ys
) # contains 3 JAX arrays
stimuli = jnp.stack([refs, comparisons], axis=1)
data = TrialData(stimuli=stimuli, responses=ys) # contains 2 JAX arrays
# --8<-- [end:data]


Expand Down
3 changes: 2 additions & 1 deletion src/psyphy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
# from . import session as session
from . import trial_placement as trial_placement
from . import utils as utils
from .data.dataset import ResponseData, TrialBatch
from .data.dataset import ResponseData, TrialBatch, TrialData
from .inference.langevin import LangevinSampler
from .inference.laplace import LaplaceApproximation

Expand Down Expand Up @@ -116,6 +116,7 @@
# Data handling
"ResponseData",
"TrialBatch",
"TrialData",
# Subpackages
"model",
"inference",
Expand Down
Loading
Loading