diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a082a45 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,57 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint-and-test: + name: lint + test (py${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - name: Install package + dev tools + run: | + python -m pip install --upgrade pip + pip install -e . + pip install ruff mypy pyright pytest + - name: Ruff lint + run: ruff check . + - name: Ruff format check + run: ruff format --check . + - name: mypy + run: mypy src/scperteval + - name: pyright + run: pyright src/scperteval + - name: pytest + run: pytest -q + + docs: + name: docs build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - name: Install package + docs deps + run: | + python -m pip install --upgrade pip + pip install -e . + pip install --group docs + - name: Build HTML docs + run: sphinx-build -b html -n docs docs/_build/html diff --git a/.gitignore b/.gitignore index 2fd5a41..b84bb62 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,11 @@ dist/ .venv/ venv/ +# Environment +uv.lock +.envrc +requirements-local.txt + # Run outputs results/ @@ -14,3 +19,7 @@ results/ .idea/ .vscode/ .DS_Store + +# docs +docs/generated/ +docs/_build/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2e7c9e3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +# Run `pre-commit install` once; hooks then run on every commit. +# Update pinned revs with `pre-commit autoupdate`. +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.13 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + - id: check-toml + - id: check-merge-conflict + - id: check-added-large-files + args: [--maxkb=1024] diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..fe7f5fe --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,16 @@ +# https://docs.readthedocs.io/en/stable/config-file/v2.html +version: 2 +build: + os: ubuntu-24.04 + tools: + python: "3.12" + jobs: + create_environment: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + build: + html: + - uv sync --group docs + - uv run sphinx-build -M html docs docs/_build -W + - mv docs/_build $READTHEDOCS_OUTPUT diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..e2e763c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,5 @@ +# Changelog + +## 0.1.0 (unreleased) + +Initial implementation of scPertEval. diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 32fbed5..37ab212 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -7,9 +7,9 @@ welcome. There are two paths, depending on what you're changing. If you're adding a protocol (a new metric, or a new combination of an existing metric with a space / centering / controls), **open a PR directly.** This is the common case and the -whole point of the project. See [Create a protocol](README.md#create-a-protocol) for the -two-step pattern (a pure function in `scperteval/protocols/algorithms.py` plus a row in -`scperteval/protocols/table.py`). Adding a new building block (feature space, DE method, control +whole point of the project. See [Create a protocol](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/docs/protocols.md#create-a-protocol) for the +two-step pattern (a pure function in `src/scperteval/protocols/metrics.py` plus a row in +`src/scperteval/protocols/table.py`). Adding a new building block (feature space, DE method, control source, calibrator) the same way is also welcome as a PR. Please include: diff --git a/README.md b/README.md index 625359f..46eda33 100644 --- a/README.md +++ b/README.md @@ -1,504 +1,42 @@ # scPertEval — Evaluation Protocols for Perturbation Sequencing scPertEval is a command-line tool for **experimenting with and sharing reference implementations of -evaluation protocols** in single-cell perturbation studies. +evaluation protocols** in single-cell perturbation studies. The same catalog of protocols backs +three commands: **`score`** (score a model's predictions against ground truth), **`calibrate`** +(calibrate a protocol against empirical positive/negative controls per perturbation, reporting the +**Dynamic Range Fraction (DRF)** and **Bound Discrimination Score (BDS)**), and **`de`** (export +per-gene differential expression). -Evaluating predictions across a dataset's -perturbations reduces to a single question: how different is one group of cells from another? To answer this, an **evaluation protocol** is defined: a specific formulation of a metric, along with some representation of the perturbation data fed to the metric. However, there are a multitude of possibilities -- many already reflected in the literature -- and it can be challenging to compare and contrast protocols across the field and ultimately choose the right approach for a given dataset and problem space. +Our accompanying publication: TODO_LINK_HERE -scPertEval renders each protocol as a short, readable building block to run, read, reuse, and contribute back -- a place for -collaboration and alignment in the field. - -The same catalog of protocols backs three commands, each a different use case: - -- **`score`** — score a model's predictions against ground truth. Each protocol's metric is - applied to your **predicted** cells vs the **real** cells, one score per perturbation — the - conventional "how good is my prediction" evaluation (see - [Scoring predictions](#scoring-predictions-against-ground-truth)). -- **`calibrate`** — calibrate a protocol against empirical positive/negative controls built from - the dataset itself, reporting the **Dynamic Range Fraction (DRF)** and the **Bound Discrimination - Score (BDS)** — quantifying how well the protocol separates real perturbation signal from an - uninformative baseline (see [How calibration works](#how-calibration-works)). Use this to decide - whether a metric is trustworthy in the first place. -- **`de`** — export per-gene differential expression (statistic + adjusted p) to HDF5, since DE - is tightly coupled with several protocols. - -Our accompanying publiciation: TODO_LINK_HERE +**→ Full documentation at ** ## Install ```bash -pip install -e . # provides the `scperteval` command +pip install scperteval ``` -## Input data - -scPertEval reads one preprocessed AnnData (`.h5ad`) per dataset. Only three things are required: - -- **`adata.X`** — normalized expression, cells × genes (e.g. `sc.pp.normalize_total` + `sc.pp.log1p`); sparse or dense float. -- **`adata.obs["perturbation"]`** — the perturbation label for each cell; control cells use the label `"control"`. Both names are configurable (`--perturbation-key` / `--control-label`). -- **`adata.var_names`** — gene identifiers, used as the DEG labels. - -Perturbations with at least `--min-cells` cells (default 30) are evaluated. Nothing else is -needed — references, DE, and PCA are all recomputed in memory, so no `uns`/`obsm`/`layers` are read. - -**Sample datasets.** Seven preprocessed perturbation datasets live in a public, read-only GCS -bucket and serve as a template for the format above: +Or from this repo: ```bash -gsutil ls gs://scperteval/processed/ # wessels23, replogle22{k562,rpe1}, nadig25{hepg2,jurkat}, arch1, kaden25rpe1 -gsutil cp gs://scperteval/processed/wessels23_processed_complete.h5ad . +pip install "scperteval @ git+https://github.com/Virtual-Cell-Research-Community/scPertEval.git" ``` -No gcloud account is needed — each file is also reachable over plain HTTPS at -`https://storage.googleapis.com/scperteval/processed/_processed_complete.h5ad`. - -## Run it +## Quick start ```bash -# protocols by name — including parameterised ones (set k / padj per protocol) -scperteval calibrate data/wessels23.h5ad -p pearson_ctrl,unbiased_mmd_median_pca_k=20,de_overlap_k=10 --de-method t-test - -# a parameterised protocol with no value uses its default (k=50, padj=0.05) -scperteval calibrate data/wessels23.h5ad -p unbiased_mmd_median_top_k --de-method MWU - -# a whole group, or everything (parameterised protocols use their defaults) -scperteval calibrate data/wessels23.h5ad -p distributional --de-method MWU +# calibrate protocols against built-in controls (DRF/BDS) scperteval calibrate data/wessels23.h5ad -p all --de-method t-test -# DRF calibration only (compute DRF only; exclude BDS) -scperteval calibrate data/wessels23.h5ad -p pearson_ctrl --de-method t-test --output drf - -# SCORE predictions against ground truth — predicted cells vs real cells, per protocol. -# predictions.h5ad must have the same genes and perturbation labels as the dataset. -scperteval score data/wessels23.h5ad predictions.h5ad -p pearson,mse,de_auprc --de-method t-test - -# DE only — writes per-gene statistic + adjusted p to HDF5 (no protocol calibration) -# Provided as a convenience, since DE methods are tightly coupled with some evaluation protocols -scperteval de data/wessels23.h5ad --methods MWU +# score a model's predictions against ground truth +scperteval score data/wessels23.h5ad predictions.h5ad -p all -# discover what's available -scperteval list protocols # also: de-methods | spaces | sources | calibrators +scperteval list protocols # also: de-methods | spaces | sources | calibrators ``` -Each command prints a summary table and writes a per-perturbation CSV named -`____.csv`: `calibrate` writes the raw control values and the -calibrated DRF/BDS per perturbation (`…__drf.csv` / `…__bds.csv`); `score` writes the raw metric -value per perturbation (`…__score.csv`). `--profile` adds a per-protocol wall-clock timing CSV. - -**DE backends** (`scperteval list de-methods`): `t-test` (default, Welch's, moment-based), -`MWU` (Cliff's δ via illico), and `t-test_overestim_var` (scanpy's conservative-variance -variant — the reference variance is scaled by the target's cell count). Select one with -`--de-method` for a `calibrate`/`score`, or list several with `--methods` for a `de` export. The overestim -variant is a selectable backend for new protocols; no current protocol uses it. - -
scperteval calibrate --help - -``` -usage: scperteval calibrate [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] - [--subsample SUBSAMPLE] [--seed SEED] [--positive POSITIVE] - [--negative NEGATIVE] [--output {drf,bds}] [--out-dir OUT_DIR] - [--workers WORKERS] [--perturbation-key PERTURBATION_KEY] - [--control-label CONTROL_LABEL] [--min-cells MIN_CELLS] - [--profile] [--quiet] - dataset - - -p, --protocols comma-separated names (parameterised as name=value, e.g. - mse_top_k=30), a group (pseudobulk|distributional|de), or 'all' - --de-method {MWU, t-test, t-test_overestim_var} DE backend for every DE unit: - the interpolated positive control, the top_k/degs spaces, - the de_* protocols, and the WMSE weights - --subsample cells in the single-cell reference sample (default 8192) - --output {drf, bds} how per-perturbation values are calibrated - --positive/--negative override a protocol's controls by source name - --min-cells skip perturbations with fewer cells - --profile also write a per-protocol wall-clock timing table -``` -
- -
scperteval score --help - -``` -usage: scperteval score [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] - [--subsample SUBSAMPLE] [--seed SEED] [--out-dir OUT_DIR] [--workers WORKERS] - [--perturbation-key PERTURBATION_KEY] [--control-label CONTROL_LABEL] - [--min-cells MIN_CELLS] [--profile] [--quiet] - dataset predictions - - dataset preprocessed .h5ad — the ground truth (real cells) - predictions predicted .h5ad — same genes and perturbation labels as the dataset - -p, --protocols comma-separated names, a group, or 'all' - --de-method DE backend for the de_* protocols, the top_k/degs spaces, and WMSE weights - --subsample cells in the all-perturbed reference (the ground truth is never subsampled) -``` - -Unlike `calibrate`, there are no `--positive`/`--negative`/`--output` options: the candidate is -always your prediction and the output is always the raw `score`. -
- -## Use it from Python - -Install with `pip install scperteval` (or, from this repo, -`pip install "scperteval @ git+https://github.com/Virtual-Cell-Research-Community/scPertEval.git"`). -The simplest path mirrors the CLI — call it via subprocess, exactly as the figure notebook does: - -```python -import subprocess, sys - -subprocess.run([sys.executable, "-m", "scperteval", "calibrate", "data/wessels23.h5ad", - "-p", "all", "--de-method", "t-test", "--out-dir", "results"], check=True) -# -> results/wessels23____drf.csv (raw control values + calibrated DRF per perturbation) - -# score predictions against ground truth instead: -subprocess.run([sys.executable, "-m", "scperteval", "score", "data/wessels23.h5ad", - "predictions.h5ad", "-p", "all", "--out-dir", "results"], check=True) -# -> results/wessels23____score.csv (raw metric value per perturbation) -``` - -## Look up an Evaluation Protocol - -Two files define each protocol: - -- **[`scperteval/protocols/metrics.py`](scperteval/protocols/metrics.py)** — the metric, as a - pure function of the ground truth and a `prediction` (the candidate being scored — a positive - or negative control under `calibrate`, or your model's output under `score`). e.g. `mse`, `mmd`, - `de_auprc`: - ```python - def mse(gt, prediction, ctx): - return float(np.mean((gt - prediction) ** 2)) - ``` -- **[`scperteval/protocols/table.py`](scperteval/protocols/table.py)** — one row wiring that function - to its data: the data representation it receives (`representation`), feature space, - reference centering, positive/negative controls, which direction is `better` - (`"higher"`/`"lower"`), and the `perfect` score: - ```python - Protocol("mse", M.mse, representation="centroid", - positive="interpolated", negative="all_perturbed_mean", better="lower", perfect=0.0) - ``` - -The next section breaks these arguments down while building one up from scratch. - -## Create a protocol - -A protocol is two things: a pure metric **function** and a one-line **spec** that wires it -to data and scoring. We'll ease in — the simplest possible protocol first, then the spec -broken down, then a few richer examples. - -### Start simple - -Here is a complete new protocol: mean absolute error on the standard pseudobulk profiles. - -1. Add a pure function to [`scperteval/protocols/metrics.py`](scperteval/protocols/metrics.py): - ```python - def mae(gt, prediction, ctx): - return float(np.mean(np.abs(gt - prediction))) - ``` - Every metric function has this signature. `gt` is one perturbation's ground-truth - profile; `prediction` is the candidate being compared against it (under `calibrate`, scPertEval - calls the function once for the positive control and once for the negative; under `score`, once - with your model's prediction). `ctx` is the dataset context, needed by only a few metrics — - ignore it otherwise. Return a single number. - -2. Add a row to [`scperteval/protocols/table.py`](scperteval/protocols/table.py): - ```python - Protocol("mae", M.mae, representation="centroid", - positive="interpolated", negative="all_perturbed_mean", - better="lower", perfect=0.0) - ``` - -Run it with `scperteval calibrate data.h5ad -p mae`. That is the whole protocol: MAE between each -perturbation's pseudobulk profile and its positive and negative controls, scored as -lower-is-better toward a perfect of 0. - -### The spec - -That row is the spec; parameters include: - -| argument | meaning | -|---|---| -| `name` | selects the protocol on the CLI (`-p mae`) | -| `representation` | the shape of each datapoint your function receives (see below) | -| `scope` | `"perturbation"` (default) or `"dataset"` — how many perturbations at once (see below) | -| `space` | which features to score — `full` (default), or a feature space like `top_50` | -| `centering` | a baseline subtracted before scoring, e.g. `"ctrl"` (default: none) | -| `positive` / `negative` | the two control sources to compare | -| `better` | `"higher"` or `"lower"` — which direction is an improvement | -| `perfect` | the value a flawless prediction attains | -| `param` | optional — a parameter family (`top_k`, `pca_k`, `degs_padj`, `overlap_k`) that makes the protocol tunable from the CLI; omit for a fixed protocol | - -**`representation`** decides the *shape* of each datapoint — the format `gt` and -`prediction` arrive in — so you never deal with sampling, references, or projection yourself: - -| `representation` | a datapoint is | -|---|---| -| `centroid` | a 1-D pseudobulk vector (one value per gene) | -| `population` | a `(cells × genes)` matrix | -| `de` | a `DEResult` (for the ground truth) / per-gene `|score|` ranking (for a prediction) | - -**`scope`** is the independent companion axis — *how many* perturbations the metric sees at once: - -| `scope` | the metric is called | -|---|---| -| `perturbation` (default) | once per perturbation — gets that perturbation's `(gt, prediction)` datapoints and returns a scalar | -| `dataset` | once for the whole dataset — gets the **list** of every perturbation's `gt` and `prediction` datapoints and returns one score per perturbation (e.g. a retrieval `rank`) | - -The two compose freely: `rank` is just `representation="centroid", scope="dataset"`; a -distributional retrieval metric would be `representation="population", scope="dataset"`. - -Many rows repeat the same wiring, so the top of `table.py` predefines the common -combinations as plain dicts. You then unpack one into a row with `**` (Python's -keyword-expansion syntax) to avoid retyping it: -```python -_PB = dict(group="pseudobulk", positive="interpolated", negative="all_perturbed_mean") -_LOWER = dict(better="lower", perfect=0.0) -``` -With those, the `mae` row above is exactly `Protocol("mae", M.mae, -representation="centroid", **_PB, **_LOWER)` — same protocol, less repetition. You'll see -these bundles reused throughout the table. - -### Building blocks — the palette - -The values those arguments take — feature spaces, control sources, DE methods, calibrators -— are registered building blocks. `scperteval list ` shows what's available -in each, with descriptions: - -**Feature spaces** (the `space` argument) - -```bash -$ scperteval list spaces -degs_0.05 — ground-truth DEGs at adjusted p < 0.05, per perturbation -full — all genes, no transform -pca_50 — top 50 principal components (fit on the dataset) -top_50 — top 50 genes by ground-truth effect size, per perturbation -``` - -`top_` / `pca_` / `degs_` are parameterised families (the defaults are shown); -a protocol template picks the value. If the space you need isn't here, see -[Add a feature space](#add-a-feature-space). - -**DE methods** (the `--de-method` choice) - -```bash -$ scperteval list de-methods -MWU — Mann-Whitney U / Cliff's delta effect size (via illico) -t-test — Welch's t-test (default) — moment-based and fast -``` - -Chosen with `--de-method`; it applies to **every** DE-dependent unit (the `interpolated` -positive control, the `top_k`/`degs` spaces, the `de_*` protocols, and the WMSE weights). -To add another, see [Add a DE method](#add-a-de-method). - -**Control sources** (the `positive` / `negative` arguments) - -```bash -$ scperteval list sources -all_perturbed (cells) — all-perturbed reference sample, leave-one-out (single-cell negative control) -all_perturbed_mean (centroid) — all-perturbed mean, excluding the target — leave-one-out (pseudobulk sibling of all_perturbed; pseudobulk negative control) -control (cells) — non-targeting control cells -global_mean (centroid) — mean of all perturbations — shared baseline for the ranking protocols -gt_all_cells (cells) — ground truth — all of a perturbation's real cells (prediction-scoring truth) -gt_half (cells) — ground truth — the first half of a perturbation's cells (calibration truth) -interpolated (centroid) — interpolated duplicate — DE-weighted blend of the held-out half and the dataset mean (pseudobulk positive control) -prediction (cells) — model-predicted cells for the perturbation, from the --predictions h5ad -tech_dup (cells) — technical duplicate — the held-out second half (single-cell positive control) -``` - -The truth source is chosen by the command, not by a protocol: `calibrate` uses `gt_half` and -holds the other half out to build the positive control; `score` uses `gt_all_cells` and compares -it to `prediction`. - -Each `provides` cells or a pseudobulk `centroid`. Use via `positive=`/`negative=` (or -`--positive`/`--negative`). To add another, see [Add a control source](#add-a-control-source). - -**Calibrators** (the `--output` choice) - -```bash -$ scperteval list calibrators -drf — Dynamic Range Fraction — mean/median over perturbations (Miller et al. 2025) -bds — Bound Discrimination Score — fraction of perturbations the positive control wins (SBB 2026) -score — raw metric of a prediction vs ground truth — mean/median over perturbations (prediction-scoring mode) -``` - -`drf`/`bds` are chosen with `calibrate --output`; `score` is selected automatically by the -`score` command. To add another, see [Add a calibrator](#add-a-calibrator). - -### More examples - -With the spec and the palette in hand, richer protocols are just different combinations. - -**Same wiring, different metric.** Cosine distance on pseudobulk reuses the bundles wholesale: -```python -def cosine(gt, prediction, ctx): - return 1.0 - float(gt @ prediction / (np.linalg.norm(gt) * np.linalg.norm(prediction))) -``` -```python -Protocol("cosine", M.cosine, representation="centroid", **_PB, **_LOWER) -``` - -**Restrict to a feature space.** Set `space` to score only some genes — e.g. MAE on the -top-50 DEGs: -```python -Protocol("mae_top50", M.mae, representation="centroid", space="top_50", **_PB, **_LOWER) -``` - -**Expose the space as a knob (parameterised).** To make `k` adjustable per invocation, add a -`param` to the same `Protocol(...)` row — nothing else changes. The row's name carries the -parameter, and the value is supplied on the CLI: -```python -Protocol("mae_top_k", M.mae, representation="centroid", param=top_k, **_PB, **_LOWER) -``` -Then `scperteval calibrate data.h5ad -p mae_top_k=30` (or `mae_top_k` for the default `k=50`). The -families are `top_k` (top-k DEGs), `pca_k` (k PCs), and `degs_padj` (DEGs at adjusted -p < padj) for the space, and `overlap_k` to feed an integer straight to the metric. - -**A metric over cells, not profiles.** Switch `representation` to `population` and your -function receives `(cells × genes)` matrices; pair it with the single-cell controls: -```python -def my_mmd(gt, prediction, ctx): # gt, prediction are (cells × genes) - ... -``` -```python -Protocol("my_mmd_top50", M.my_mmd, representation="population", space="top_50", - positive="tech_dup", negative="all_perturbed", better="lower", perfect=0.0) -``` -This changes two pieces at once — the `representation` (so the function sees cells) and the controls -(the single-cell positive/negative) — which is the general pattern for a distributional -protocol. - -By now you've seen every moving part: the function, the spec, the building blocks the spec -draws on, fixed and parameterised spaces, and switching the representation the function -sees. Most new metrics are some combination of these. - -## Add a building block - -Spaces, DE methods, control sources, and calibrators are registered units — add one when -the palette is missing what a new protocol needs. Each is a small function (or object) plus -a one-line registration. - -### Add a feature space - -A space is a function `(X, ctx, pert) -> dense (cells × genes) array` that transforms the -gene axis. Register it with `@SPACES.register` in -[`scperteval/blocks/spaces.py`](scperteval/blocks/spaces.py); pass `global_space=True` if it doesn't -depend on the perturbation (so it can be computed once and shared): - -```python -@SPACES.register("hvg_100", global_space=True, description="100 highest-variance genes") -def space_hvg(X, ctx, pert): - keep = ... # indices of the genes to keep - return to_dense(X[:, keep]) -``` - -For a per-perturbation subset derived from the ground-truth DE (like `top_k` / `degs`), use -the `register_de_space(name, field=..., top=...)` helper in the same file instead. - -### Add a DE method - -A DE method maps `(target_cells, reference_cells) -> DEResult(score, pvalue, pvalue_adj)`. -Register it with `@DE_METHODS.register` in [`scperteval/blocks/de.py`](scperteval/blocks/de.py) (the -`bh` helper there BH-adjusts p-values): - -```python -@DE_METHODS.register("my_test", description="…") -def de_my_test(target, reference): - score, pvalue = ... # per-gene statistic and raw p-value - return DEResult(score=score, pvalue=pvalue, pvalue_adj=bh(pvalue)) -``` - -Then `--de-method my_test` routes every DE-dependent unit through it. - -### Add a control source - -A source maps `(ctx, pert) -> cells or a 1-D centroid`, declaring which with `provides`. -Register it with `@SOURCES.register` in [`scperteval/sources.py`](scperteval/sources.py): - -```python -@SOURCES.register("my_baseline", provides="centroid", description="…") -def src_my_baseline(ctx, pert): - return ... # a 1-D centroid (or cells, if provides="cells") -``` - -Use it as a control via `positive=`/`negative=` in a row, or `--positive`/`--negative` at -the CLI. - -### Add a calibrator - -A calibrator declares the control roles it needs, a per-perturbation combine, and a -cross-perturbation aggregate. Add a `Calibrator` to the `CALIBRATORS` dict in -[`scperteval/calibrators.py`](scperteval/calibrators.py): - -```python -CALIBRATORS["my_score"] = Calibrator( - "my_score", ("positive", "negative"), - per_pert=lambda raws, p: ..., # raws["positive"], raws["negative"] -> one number - aggregate=lambda v: {"my_score": float(np.nanmean(v))}, - description="…", -) -``` - -Then `--output my_score` reports it. - -## Scoring predictions against ground truth - -`scperteval score dataset.h5ad predictions.h5ad` is the conventional evaluation: each protocol's -metric is applied to your **predicted** cells against the **real** cells, one score per -perturbation. It runs the *same* protocol catalog as `calibrate`; only two pieces differ. - -- **ground truth** — *all* of a perturbation's real cells (the `gt_all_cells` source). Unlike - calibration, no half is held out and no positive/negative controls are built — the ground - truth is the whole real population. -- **prediction** — the matching cells from your `predictions.h5ad` (the `prediction` source). - The prediction file must contain the dataset's exact gene set (any order — columns are - reordered by name so the comparison lines up gene-for-gene) and the same perturbation labels. - A gene-set mismatch, or a perturbation present in the dataset but absent from the predictions, - raises an error naming exactly what's wrong. - -The `score` calibrator reports each protocol's raw metric value per perturbation and its -mean/median across perturbations, written to `____score.csv`. Higher- vs -lower-is-better follows each protocol's `better` field, exactly as in calibration. - -Architecturally this reuses everything — the per-perturbation loop, every metric, representation, -and feature space are shared with `calibrate`. The only differences are the **truth source** -(`gt_all_cells` instead of the held-out `gt_half`) and the **calibrator** (`score`, which needs -only the prediction, instead of `drf`/`bds`, which need both controls). The DE-derived feature -spaces (`top_k`, `degs`) and the WMSE weights are computed from this same all-cells ground truth. - -## How calibration works - -scPertEval's claim — a usable catalog of protocols — rests on **calibrating** each protocol -against two empirical controls per perturbation, so you can see whether a metric actually -separates signal from baseline rather than read a raw, uninterpretable number. - -- **positive control** — the best realistic candidate: the **technical duplicate** (a - held-out replicate) for single-cell protocols, the **interpolated duplicate** for pseudobulk. -- **negative control** — an uninformative baseline: the **all-perturbed reference, - excluding the target perturbation** (a full-resolution mean for pseudobulk; an 8192-cell - subsample for single-cell distances). - -**Dynamic Range Fraction (DRF)** — where the protocol's value sits between the negative -control (floor) and the perfect score, anchored by the positive control: - -``` -DRF = (positive − negative) / (perfect − negative) # per perturbation, clipped to [-1, 1] -``` - -`--output drf` reports the mean/median across perturbations. High DRF means the protocol -discriminates real signal; near zero means it doesn't. Introduced by Miller et al., -*Deep Learning-Based Genetic Perturbation Models Do Outperform Uninformative Baselines on -Well-Calibrated Metrics* (2025) — . - -**Bound Discrimination Score (BDS)** — the fraction of perturbations for which the positive -control beats the negative control under this protocol: - -``` -BDS = fraction of perturbations where positive control beats negative control # in [0, 1] -``` - -`--output bds` reports this fraction. It's a sensitivity check: a protocol with low BDS -can't even tell a technical replicate from an uninformative baseline, so its scores -shouldn't be trusted. Introduced by Vollenweider & Bühlmann, *Signal, Bounds, and -Baselines* (SBB, 2026) — (code: -). +Sample datasets are available at +`https://storage.googleapis.com/scperteval/processed/_processed_complete.h5ad`. --- diff --git a/docs/_static/.gitkeep b/docs/_static/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css new file mode 100644 index 0000000..b8c8d47 --- /dev/null +++ b/docs/_static/css/custom.css @@ -0,0 +1,4 @@ +/* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */ +div.cell_output table.dataframe { + font-size: 0.8em; +} diff --git a/docs/_templates/.gitkeep b/docs/_templates/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst new file mode 100644 index 0000000..834712e --- /dev/null +++ b/docs/_templates/autosummary/class.rst @@ -0,0 +1,57 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + +{% set public_methods = methods | reject("equalto", "__init__") | list %} + +{% block attributes %} +{% if attributes %} +Attributes table +~~~~~~~~~~~~~~~~ + +.. autosummary:: +{% for item in attributes %} + ~{{ name }}.{{ item }} +{%- endfor %} +{% endif %} +{% endblock %} + +{% block methods %} +{% if public_methods %} +Methods table +~~~~~~~~~~~~~ + +.. autosummary:: +{% for item in public_methods %} + ~{{ name }}.{{ item }} +{%- endfor %} +{% endif %} +{% endblock %} + +{% block attributes_documentation %} +{% if attributes %} +Attributes +~~~~~~~~~~ + +{% for item in attributes %} + +.. autoattribute:: {{ [objname, item] | join(".") }} +{%- endfor %} + +{% endif %} +{% endblock %} + +{% block methods_documentation %} +{% if public_methods %} +Methods +~~~~~~~ + +{% for item in public_methods %} + +.. automethod:: {{ [objname, item] | join(".") }} +{%- endfor %} + +{% endif %} +{% endblock %} diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 0000000..5c40462 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,193 @@ +# API + +## Core types + +```{eval-rst} +.. module:: scperteval.types +.. currentmodule:: scperteval.types + +.. autosummary:: + :toctree: generated + + RunConfig + Protocol + Calibrator + DEResult + Param +``` + +## Runner + +```{eval-rst} +.. module:: scperteval.runner +.. currentmodule:: scperteval.runner + +.. autosummary:: + :toctree: generated + + run_protocol +``` + +## Protocols + +- `scperteval.protocols.TABLE` — list of all `Protocol` objects. +- `scperteval.protocols.PROTOCOLS` — `{name: Protocol}` dict. +- `scperteval.protocols.GROUPS` — sorted list of group names. + +```{eval-rst} +.. protocol-table:: +``` + +### Metrics + +```{eval-rst} +.. module:: scperteval.protocols.metrics +.. currentmodule:: scperteval.protocols.metrics + +.. automodule:: scperteval.protocols.metrics + :no-members: + :no-index: + +.. autosummary:: + :toctree: generated + + pearson + mse + weighted_mse + energy_distance + unbiased_mmd_median + sinkhorn_w2 + rank_retrieval + de_auprc + de_auroc + de_overlap +``` + +## Calibrators + +```{eval-rst} +.. automodule:: scperteval.calibrators + :no-members: +``` + +`scperteval.calibrators.CALIBRATORS` — `{name: Calibrator}` dict of built-in calibrators (`drf`, `bds`, and `score` for the prediction-scoring mode). +Add entries here to register a new calibrator; see [Add a calibrator](user-guide/building-blocks.md#add-a-calibrator). + +## Building blocks + +### Differential expression + +```{eval-rst} +.. module:: scperteval.blocks.de +.. currentmodule:: scperteval.blocks.de + +.. automodule:: scperteval.blocks.de + :no-members: + :no-index: + +.. autosummary:: + :toctree: generated + + DE_METHODS + moments + bh + ttest_from_moments + de_ttest + de_ttest_overestim + de_mwu +``` + +### Feature spaces + +```{eval-rst} +.. module:: scperteval.blocks.spaces +.. currentmodule:: scperteval.blocks.spaces + +.. automodule:: scperteval.blocks.spaces + :no-members: + :no-index: + +.. autosummary:: + :toctree: generated + + SPACES + register_de_space + top_space + pca_space + degs_space +``` + +### Control sources + +```{eval-rst} +.. automodule:: scperteval.sources + :no-members: +``` + +`scperteval.sources.SOURCES` — registry of all control/reference sources. +Add entries here to register a new source; see [Add a control source](user-guide/building-blocks.md#add-a-control-source). + +### Predictions + +```{eval-rst} +.. module:: scperteval.predictions +.. currentmodule:: scperteval.predictions + +.. autosummary:: + :toctree: generated + + PredictionSet +``` + +`scperteval.predictions.PredictionSet` — model-predicted cells loaded from a `.h5ad` and +gene-aligned to the dataset, used by the `score` command. + +## Context + +```{eval-rst} +.. module:: scperteval.context +.. currentmodule:: scperteval.context + +.. autosummary:: + :toctree: generated + + Context +``` + +## Registry + +```{eval-rst} +.. module:: scperteval.registry +.. currentmodule:: scperteval.registry + +.. autosummary:: + :toctree: generated + + Registry +``` + +## Dataset & I/O + +```{eval-rst} +.. module:: scperteval.dataset +.. currentmodule:: scperteval.dataset + +.. autosummary:: + :toctree: generated + + Dataset + to_dense +``` + +```{eval-rst} +.. module:: scperteval.io +.. currentmodule:: scperteval.io + +.. autosummary:: + :toctree: generated + + print_summary + write_rows + write_timing + write_de +``` diff --git a/docs/changelog.md b/docs/changelog.md new file mode 100644 index 0000000..d9e79ba --- /dev/null +++ b/docs/changelog.md @@ -0,0 +1,3 @@ +```{include} ../CHANGELOG.md + +``` diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..23689c7 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,126 @@ +# Configuration file for the Sphinx documentation builder. +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +import shutil +import sys +from datetime import datetime +from importlib.metadata import metadata +from pathlib import Path + +from sphinxcontrib import katex + +HERE = Path(__file__).parent +sys.path.insert(0, str(HERE / "extensions")) + +# -- Project information ----------------------------------------------------- + +info = metadata("scperteval") +project = info["Name"] +author = info.get("Author") or "scPertEval authors" +copyright = f"{datetime.now():%Y}, {author}" +version = info["Version"] +_project_urls = info.get_all("Project-URL") or [] +urls = dict(pu.split(", ", 1) for pu in _project_urls) +repository_url = urls.get("Source", "https://github.com/Virtual-Cell-Research-Community/scPertEval") + +release = info["Version"] + +bibtex_bibfiles = ["references.bib"] +bibtex_reference_style = "author_year" +templates_path = ["_templates"] +nitpicky = True +needs_sphinx = "4.0" + +html_context = { + "display_github": True, + "github_user": "Virtual-Cell-Research-Community", + "github_repo": "scPertEval", + "github_version": "main", + "conf_py_path": "/docs/", +} + +# -- General configuration --------------------------------------------------- + +extensions = [ + "myst_nb", + "sphinx_copybutton", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinxcontrib.bibtex", + "sphinxcontrib.katex", + "sphinx_autodoc_typehints", + "sphinx_design", + "IPython.sphinxext.ipython_console_highlighting", + "sphinxext.opengraph", + *[p.stem for p in (HERE / "extensions").glob("*.py")], +] + +autosummary_generate = True +autodoc_member_order = "groupwise" +default_role = "literal" +napoleon_google_docstring = False +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = False +napoleon_use_rtype = True +napoleon_use_param = True +myst_heading_anchors = 6 +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "html_image", + "html_admonition", +] +myst_url_schemes = ("http", "https", "mailto") +nb_output_stderr = "remove" +nb_execution_mode = "off" +nb_merge_streams = True +typehints_defaults = "braces" +always_use_bars_union = True + +source_suffix = { + ".rst": "restructuredtext", + ".ipynb": "myst-nb", + ".myst": "myst-nb", +} + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "anndata": ("https://anndata.readthedocs.io/en/stable/", None), + "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), + "sklearn": ("https://scikit-learn.org/stable/", None), +} + +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints", "notebooks/README.md"] + +# -- Options for HTML output ------------------------------------------------- + +html_theme = "sphinx_book_theme" +html_static_path = ["_static"] +html_css_files = ["css/custom.css"] + +html_title = project + +html_theme_options = { + "repository_url": repository_url, + "use_repository_button": True, + "path_to_docs": "docs/", + "navigation_with_keys": False, + "show_navbar_depth": 1, +} + +pygments_style = "default" +katex_prerender = shutil.which(katex.NODEJS_BINARY) is not None + +nitpick_ignore = [ # type: ignore + # Add exceptions here for links outside your control that fail to resolve + ("py:class", "Context"), + ("py:class", "Dataset"), + # Internal classes referenced in type hints but not given their own API page. + ("py:class", "scperteval.reference.Reference"), +] diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 0000000..69b9914 --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,6 @@ +```{include} ../CONTRIBUTORS.md + +``` + +For development setup (installing dependencies, running linters, building docs locally), +see [Installation](installation.md#development-setup). diff --git a/docs/extensions/protocol_table.py b/docs/extensions/protocol_table.py new file mode 100644 index 0000000..360ffc0 --- /dev/null +++ b/docs/extensions/protocol_table.py @@ -0,0 +1,60 @@ +"""Sphinx directive that auto-generates a reference table of evaluation protocols.""" +from __future__ import annotations + +from docutils import nodes +from docutils.parsers.rst import Directive +from sphinx.application import Sphinx + + +class ProtocolTableDirective(Directive): + """Emit a table of all protocols from ``scperteval.protocols.TABLE``.""" + + def run(self): + """Build and return the protocol reference table node.""" + from scperteval.protocols import TABLE + + table = nodes.table() + tgroup = nodes.tgroup(cols=4) + table += tgroup + + for _ in range(4): + tgroup += nodes.colspec(colwidth=1) + + thead = nodes.thead() + tgroup += thead + header_row = nodes.row() + for text in ("Name", "Group", "Representation", "Better"): + entry = nodes.entry() + entry += nodes.paragraph(text=text) + header_row += entry + thead += header_row + + tbody = nodes.tbody() + tgroup += tbody + for p in TABLE: + row = nodes.row() + + # Name cell — link to the metric function page if available + name_entry = nodes.entry() + metric_name = p.metric.__name__ if hasattr(p.metric, "__name__") else p.name + ref_id = f"scperteval.protocols.metrics.{metric_name}" + ref = nodes.reference("", p.name, internal=True, refuri=f"generated/{ref_id}.html") + name_para = nodes.paragraph() + name_para += ref + name_entry += name_para + row += name_entry + + for text in (p.group, p.representation, p.better): + entry = nodes.entry() + entry += nodes.paragraph(text=text) + row += entry + + tbody += row + + return [table] + + +def setup(app: Sphinx): + """Register the ``protocol-table`` directive with Sphinx.""" + app.add_directive("protocol-table", ProtocolTableDirective) + return {"version": "0.1", "parallel_read_safe": True} diff --git a/docs/extensions/typed_returns.py b/docs/extensions/typed_returns.py new file mode 100644 index 0000000..5342ccf --- /dev/null +++ b/docs/extensions/typed_returns.py @@ -0,0 +1,31 @@ +# code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py +from __future__ import annotations + +import re +from collections.abc import Generator, Iterable + +from sphinx.application import Sphinx +from sphinx.ext.napoleon import NumpyDocstring # type: ignore + + +def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: + for line in lines: + if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): + yield f"-{m['param']} (:class:`~{m['type']}`)" + else: + yield line + + +def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: + lines_raw = self._dedent(self._consume_to_next_section()) + if lines_raw[0] == ":": + del lines_raw[0] + lines = self._format_block(":returns: ", list(_process_return(lines_raw))) + if lines and lines[-1]: + lines.append("") + return lines + + +def setup(app: Sphinx): + """Set app.""" + NumpyDocstring._parse_returns_section = _parse_returns_section # type: ignore diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..4a2ea58 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,93 @@ +# scPertEval — Evaluation Protocols for Perturbation Sequencing + +scPertEval is a command-line tool for **experimenting with and sharing reference implementations of +evaluation protocols** in single-cell perturbation studies. + +Evaluating predictions across a dataset's perturbations reduces to a single question: how +different is one group of cells from another? To answer this, an **evaluation protocol** is +defined: a specific formulation of a metric, along with some representation of the +perturbation data fed to the metric. However, there are a multitude of possibilities — many +already reflected in the literature — and it can be challenging to compare and contrast +protocols across the field and ultimately choose the right approach for a given dataset and +problem space. + +scPertEval renders each protocol as a short, readable building block to run, read, reuse, +and contribute back — a place for collaboration and alignment in the field. The same catalog +of protocols backs three commands: + +- **`score`** — score a model's predictions against ground truth, one metric value per + perturbation (see [Scoring predictions](user-guide/scoring.md)). +- **`calibrate`** — calibrate a protocol against built-in positive/negative controls, reporting + the **Dynamic Range Fraction (DRF)** and **Bound Discrimination Score (BDS)** — how well it + separates real signal from an uninformative baseline (see [Calibration](user-guide/calibration.md)). +- **`de`** — export per-gene differential expression to HDF5. + +## Quick start + +```bash +pip install scperteval +scperteval calibrate data/wessels23.h5ad -p all --de-method t-test +``` + +::::{grid} 1 2 3 3 +:gutter: 2 + +:::{grid-item-card} {octicon}`desktop-download;1em;` Installation +:link: installation +:link-type: doc +Get scPertEval installed and set up your development environment. +::: + +:::{grid-item-card} {octicon}`book;1em;` User guide +:link: user-guide/index +:link-type: doc +Learn how to run protocols, interpret scores, and explore the building blocks. +::: + +:::{grid-item-card} {octicon}`mortar-board;1em;` Tutorials +:link: tutorials +:link-type: doc +Step-by-step notebooks: CLI walkthrough, Python API, and extending the tool. +::: + +:::{grid-item-card} {octicon}`code-square;1em;` API reference +:link: api +:link-type: doc +Full reference for the Python API. +::: + +:::{grid-item-card} {octicon}`mark-github;1em;` GitHub +:link: https://github.com/Virtual-Cell-Research-Community/scPertEval +:link-type: url +Browse the source code, open issues, or contribute a pull request. +::: + +:::: + +## Citation + +If you use scPertEval, please cite {cite}`Schafer_2026`. + +```bibtex +@unpublished{Schafer_2026, + author = {Schäfer, Philipp S. L. and Reid, Kendall A. and Boldyga, Zach + and Aksu, Ekin Deniz and Hakem, Hugo and Saez-Rodriguez, Julio}, + title = {Towards a Principled Evaluation of Single-Cell Perturbation + Response Prediction Models}, + note = {In preparation}, + year = {2026}, +} +``` + +```{toctree} +:hidden: true +:maxdepth: 2 + +installation.md +user-guide/index +tutorials.md +api.md +changelog.md +Contributing +references.md +``` diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 0000000..6e149af --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,44 @@ +# Installation + +## From PyPI + +```bash +pip install scperteval +``` + +## From source + +```bash +pip install "scperteval @ git+https://github.com/Virtual-Cell-Research-Community/scPertEval.git" +``` + +Or, for an editable install from a local clone: + +```bash +git clone https://github.com/Virtual-Cell-Research-Community/scPertEval.git +cd scPertEval +pip install -e . +``` + +## Development setup + +Install all dev dependencies (linting + docs + tests): + +```bash +uv sync --group dev +``` + +Run linters: + +```bash +uv run ruff format . +uv run ruff check . +uv run mypy src/scperteval +``` + +Build the docs locally with live reload: + +```bash +uv sync --group docs +uv run sphinx-autobuild docs docs/_build/html +``` diff --git a/docs/notebooks/.gitkeep b/docs/notebooks/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/references.bib b/docs/references.bib new file mode 100644 index 0000000..fbcd467 --- /dev/null +++ b/docs/references.bib @@ -0,0 +1,26 @@ +@unpublished{Schafer_2026, + author = {Schäfer, Philipp S. L. and Reid, Kendall A. and Boldyga, Zach and Aksu, Ekin Deniz and Hakem, Hugo and Saez-Rodriguez, Julio}, + title = {Towards a Principled Evaluation of Single-Cell Perturbation Response Prediction Models}, + note = {In preparation}, + year = {2026}, +} + +@misc{Miller_2025, + title = {Deep {{Learning-Based Genetic Perturbation Models Do Outperform Uninformative Baselines}} on {{Well-Calibrated Metrics}}}, + author = {Miller, Henry E. and Mejia, Gabriel M. and Leblanc, Francis J. A. and Wang, Bo and Swain, Brendan and Camillo, Lucas Paulo de Lima}, + year = {2025}, + month = oct, + publisher = {bioRxiv}, + pages = {2025.10.20.683304}, + doi = {10.1101/2025.10.20.683304}, +} + +@misc{Vollenweider_2026, + title = {Signal, {{Bounds}}, and {{Baselines}}: {{Principles}} for {{Rigorous Evaluation}} of {{High-Dimensional Biological Perturbation Prediction}}}, + author = {Vollenweider, Michael and B{\"u}hlmann, Peter}, + year = {2026}, + month = apr, + publisher = {bioRxiv}, + pages = {2026.04.20.719650}, + doi = {10.64898/2026.04.20.719650}, +} diff --git a/docs/references.md b/docs/references.md new file mode 100644 index 0000000..00ad6a6 --- /dev/null +++ b/docs/references.md @@ -0,0 +1,5 @@ +# References + +```{bibliography} +:cited: +``` diff --git a/docs/tutorials.md b/docs/tutorials.md new file mode 100644 index 0000000..9cd4fb2 --- /dev/null +++ b/docs/tutorials.md @@ -0,0 +1,7 @@ +# Tutorials + +Notebooks are coming soon. Planned tutorials: + +- **CLI walkthrough** — run protocols on a dataset end-to-end from the command line +- **Python API** — use scPertEval programmatically from a notebook or script +- **Extending scPertEval** — add a new protocol, feature space, or control source diff --git a/docs/user-guide/building-blocks.md b/docs/user-guide/building-blocks.md new file mode 100644 index 0000000..ffd2ede --- /dev/null +++ b/docs/user-guide/building-blocks.md @@ -0,0 +1,68 @@ +# Building blocks + +Spaces, DE methods, control sources, and calibrators are registered units — add one when +the palette is missing what a new protocol needs. Each is a small function (or object) plus +a one-line registration. + +## Add a feature space + +A space is a function `(X, ctx, pert) -> dense (cells × genes) array` that transforms the +gene axis. Register it with `@SPACES.register` in +[`src/scperteval/blocks/spaces.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/blocks/spaces.py); pass `global_space=True` if it doesn't +depend on the perturbation (so it can be computed once and shared): + +```python +@SPACES.register("hvg_100", global_space=True, description="100 highest-variance genes") +def space_hvg(X, ctx, pert): + keep = ... # indices of the genes to keep + return to_dense(X[:, keep]) +``` + +For a per-perturbation subset derived from the ground-truth DE (like `top_k` / `degs`), use +the `register_de_space(name, field=..., top=...)` helper in the same file instead. + +## Add a DE method + +A DE method maps `(target_cells, reference_cells) -> DEResult(score, pvalue, pvalue_adj)`. +Register it with `@DE_METHODS.register` in [`src/scperteval/blocks/de.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/blocks/de.py) (the +`bh` helper there BH-adjusts p-values): + +```python +@DE_METHODS.register("my_test", description="…") +def de_my_test(target, reference): + score, pvalue = ... # per-gene statistic and raw p-value + return DEResult(score=score, pvalue=pvalue, pvalue_adj=bh(pvalue)) +``` + +Then `--de-method my_test` routes every DE-dependent unit through it. + +## Add a control source + +A source maps `(ctx, pert) -> cells or a 1-D centroid`, declaring which with `provides`. +Register it with `@SOURCES.register` in [`src/scperteval/sources.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/sources.py): + +```python +@SOURCES.register("my_baseline", provides="centroid", description="…") +def src_my_baseline(ctx, pert): + return ... # a 1-D centroid (or cells, if provides="cells") +``` + +Use it as a control via `positive=`/`negative=` in a row, or `--positive`/`--negative` at +the CLI. + +## Add a calibrator + +A calibrator declares the control roles it needs, a per-perturbation combine, and a +cross-perturbation aggregate. Add a `Calibrator` to the `CALIBRATORS` dict in +[`src/scperteval/calibrators.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/calibrators.py): + +```python +CALIBRATORS["my_score"] = Calibrator( + "my_score", ("positive", "negative"), + per_pert=lambda raws, p: ..., # raws["positive"], raws["negative"] -> one number + aggregate=lambda v: {"my_score": float(np.nanmean(v))}, + description="…", +) +``` + +Then `--output my_score` reports it. diff --git a/docs/user-guide/calibration.md b/docs/user-guide/calibration.md new file mode 100644 index 0000000..58fe04e --- /dev/null +++ b/docs/user-guide/calibration.md @@ -0,0 +1,89 @@ +# Calibration + +scPertEval assesses each **evaluation protocol** — a representation $\phi$ paired with a +metric $d$ — for its **separability**: can it reliably distinguish a perturbation's true +response from an uninformative baseline? + +Let $s(\mathcal{X}, \mathcal{Y}) = d(\phi(\mathcal{X}), \phi(\mathcal{Y}))$ denote the +protocol-induced score, where **smaller values indicate better agreement** (similarity scores +are converted beforehand). + +## Controls + +For each perturbation $a$, every protocol is evaluated against two empirical controls: + +- **Positive control** $s_{\text{pos}}^{(a)}$ — the best realistic score: comparing the + observed cells against a **technical duplicate** (a held-out replicate); for pseudobulk + protocols, an **interpolated duplicate** is used for stability. +- **Negative control** $s_{\text{neg}}^{(a)}$ — an uninformative baseline: comparing against + the **all-perturbed reference excluding $a$** (full mean for pseudobulk; a subsample of + 8 192 cells by default for single-cell distances, configurable with `--subsample`). +Ideally $s_{\text{pos}}^{(a)} < s_{\text{neg}}^{(a)}$. + +## Dynamic Range Fraction (DRF) + +DRF asks: **how much of the available signal range does the protocol actually recover?** + +$$ +\operatorname{DRF}(a) += \frac{s_{\text{neg}}^{(a)} - s_{\text{pos}}^{(a)}}{s_{\text{neg}}^{(a)} - s_{\text{optim}} + \xi} +$$ + +where $s_{\text{optim}}$ is the protocol's ideal score (0 for distance metrics; set by the +`perfect` field in the protocol spec) and $\xi > 0$ is a small stabilising constant. +The numerator is the recovered gap (how much the positive control beats the negative); +the denominator is the total available dynamic range from baseline down to optimal. + +| $\operatorname{DRF}(a)$ | Meaning | +|---|---| +| $= 1$ | positive control achieves the optimal score | +| $= 0$ | positive and negative controls score equally | +| $< 0$ | positive control performs *worse* than the uninformative baseline | + +`--output drf` reports the mean/median of $\operatorname{DRF}(a)$ across perturbations. +Introduced by {cite}`Miller_2025`. + +## Bound Discrimination Score (BDS) + +BDS asks a simpler, binary question: **for what fraction of perturbations does the protocol +get the ordering right?** + +$$ +\operatorname{BDS} += \frac{1}{|\mathcal{P}|} + \sum_{a \in \mathcal{P}} + \mathbf{1}\!\left[s_{\text{pos}}^{(a)} < s_{\text{neg}}^{(a)}\right] +$$ + +It records whether the positive control beats the negative, but not by how much. +A protocol with low BDS cannot distinguish a technical replicate from a random reference; +its scores should not be trusted regardless of their magnitude. + +`--output bds` reports this fraction. Introduced by {cite}`Vollenweider_2026`. + +## DRF vs BDS + +The two scores are complementary. BDS checks the **sign** — does +$s_{\text{pos}}^{(a)} < s_{\text{neg}}^{(a)}$? DRF checks the **magnitude** — how far along +the full dynamic range is that gap? A protocol can have high BDS (ordering consistently +correct) yet low DRF (margin negligible relative to what is achievable). Use both together: +BDS as a pass/fail gate on directionality, DRF as a quantitative measure of signal recovery. + +## In practice + +```bash +scperteval calibrate data/wessels23.h5ad -p pearson_ctrl,unbiased_mmd_median_pca_k=20,de_overlap_k=10 --de-method t-test +``` + + +| protocol | DRF (mean) | BDS | +|---|---|---| +| `pearson_ctrl` | … | … | +| `unbiased_mmd_median_pca_k=20` | … | … | +| `de_overlap_k=10` | … | … | + +A protocol with BDS < 0.5 cannot reliably order its controls — its scores should not be trusted +regardless of magnitude. A protocol with high BDS but low DRF is directionally correct but +recovers little of the available signal range. + +→ See [Usage](usage.md) for the full CLI reference, all options, and `--help` output. diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md new file mode 100644 index 0000000..261ca2d --- /dev/null +++ b/docs/user-guide/index.md @@ -0,0 +1,18 @@ +# User guide + +scPertEval runs in two modes: + +- [Scoring predictions](scoring) compares a model's output to ground truth +- [Calibration](calibration) measures whether a protocol can tell real signal from an uninformative baseline + +Start with whichever matches your goal, then see [Usage](usage). + +```{toctree} +:maxdepth: 1 + +scoring +calibration +usage +protocols +building-blocks +``` diff --git a/docs/user-guide/protocols.md b/docs/user-guide/protocols.md new file mode 100644 index 0000000..a8a0b94 --- /dev/null +++ b/docs/user-guide/protocols.md @@ -0,0 +1,230 @@ +# Protocols + +## Look up an evaluation protocol + +Two files define each protocol: + +- **[`src/scperteval/protocols/metrics.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/metrics.py)** — the metric, as a + pure function of the ground truth and a `prediction` (the candidate being scored — a positive + or negative control under `calibrate`, or your model's output under `score`). e.g. `mse`, + `mmd`, `de_auprc`: + + ```python + def mse(gt, prediction, ctx): + return float(np.mean((gt - prediction) ** 2)) + ``` + +- **[`src/scperteval/protocols/table.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/table.py)** — one row wiring that function + to its data: the data representation it receives (`representation`), feature space, + reference centering, positive/negative controls, which direction is `better` + (`"higher"`/`"lower"`), and the `perfect` score: + + ```python + Protocol("mse", M.mse, representation="centroid", + positive="interpolated", negative="all_perturbed_mean", better="lower", perfect=0.0) + ``` + +The next section breaks these arguments down while building one up from scratch. + +## Create a protocol + +A protocol is two things: a pure metric **function** and a one-line **spec** that wires it +to data and scoring. We'll ease in — the simplest possible protocol first, then the spec +broken down, then a few richer examples. + +### Start simple + +Here is a complete new protocol: mean absolute error on the standard pseudobulk profiles. + +1. Add a pure function to [`src/scperteval/protocols/metrics.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/metrics.py): + + ```python + def mae(gt, prediction, ctx): + return float(np.mean(np.abs(gt - prediction))) + ``` + + Every metric function has this signature. `gt` is one perturbation's ground-truth + profile; `prediction` is the candidate being compared against it (under `calibrate`, scPertEval + calls the function once for the positive control and once for the negative; under `score`, once + with your model's prediction). `ctx` is the dataset context, needed by only a few metrics — + ignore it otherwise. Return a single number. + +2. Add a row to [`src/scperteval/protocols/table.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/table.py): + + ```python + Protocol("mae", M.mae, representation="centroid", + positive="interpolated", negative="all_perturbed_mean", + better="lower", perfect=0.0) + ``` + +Run it with `scperteval calibrate data.h5ad -p mae`. That is the whole protocol: MAE between each +perturbation's pseudobulk profile and its positive and negative controls, scored as +lower-is-better toward a perfect of 0. + +### The spec + +That row is the spec; parameters include: + +| argument | meaning | +|---|---| +| `name` | selects the protocol on the CLI (`-p mae`) | +| `representation` | the shape of each datapoint your function receives (see below) | +| `scope` | `"perturbation"` (default) or `"dataset"` — how many perturbations at once (see below) | +| `space` | which features to score — `full` (default), or a feature space like `top_50` | +| `centering` | a baseline subtracted before scoring, e.g. `"ctrl"` (default: none) | +| `positive` / `negative` | the two control sources to compare | +| `better` | `"higher"` or `"lower"` — which direction is an improvement | +| `perfect` | the value a flawless prediction attains | +| `param` | optional — a parameter family (`top_k`, `pca_k`, `degs_padj`, `overlap_k`) that makes the protocol tunable from the CLI; omit for a fixed protocol | + +**`representation`** decides the *shape* of each datapoint — the format `gt` and +`prediction` arrive in — so you never deal with sampling, references, or projection yourself: + +| `representation` | a datapoint is | +|---|---| +| `centroid` | a 1-D pseudobulk vector (one value per gene) | +| `population` | a `(cells × genes)` matrix | +| `de` | a `DEResult` (for the ground truth) / per-gene `\|score\|` ranking (for a prediction) | + +**`scope`** is the independent companion axis — *how many* perturbations the metric sees at once: + +| `scope` | the metric is called | +|---|---| +| `perturbation` (default) | once per perturbation — gets that perturbation's `(gt, prediction)` datapoints and returns a scalar | +| `dataset` | once for the whole dataset — gets the **list** of every perturbation's `gt` and `prediction` datapoints and returns one score per perturbation (e.g. a retrieval `rank`) | + +The two compose freely: `rank` is just `representation="centroid", scope="dataset"`; a +distributional retrieval metric would be `representation="population", scope="dataset"`. + +Many rows repeat the same wiring, so the top of `table.py` predefines the common +combinations as plain dicts. You then unpack one into a row with `**` (Python's +keyword-expansion syntax) to avoid retyping it: + +```python +_PB = dict(group="pseudobulk", positive="interpolated", negative="all_perturbed_mean") +_LOWER = dict(better="lower", perfect=0.0) +``` + +With those, the `mae` row above is exactly `Protocol("mae", M.mae, +representation="centroid", **_PB, **_LOWER)` — same protocol, less repetition. You'll see +these bundles reused throughout the table. + +### Building blocks — the palette + +The values those arguments take — feature spaces, control sources, DE methods, calibrators +— are registered building blocks. `scperteval list ` shows what's available +in each, with descriptions: + +**Feature spaces** (the `space` argument) + +```bash +$ scperteval list spaces +degs_0.05 — ground-truth DEGs at adjusted p < 0.05, per perturbation +full — all genes, no transform +pca_50 — top 50 principal components (fit on the dataset) +top_50 — top 50 genes by ground-truth effect size, per perturbation +``` + +`top_` / `pca_` / `degs_` are parameterised families (the defaults are shown); +a protocol template picks the value. If the space you need isn't here, see +[Add a feature space](building-blocks.md#add-a-feature-space). + +**DE methods** (the `--de-method` choice) + +```bash +$ scperteval list de-methods +MWU — Mann-Whitney U / Cliff's delta effect size (via illico) +t-test — Welch's t-test (default) — moment-based and fast +``` + +Chosen with `--de-method`; it applies to **every** DE-dependent unit (the `interpolated` +positive control, the `top_k`/`degs` spaces, the `de_*` protocols, and the WMSE weights). +To add another, see [Add a DE method](building-blocks.md#add-a-de-method). + +**Control sources** (the `positive` / `negative` arguments) + +```text +$ scperteval list sources +all_perturbed (cells) — all-perturbed reference sample, leave-one-out (single-cell negative control) +all_perturbed_mean (centroid) — all-perturbed mean, excluding the target — leave-one-out (pseudobulk sibling of all_perturbed; pseudobulk negative control) +control (cells) — non-targeting control cells +global_mean (centroid) — mean of all perturbations — shared baseline for the ranking protocols +gt_all_cells (cells) — ground truth — all of a perturbation's real cells (prediction-scoring truth) +gt_half (cells) — ground truth — the first half of a perturbation's cells (calibration truth) +interpolated (centroid) — interpolated duplicate — DE-weighted blend of the held-out half and the dataset mean (pseudobulk positive control) +prediction (cells) — model-predicted cells for the perturbation, from the --predictions h5ad +tech_dup (cells) — technical duplicate — the held-out second half (single-cell positive control) +``` + +Each `provides` cells or a pseudobulk `centroid`. Use via `positive=`/`negative=` (or +`--positive`/`--negative`). The truth source is chosen by the command, not by a protocol: +`calibrate` uses `gt_half` (holding the other half out as the positive control), while `score` +uses `gt_all_cells` and compares it to `prediction`. To add another, see +[Add a control source](building-blocks.md#add-a-control-source). + +**Calibrators** (the `--output` choice) + +```bash +$ scperteval list calibrators +drf — Dynamic Range Fraction — mean/median over perturbations (Miller et al. 2025) +bds — Bound Discrimination Score — fraction of perturbations the positive control wins (SBB 2026) +score — raw metric of a prediction vs ground truth — mean/median over perturbations (prediction-scoring mode) +``` + +`drf`/`bds` are chosen with `calibrate --output`; `score` is selected automatically by the +`score` command. To add another, see [Add a calibrator](building-blocks.md#add-a-calibrator). + +### More examples + +With the spec and the palette in hand, richer protocols are just different combinations. + +**Same wiring, different metric.** Cosine distance on pseudobulk reuses the bundles wholesale: + +```python +def cosine(gt, prediction, ctx): + return 1.0 - float(gt @ prediction / (np.linalg.norm(gt) * np.linalg.norm(prediction))) +``` + +```python +Protocol("cosine", M.cosine, representation="centroid", **_PB, **_LOWER) +``` + +**Restrict to a feature space.** Set `space` to score only some genes — e.g. MAE on the +top-50 DEGs: + +```python +Protocol("mae_top50", M.mae, representation="centroid", space="top_50", **_PB, **_LOWER) +``` + +**Expose the space as a knob (parameterised).** To make `k` adjustable per invocation, add a +`param` to the same `Protocol(...)` row — nothing else changes. The row's name carries the +parameter, and the value is supplied on the CLI: + +```python +Protocol("mae_top_k", M.mae, representation="centroid", param=top_k, **_PB, **_LOWER) +``` + +Then `scperteval calibrate data.h5ad -p mae_top_k=30` (or `mae_top_k` for the default `k=50`). The +families are `top_k` (top-k DEGs), `pca_k` (k PCs), and `degs_padj` (DEGs at adjusted +p < padj) for the space, and `overlap_k` to feed an integer straight to the metric. + +**A metric over cells, not profiles.** Switch `representation` to `population` and your +function receives `(cells × genes)` matrices; pair it with the single-cell controls: + +```python +def my_mmd(gt, prediction, ctx): # gt, prediction are (cells × genes) + ... +``` + +```python +Protocol("my_mmd_top50", M.my_mmd, representation="population", space="top_50", + positive="tech_dup", negative="all_perturbed", better="lower", perfect=0.0) +``` + +This changes two pieces at once — the `representation` (so the function sees cells) and the controls +(the single-cell positive/negative) — which is the general pattern for a distributional +protocol. + +By now you've seen every moving part: the function, the spec, the building blocks the spec +draws on, fixed and parameterised spaces, and switching the representation the function +sees. Most new metrics are some combination of these. diff --git a/docs/user-guide/scoring.md b/docs/user-guide/scoring.md new file mode 100644 index 0000000..ecda9cb --- /dev/null +++ b/docs/user-guide/scoring.md @@ -0,0 +1,57 @@ +# Scoring predictions + +Scoring answers the question: **how well does a model's predicted response match the real +perturbation response?** Each evaluation protocol is applied to the predicted cells against the +real cells, yielding one score per perturbation. It is the conventional benchmarking step — the +number you report in a paper. + +This is distinct from [calibration](calibration.md), which asks a prior question: is a given +protocol trustworthy enough to report in the first place? Use calibration to select protocols, +then scoring to evaluate your model with them. + +`scperteval score dataset.h5ad predictions.h5ad` runs the *same* protocol catalog as +[`calibrate`](calibration.md); only two pieces differ. + +## Inputs + +- **ground truth** — *all* of a perturbation's real cells (the `gt_all_cells` source). Unlike + calibration, no half is held out and no positive/negative controls are built — the ground truth + is the whole real population. +- **prediction** — the matching cells from your `predictions.h5ad` (the `prediction` source). + The prediction file must contain the dataset's exact gene set (any order — columns are + reordered by name so the comparison lines up gene-for-gene) and the same perturbation labels. + A gene-set mismatch, or a perturbation present in the dataset but absent from the predictions, + raises an error naming exactly what's wrong. + +## Output + +The `score` calibrator reports each protocol's raw metric value per perturbation and its +mean/median across perturbations, written to `____score.csv`. Higher- vs +lower-is-better follows each protocol's `better` field, exactly as in calibration. + +```bash +scperteval score data/wessels23.h5ad predictions.h5ad -p pearson,mse,de_auprc,unbiased_mmd_median_pca_k=20 +``` + +| protocol | representation | perfect prediction | degraded prediction | +|---|---|---|---| +| `pearson` | centroid | 1.000 | 0.993 | +| `mse` | centroid | 0.000 | 0.004 | +| `de_auprc` | de | 1.000 | 0.297 | +| `unbiased_mmd_median_pca_k=20` | population | ≈0 | 0.199 | + +(An exact replica of the real cells scores optimally; a prediction degraded toward the control +mean scores worse on every representation.) + +## How it relates to calibration + +Architecturally this reuses everything — the per-perturbation loop, every metric, representation, +and feature space are shared with `calibrate`. The only differences are the **truth source** +(`gt_all_cells` instead of the held-out `gt_half`) and the **calibrator** (`score`, which needs +only the prediction, instead of `drf`/`bds`, which need both controls). The DE-derived feature +spaces (`top_k`, `degs`) and the WMSE weights are computed from this same all-cells ground truth. + +Use `score` to measure how good a model's predictions are; use [`calibrate`](calibration.md) to +decide whether a given protocol is trustworthy enough to report those scores in the first place. + +→ See [Usage](usage.md) for the full CLI reference, all options, and `--help` output. diff --git a/docs/user-guide/usage.md b/docs/user-guide/usage.md new file mode 100644 index 0000000..38c56cc --- /dev/null +++ b/docs/user-guide/usage.md @@ -0,0 +1,156 @@ +# Usage + +## Input data + +scPertEval reads one preprocessed AnnData (`.h5ad`) per dataset. Only three things are required: + +- **`adata.X`** — normalized expression, cells × genes (e.g. `sc.pp.normalize_total` + `sc.pp.log1p`); sparse or dense float. +- **`adata.obs["perturbation"]`** — the perturbation label for each cell; control cells use the label `"control"`. Both names are configurable (`--perturbation-key` / `--control-label`). +- **`adata.var_names`** — gene identifiers, used as the DEG labels. + +Perturbations with at least `--min-cells` cells (default 30) are evaluated. Nothing else is +needed — references, DE, and PCA are all recomputed in memory, so no `uns`/`obsm`/`layers` are read. + +**Sample datasets.** Seven preprocessed perturbation datasets live in a public, read-only GCS +bucket and serve as a template for the format above: + +```bash +gsutil ls gs://scperteval/processed/ # wessels23, replogle22{k562,rpe1}, nadig25{hepg2,jurkat}, arch1, kaden25rpe1 +gsutil cp gs://scperteval/processed/wessels23_processed_complete.h5ad . +``` + +No gcloud account is needed — each file is also reachable over plain HTTPS at +`https://storage.googleapis.com/scperteval/processed/_processed_complete.h5ad`. + +## Run it + +The same protocol catalog backs three commands: + +- **`calibrate`** — calibrate a protocol against built-in controls → DRF/BDS (see [Calibration](calibration.md)) +- **`score`** — score a model's predictions against ground truth (see [Scoring predictions](scoring.md)) +- **`de`** — export per-gene differential expression + +### Calibrate + +```bash +# protocols by name — including parameterised ones (set k / padj per protocol) +scperteval calibrate data/wessels23.h5ad -p pearson_ctrl,unbiased_mmd_median_pca_k=20,de_overlap_k=10 --de-method t-test + +# a parameterised protocol with no value uses its default (k=50, padj=0.05) +scperteval calibrate data/wessels23.h5ad -p unbiased_mmd_median_top_k --de-method MWU + +# a whole group, or everything (parameterised protocols use their defaults) +scperteval calibrate data/wessels23.h5ad -p distributional --de-method MWU +scperteval calibrate data/wessels23.h5ad -p all --de-method t-test + +# DRF calibration only (compute DRF only; exclude BDS) +scperteval calibrate data/wessels23.h5ad -p pearson_ctrl --de-method t-test --output drf +``` + +#### Output + +Prints a summary table and writes `____drf.csv` / `…__bds.csv` — raw control +values and calibrated DRF/BDS per perturbation. `--profile` adds a per-protocol wall-clock timing CSV. + +
scperteval calibrate --help + +```text +usage: scperteval calibrate [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] + [--subsample SUBSAMPLE] [--seed SEED] [--positive POSITIVE] + [--negative NEGATIVE] [--output {drf,bds}] [--out-dir OUT_DIR] + [--workers WORKERS] [--perturbation-key PERTURBATION_KEY] + [--control-label CONTROL_LABEL] [--min-cells MIN_CELLS] + [--profile] [--quiet] + dataset + + -p, --protocols comma-separated names (parameterised as name=value, e.g. + mse_top_k=30), a group (pseudobulk|distributional|de), or 'all' + --de-method {MWU, t-test, t-test_overestim_var} DE backend for every DE unit: + the interpolated positive control, the top_k/degs spaces, + the de_* protocols, and the WMSE weights + --subsample cells in the single-cell reference sample (default 8192) + --output {drf, bds} how per-perturbation values are calibrated + --positive/--negative override a protocol's controls by source name + --min-cells skip perturbations with fewer cells + --profile also write a per-protocol wall-clock timing table +``` + +
+ +### Score + +Score predictions against ground truth — predicted cells vs real cells, per protocol. +Predictions must have the same genes and perturbation labels as the dataset. + +```bash +scperteval score data/wessels23.h5ad predictions.h5ad -p pearson,mse,de_auprc --de-method t-test +``` + +#### Output + +Prints a summary table and writes `____score.csv` — raw metric value per +perturbation. `--profile` adds a per-protocol wall-clock timing CSV. + +
scperteval score --help + +```text +usage: scperteval score [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] + [--subsample SUBSAMPLE] [--seed SEED] [--out-dir OUT_DIR] [--workers WORKERS] + [--perturbation-key PERTURBATION_KEY] [--control-label CONTROL_LABEL] + [--min-cells MIN_CELLS] [--profile] [--quiet] + dataset predictions + + dataset preprocessed .h5ad — the ground truth (real cells) + predictions predicted .h5ad — same genes and perturbation labels as the dataset + -p, --protocols comma-separated names, a group, or 'all' + --de-method DE backend for the de_* protocols, the top_k/degs spaces, and WMSE weights + --subsample cells in the all-perturbed reference (the ground truth is never subsampled) +``` + +Unlike `calibrate`, there are no `--positive`/`--negative`/`--output` options: the candidate is +always your prediction and the output is always the raw `score`. + +
+ +### DE + +Export per-gene differential expression to HDF5 — provided as a convenience since DE methods are +tightly coupled with some evaluation protocols. + +```bash +scperteval de data/wessels23.h5ad --methods MWU +``` + +### Discover what's available + +```bash +scperteval list protocols # also: de-methods | spaces | sources | calibrators +``` + +**DE backends** (`scperteval list de-methods`): + +- `t-test` (default, Welch's, moment-based) +- `MWU` (Cliff's δ via illico) +- `t-test_overestim_var` (scanpy's conservative-variance variant — the reference variance is scaled by the target's cell count). + +Select one with `--de-method` for a `calibrate`/`score`, or list several with `--methods` for a `de` export. The +overestim variant is a selectable backend for new protocols; no current protocol uses it. + +## Use it from Python + +Install with `pip install scperteval` (or, from this repo, +`pip install "scperteval @ git+https://github.com/Virtual-Cell-Research-Community/scPertEval.git"`). +The simplest path mirrors the CLI — call it via subprocess, exactly as the figure notebook does: + +```python +import subprocess, sys + +subprocess.run([sys.executable, "-m", "scperteval", "calibrate", "data/wessels23.h5ad", + "-p", "all", "--de-method", "t-test", "--out-dir", "results"], check=True) +# -> results/wessels23____drf.csv (raw control values + calibrated DRF per perturbation) + +# score predictions against ground truth instead: +subprocess.run([sys.executable, "-m", "scperteval", "score", "data/wessels23.h5ad", + "predictions.h5ad", "-p", "all", "--out-dir", "results"], check=True) +# -> results/wessels23____score.csv (raw metric value per perturbation) +``` diff --git a/pyproject.toml b/pyproject.toml index 2d0924d..a2bc892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,30 +1,121 @@ +[build-system] +build-backend = "hatchling.build" +requires = [ "hatchling" ] + [project] name = "scperteval" version = "0.1.0" description = "Evaluation Protocols for Perturbation Studies: per-metric DRF/BDS calibration on a single preprocessed dataset." -requires-python = ">=3.10" +readme = "README.md" +license = { file = "LICENSE" } +authors = [ + { name = "Zach Boldyga" }, +] +requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", +] dependencies = [ - "anndata", - "scanpy", - "numpy", - "scipy", - "scikit-learn", - "pandas", - "geomloss", - "torch", - "illico", - "h5py", + "anndata", + "geomloss", + "h5py", + "illico", + "numpy", + "pandas", + "scanpy", + "scikit-learn", + "scipy", + "torch", +] +urls.Documentation = "https://scperteval.readthedocs.io/" +urls.Homepage = "https://github.com/Virtual-Cell-Research-Community/scPertEval" +urls.Source = "https://github.com/Virtual-Cell-Research-Community/scPertEval" +scripts.scperteval = "scperteval.cli:main" + +[dependency-groups] +dev = [ + { include-group = "docs" }, + { include-group = "lint" }, + { include-group = "test" }, +] +docs = [ + "ipykernel", + "ipython", + "myst-nb>=1.1", + "pandas", + "sphinx>=8.1", + "sphinx-autobuild>=2025.8.25", + "sphinx-autodoc-typehints", + "sphinx-book-theme>=1", + "sphinx-copybutton", + "sphinx-design", + "sphinxcontrib-bibtex>=1", + "sphinxcontrib-katex", + "sphinxext-opengraph", +] +lint = [ + "mypy", + "pyright", + "ruff", +] +test = [ + "pytest", ] -[project.scripts] -scperteval = "scperteval.cli:main" +[tool.hatch] +build.targets.wheel.packages = [ "src/scperteval" ] -[project.optional-dependencies] -test = ["pytest"] +[tool.ruff] +target-version = "py311" +line-length = 120 +src = [ "src" ] +extend-include = [ "*.ipynb" ] +format.docstring-code-format = true +lint.select = [ + "B", # flake8-bugbear — common bug patterns (mutable defaults, unused loop vars…) + "BLE", # flake8-blind-except — catch bare `except Exception` + "C4", # flake8-comprehensions — idiomatic list/dict/set comprehensions + "D", # pydocstyle — numpy docstring convention + "E", # pycodestyle errors (E722 bare except, E731 lambda assign, E741 ambiguous names…) + "F", # Pyflakes (unused imports F401, undefined names, unused variables…) + "I", # isort — consistent import ordering + "PTH", # flake8-use-pathlib — replace os.path.* and open() with pathlib equivalents + "RUF", # Ruff-specific — list unpacking, quadratic sum, unused unpacked vars… + "SIM", # flake8-simplify — simplify if/else, dict.keys() iteration… + "UP", # pyupgrade — modern Python syntax (UP006 built-in generics, UP007 x | None…) + "W", # pycodestyle warnings (W291 trailing whitespace, W605 invalid escape sequence…) +] +lint.ignore = [ + "B905", # `zip()` without an explicit `strict=` parameter + "C408", # Unnecessary dict(), list() or tuple() calls that can be rewritten as empty literals. + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic methods + "D107", # Missing docstring in __init__ + "D203", # No blank line before class docstring (incompatible with D211) + "D213", # Multi-line summary second line (incompatible with D212) + "D400", # First line should end with period (breaks single-line docstrings) + "D401", # First line in imperative mood + "E501", # line too long — formatter enforces line-length = 120, remaining cases are unfixable strings/comments + "RUF001", # ambiguous unicode in strings — intentional (×, –, ➕ used in UI labels and math) + "RUF002", # ambiguous unicode in docstrings — same rationale + "RUF003", # ambiguous unicode in comments — same rationale +] +lint.per-file-ignores."*/__init__.py" = [ "F401" ] +lint.per-file-ignores."docs/*" = [ "I" ] +lint.per-file-ignores."tests/*" = [ "D" ] +lint.pydocstyle.convention = "numpy" -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +[tool.mypy] +mypy_path = [ "src" ] +exclude = [ "^docs/" ] +ignore_missing_imports = true +check_untyped_defs = true +warn_unreachable = true -[tool.hatch.build.targets.wheel] -packages = ["scperteval"] +[tool.pyright] +exclude = [ "**/.*", "**/__pycache__", "**/node_modules", ".venv", "docs/**" ] diff --git a/scperteval/__init__.py b/scperteval/__init__.py deleted file mode 100644 index a320ab4..0000000 --- a/scperteval/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Evaluation Protocols for Perturbation Studies.""" -import os as _os - -for _v in ("OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", - "NUMEXPR_NUM_THREADS", "VECLIB_MAXIMUM_THREADS"): - _os.environ.setdefault(_v, "1") diff --git a/scperteval/blocks/__init__.py b/scperteval/blocks/__init__.py deleted file mode 100644 index 404db76..0000000 --- a/scperteval/blocks/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""Pluggable building blocks: DE methods and feature spaces.""" -from . import de, spaces # noqa: F401 (import for registration side effects) diff --git a/scperteval/blocks/spaces.py b/scperteval/blocks/spaces.py deleted file mode 100644 index 96f8ab7..0000000 --- a/scperteval/blocks/spaces.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Feature spaces: a transform applied to the gene axis before a protocol runs. - -Spaces receive the raw (possibly sparse) cells and return a dense array, so a -gene-subset space densifies only its subset. The parameterised families -``top_`` / ``pca_`` / ``degs_`` are registered on demand by the -``top_space`` / ``pca_space`` / ``degs_space`` factories (used by the protocol -templates); the default instances created at import are what ``scperteval list spaces`` -shows. ``description`` is shown by ``scperteval list spaces``. -""" -from __future__ import annotations - -import numpy as np - -from ..dataset import to_dense -from ..registry import Registry - -SPACES = Registry("space") - - -@SPACES.register("full", global_space=True, description="all genes, no transform") -def space_full(X, ctx, pert): - return to_dense(X) - - -def _field(de, name): - return de.extra[name.split(":", 1)[1]] if name.startswith("extra:") else getattr(de, name) - - -def register_de_space(name, field, top=None, threshold=None, description=""): - """Register a DE-derived gene subset selected from a field of the GT DEResult.""" - - def space(X, ctx, pert): - values = _field(ctx.de(pert, ctx.cfg.truth), field) - keep = np.argsort(-np.abs(values))[:top] if top is not None else np.where(threshold(values))[0] - return to_dense(X[:, keep]) - - SPACES.add(name, space, description=description) - return name - - -def top_space(k: int) -> str: - """top-k genes by |ground-truth effect size| (registered on demand).""" - name = f"top_{k}" - if name not in SPACES: - register_de_space(name, field="score", top=k, - description=f"top {k} genes by ground-truth effect size, per perturbation") - return name - - -def degs_space(padj: float) -> str: - """ground-truth DEGs at adjusted p < padj (registered on demand).""" - name = f"degs_{padj:g}" - if name not in SPACES: - register_de_space(name, field="pvalue_adj", threshold=(lambda v, p=padj: v < p), - description=f"ground-truth DEGs at adjusted p < {padj:g}, per perturbation") - return name - - -def pca_space(k: int) -> str: - """top-k principal components (registered on demand).""" - name = f"pca_{k}" - if name not in SPACES: - SPACES.add(name, lambda X, ctx, pert, k=k: ctx.pca(k).transform(to_dense(X))[:, :k], - global_space=True, description=f"top {k} principal components (fit on the dataset)") - return name - - -# Default instances — also what `scperteval list spaces` shows. -top_space(50) -pca_space(50) -degs_space(0.05) diff --git a/scperteval/protocols/metrics.py b/scperteval/protocols/metrics.py deleted file mode 100644 index 58515e7..0000000 --- a/scperteval/protocols/metrics.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Evaluation-protocol metrics — the exact implementation of every metric. - -A metric takes the ground-truth and a prediction (whichever control is being scored) plus -the context, and returns a score. The protocol's ``representation`` sets each datapoint's -shape — ``centroid`` -> a 1-D pseudobulk vector, ``population`` -> a (cells x genes) array, -``de`` -> a DEResult (GT) / |score| ranking (prediction). Its ``scope`` sets the call: a -``perturbation``-scope metric gets one perturbation's (gt, prediction) and returns a scalar; -a ``dataset``-scope metric gets the list of every perturbation's gt and prediction and -returns one score per perturbation (e.g. ``rank_retrieval``). - -Every metric is implemented in full here; only external numerical libraries (numpy, -scikit-learn, geomloss) are relied upon. So a metric is completely defined by its function -below plus its row in ``table.py`` — nothing is hidden behind another layer. -""" -from __future__ import annotations - -import numpy as np -from sklearn.metrics import average_precision_score, roc_auc_score - - -def _sq_dists(X, Y): - """Pairwise squared euclidean distances via ||x||^2 + ||y||^2 - 2 x.y. - - Routed through a BLAS matrix product, which releases the GIL so the - per-perturbation thread pool actually parallelises. - """ - xx = np.einsum("ij,ij->i", X, X) - yy = np.einsum("ij,ij->i", Y, Y) - sq = xx[:, None] + yy[None, :] - 2.0 * (X @ Y.T) - return np.maximum(sq, 0.0) - - -def _within_unbiased(sq, n): - """Unbiased (U-statistic) mean within-population euclidean distance.""" - if n <= 1: - return 0.0 - return float(np.sqrt(sq).sum() / (n * (n - 1))) - - -def pearson(gt, prediction, ctx): - return float(np.corrcoef(gt, prediction)[0, 1]) - - -def mse(gt, prediction, ctx): - return float(np.mean((gt - prediction) ** 2)) - - -def weighted_mse(gt, prediction, ctx, exp=2.0): - w = ctx.wmse_weights(ctx.current_pert) ** exp - total = w.sum() - w = w / total if total > 0 else np.full(w.size, 1.0 / w.size) - return float(np.sum(w * (gt - prediction) ** 2)) - - -def energy_distance(gt, prediction, ctx): - """Szekely-Rizzo energy distance with bias-corrected within terms.""" - if len(gt) == 0 or len(prediction) == 0: - return float("nan") - X = gt.astype(np.float64) - Y = prediction.astype(np.float64) - cross = np.sqrt(_sq_dists(X, Y)).mean() - xx = _within_unbiased(_sq_dists(X, X), len(X)) - yy = _within_unbiased(_sq_dists(Y, Y), len(Y)) - return float(2.0 * cross - xx - yy) - - -def unbiased_mmd_median(gt, prediction, ctx): - """Unbiased RBF-MMD^2 with a single median-heuristic bandwidth (Gretton 2012).""" - if len(gt) < 2 or len(prediction) < 2: - return float("nan") - X = gt.astype(np.float64) - Y = prediction.astype(np.float64) - nx, ny = len(X), len(Y) - pooled = np.vstack([X, Y]) - euc = np.sqrt(_sq_dists(pooled, pooled)) - n = euc.shape[0] - sigma = float(np.median(euc[~np.eye(n, dtype=bool)])) - if sigma <= 0: - return 0.0 - gamma = 1.0 / (2.0 * sigma * sigma) - k_xx = np.exp(-gamma * _sq_dists(X, X)) - k_yy = np.exp(-gamma * _sq_dists(Y, Y)) - k_xy = np.exp(-gamma * _sq_dists(X, Y)) - xx = (k_xx.sum() - np.trace(k_xx)) / (nx * (nx - 1)) - yy = (k_yy.sum() - np.trace(k_yy)) / (ny * (ny - 1)) - return float(xx + yy - 2.0 * k_xy.mean()) - - -_geomloss_cache: dict = {} - - -def sinkhorn_w2(gt, prediction, ctx, blur=0.05): - """Debiased Sinkhorn 2-Wasserstein distance (geomloss, p=2).""" - if len(gt) == 0 or len(prediction) == 0: - return float("nan") - import torch - from geomloss import SamplesLoss - - loss = _geomloss_cache.get(blur) - if loss is None: - torch.set_num_threads(1) - loss = SamplesLoss(loss="sinkhorn", p=2, blur=blur, debias=True, backend="tensorized") - _geomloss_cache[blur] = loss - Xt = torch.as_tensor(np.ascontiguousarray(gt), dtype=torch.float32) - Yt = torch.as_tensor(np.ascontiguousarray(prediction), dtype=torch.float32) - a = torch.full((len(gt),), 1.0 / len(gt), dtype=torch.float32) - b = torch.full((len(prediction),), 1.0 / len(prediction), dtype=torch.float32) - with torch.no_grad(): - val = float(loss(a, Xt, b, Yt)) - return float(np.sqrt(max(2.0 * val, 0.0))) - - -def rank_retrieval(gt, prediction, ctx, transpose=False): - """Cross-perturbation retrieval rank (0 = best, lower is better) — a dataset-scope metric. - - ``gt`` and ``prediction`` are the lists of every perturbation's centroid (one per - perturbation); this returns one score per perturbation. In the prediction-vs-GT - squared-distance matrix, ``rank`` ranks each GT's own prediction against all predictions - (column-wise); ``transpose_rank`` ranks each prediction's own GT against all GTs - (row-wise). Normalised by n-1, with the drf tie-breaking noise (seed 42). - """ - G = np.vstack(gt) - P = np.vstack(prediction) - sq = _sq_dists(P, G) - if transpose: - sq = sq.T - n = sq.shape[0] - noise = np.random.default_rng(42).uniform(0, 1e-12, size=sq.shape) - ranks = np.argsort(np.argsort(sq + noise, axis=0), axis=0) - return np.diag(ranks).astype(np.float64) / max(n - 1, 1) - - -def de_auprc(gt, prediction, ctx): - labels = gt.pvalue_adj < 0.05 - if labels.sum() == 0 or labels.sum() == labels.size: - return float("nan") - return float(average_precision_score(labels, prediction)) - - -def de_auroc(gt, prediction, ctx): - labels = gt.pvalue_adj < 0.05 - if labels.sum() == 0 or labels.sum() == labels.size: - return float("nan") - return float(roc_auc_score(labels, prediction)) - - -def de_overlap(gt, prediction, ctx, k=50): - truth = np.abs(gt.score) - if k >= truth.size: - return float("nan") - top_truth = np.argpartition(-truth, k - 1)[:k] - top_prediction = np.argpartition(-prediction, k - 1)[:k] - return float(np.intersect1d(top_truth, top_prediction).size) / k diff --git a/scperteval/sources.py b/scperteval/sources.py deleted file mode 100644 index 93ffd70..0000000 --- a/scperteval/sources.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Control/reference sources: each yields a perturbation's cells or pseudobulk centroid. - -A source's positive/negative role is chosen at the CLI; the registry just produces -the data. ``provides`` ("cells" or "centroid") drives the runner's compatibility -check and how the context turns a source into a view. ``description`` is shown by -``scperteval list sources``. -""" -from __future__ import annotations - -import numpy as np - -from .dataset import to_dense -from .registry import Registry - -SOURCES = Registry("source") - - -@SOURCES.register("gt_half", provides="cells", - description="ground truth — the first half of a perturbation's cells (calibration truth)") -def src_gt_half(ctx, pert): - return ctx.ds.cells(pert, half="first") - - -@SOURCES.register("gt_all_cells", provides="cells", - description="ground truth — all of a perturbation's real cells (prediction-scoring truth)") -def src_gt_all_cells(ctx, pert): - return ctx.ds.cells(pert) - - -@SOURCES.register("prediction", provides="cells", - description="model-predicted cells for the perturbation, from the --predictions h5ad") -def src_prediction(ctx, pert): - return ctx.predictions.cells(pert) - - -@SOURCES.register("tech_dup", provides="cells", - description="technical duplicate — the held-out second half (single-cell positive control)") -def src_tech_dup(ctx, pert): - return ctx.ds.cells(pert, half="second") - - -@SOURCES.register("control", provides="cells", - description="non-targeting control cells") -def src_control(ctx, pert): - return ctx.ds.control_cells(ctx.cfg.subsample) - - -@SOURCES.register("all_perturbed", provides="cells", - description="all-perturbed reference sample, leave-one-out (single-cell negative control)") -def src_all_perturbed(ctx, pert): - return ctx.reference().subset(pert) - - -@SOURCES.register("all_perturbed_mean", provides="centroid", - description="all-perturbed mean, excluding the target — leave-one-out " - "(pseudobulk sibling of all_perturbed; pseudobulk negative control)") -def src_all_perturbed_mean(ctx, pert): - return ctx.ds.allpert_mean_except(pert) - - -@SOURCES.register("global_mean", provides="centroid", - description="mean of all perturbations — shared baseline for the ranking protocols") -def src_global_mean(ctx, pert): - return ctx.ds.allpert_mean() - - -@SOURCES.register("interpolated", provides="centroid", - description="interpolated duplicate — DE-weighted blend of the held-out half and " - "the dataset mean (pseudobulk positive control)") -def src_interpolated(ctx, pert): - """alpha = 1 - adjusted p per gene (from the run's DE method, vs control); blend toward - the held-out replicate where the gene is significant, else toward the all-perturbed mean.""" - tech = np.asarray(to_dense(ctx.ds.cells(pert, half="second"))).mean(0) - alpha = np.nan_to_num(1.0 - ctx.de(pert, "tech_dup", "control").pvalue_adj, nan=0.0) - return alpha * tech + (1.0 - alpha) * ctx.ds.allpert_mean_except(pert) diff --git a/scperteval/types.py b/scperteval/types.py deleted file mode 100644 index ecb7bae..0000000 --- a/scperteval/types.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Core dataclasses shared across the package.""" -from __future__ import annotations - -from dataclasses import dataclass, field, replace -from functools import partial -from typing import Callable, Optional - -import numpy as np - - -@dataclass -class RunConfig: - """Resolved options for a single run.""" - - dataset: str - protocols: list[str] - de_method: str = "t-test" - subsample: int = 8192 - seed: int = 42 - positive: str = "auto" - negative: str = "auto" - truth: str = "gt_half" - predictions: Optional[str] = None - output: str = "drf" - out_dir: str = "results" - workers: int = 0 - min_cells: int = 30 - perturbation_key: str = "perturbation" - control_label: str = "control" - profile: bool = False - - -@dataclass(frozen=True) -class DEResult: - """Per-gene differential expression for one target-vs-reference comparison.""" - - score: np.ndarray - pvalue: np.ndarray - pvalue_adj: np.ndarray - extra: dict = field(default_factory=dict) - - -@dataclass(frozen=True) -class Param: - """A protocol's tunable knob — how a CLI value (``k=30``, ``padj=0.05``) is cast, - defaulted, and applied. ``space`` maps the value to a feature-space name; when it is - ``None`` the value is passed straight to the metric as a keyword argument. - """ - - name: str - cast: Callable - default: float - space: Optional[Callable] = None - - -@dataclass(frozen=True) -class Protocol: - """An evaluation protocol: a pure metric plus its data and control wiring. - - ``representation`` and ``scope`` are independent and together decide what the metric - receives: - - - ``representation`` — the shape of one perturbation's datapoint: ``"centroid"`` (a 1-D - pseudobulk vector), ``"population"`` (a cells × genes matrix), or ``"de"`` (a DEResult). - - ``scope`` — ``"perturbation"`` (default): the metric is called once per perturbation, - gets that perturbation's ``(gt, prediction)`` datapoints, and returns a scalar. - ``"dataset"``: the metric is called once, gets the list of *every* perturbation's - ``gt`` and ``prediction`` datapoints, and returns one score per perturbation (e.g. a - cross-perturbation retrieval rank). - - Set ``param`` to make the protocol *tunable* — its feature space (or a metric argument) - is then chosen per run from a CLI value, e.g. ``-p mse_top_k=30``; with no value the - parameter's default is used. Leave ``param`` unset for a fully-specified protocol. - - ``better`` and ``perfect`` describe the metric's score scale and are independent: - - - ``better`` — which direction is an improvement, ``"higher"`` or ``"lower"``. - Correlations and overlaps improve as they go up (``"higher"``); errors and - distances improve as they go down (``"lower"``). This is the metric's *sense*, - and it is not implied by ``perfect`` — e.g. perplexity has ``perfect=1.0`` yet - ``better="lower"``, and a log-likelihood has ``perfect=0.0`` yet ``better="higher"``. - - ``perfect`` — the value a flawless prediction attains (1.0 for a correlation, - 0.0 for an error). It anchors the top of the DRF scale. - - Together they let the calibrator orient the score: DRF measures how far the positive - control moves from the negative-control floor toward ``perfect`` in the ``better`` - direction, and BDS counts the perturbations where the positive control is ``better`` - than the negative. - """ - - name: str - metric: Callable - representation: str - scope: str = "perturbation" - space: str = "full" - centering: Optional[str] = None - reference: str = "all_perturbed" - neg_reference: Optional[str] = None - better: str = "higher" - perfect: float = 1.0 - positive: str = "auto" - negative: str = "auto" - group: str = "" - param: Optional[Param] = None - - @property - def parameterised(self) -> bool: - return self.param is not None - - def resolve(self, value) -> "Protocol": - """Concrete protocol for a tunable one at ``value`` (sets the space or metric arg).""" - suffix = f"{value:g}" if isinstance(value, float) else str(value) - name = f"{self.name}={suffix}" - if self.param.space is not None: - return replace(self, name=name, space=self.param.space(value), param=None) - metric = partial(self.metric, **{self.param.name: value}) - return replace(self, name=name, metric=metric, param=None) - - -@dataclass(frozen=True) -class Calibrator: - """Turns per-control raw metric values into per-perturbation and aggregate scores.""" - - name: str - requires: tuple[str, ...] - per_pert: Callable - aggregate: Callable - description: str = "" diff --git a/src/scperteval/__init__.py b/src/scperteval/__init__.py new file mode 100644 index 0000000..2a196bc --- /dev/null +++ b/src/scperteval/__init__.py @@ -0,0 +1,12 @@ +"""Evaluation Protocols for Perturbation Studies.""" + +import os as _os + +for _v in ( + "OMP_NUM_THREADS", + "OPENBLAS_NUM_THREADS", + "MKL_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + "VECLIB_MAXIMUM_THREADS", +): + _os.environ.setdefault(_v, "1") diff --git a/scperteval/__main__.py b/src/scperteval/__main__.py similarity index 100% rename from scperteval/__main__.py rename to src/scperteval/__main__.py diff --git a/src/scperteval/blocks/__init__.py b/src/scperteval/blocks/__init__.py new file mode 100644 index 0000000..162dacd --- /dev/null +++ b/src/scperteval/blocks/__init__.py @@ -0,0 +1,3 @@ +"""Pluggable building blocks: DE methods and feature spaces.""" + +from . import de, spaces diff --git a/scperteval/blocks/de.py b/src/scperteval/blocks/de.py similarity index 52% rename from scperteval/blocks/de.py rename to src/scperteval/blocks/de.py index f4518ed..028cb41 100644 --- a/scperteval/blocks/de.py +++ b/src/scperteval/blocks/de.py @@ -4,6 +4,7 @@ expressed through reusable ``moments`` / ``ttest_from_moments`` helpers so the context can cache a shared reference's moments and combine them cheaply. """ + from __future__ import annotations import numpy as np @@ -14,10 +15,43 @@ from ..types import DEResult DE_METHODS = Registry("de-method") +"""Registry of DE backends; keys are the ``--de-method`` names. + +Use :meth:`~scperteval.registry.Registry.register` to add a custom backend:: + + from scperteval.blocks.de import DE_METHODS, bh + from scperteval.types import DEResult + + @DE_METHODS.register("my_test", description="My custom DE test") + def de_my_test(target, reference): + score = ... # per-gene statistic, shape (G,) + pvalue = ... # per-gene raw p-values, shape (G,) + return DEResult(score=score, pvalue=pvalue, pvalue_adj=bh(pvalue)) + +Then ``--de-method my_test`` routes every DE-dependent unit through it. +""" def moments(X): - """Per-gene (mean, sample variance, n) for a cell matrix, sparse- or dense-aware.""" + r"""Per-gene mean, sample variance, and cell count for a cell matrix. + + Sparse- and dense-aware; uses :math:`\text{Var}(X) = E[X^2] - E[X]^2` + with Bessel's correction. + + Parameters + ---------- + X : array-like, shape ``(n, G)`` + Cell matrix (sparse or dense). + + Returns + ------- + mean : numpy.ndarray, shape ``(G,)`` + Per-gene sample mean. + variance : numpy.ndarray, shape ``(G,)`` + Per-gene sample variance (floored at 0). + n : int + Number of cells. + """ n = X.shape[0] if sp.issparse(X): m = np.asarray(X.mean(0)).ravel() @@ -31,7 +65,22 @@ def moments(X): def bh(pvalue: np.ndarray) -> np.ndarray: - """Benjamini-Hochberg adjusted p-values.""" + """Benjamini-Hochberg adjusted p-values for FDR control :cite:p:`Vollenweider_2026`. + + Applied gene-wise inside each DE method to control the false discovery rate + across genes. The same procedure is used across perturbations to summarise + the overall sensitivity of a metric — see :cite:t:`Vollenweider_2026`. + + Parameters + ---------- + pvalue : numpy.ndarray + Array of raw p-values; non-finite values are carried through as ``nan``. + + Returns + ------- + numpy.ndarray + BH-adjusted p-values clipped to [0, 1]; same shape as ``pvalue``. + """ p = np.asarray(pvalue, dtype=np.float64) out = np.full(p.shape, np.nan) idx = np.where(np.isfinite(p))[0] @@ -44,11 +93,36 @@ def bh(pvalue: np.ndarray) -> np.ndarray: def ttest_from_moments(mt, vt, nt, mr, vr, nr) -> DEResult: - """Welch's t-test (scanpy convention); score = t-statistic.""" + """Welch's t-test from pre-computed per-gene moments (scanpy convention). + + Accepts moments directly so the context can cache the reference's moments + once and combine them cheaply for every perturbation. The ``score`` field + of the returned :class:`~scperteval.types.DEResult` is the t-statistic. + + Parameters + ---------- + mt : numpy.ndarray, shape ``(G,)`` + Target per-gene means. + vt : numpy.ndarray, shape ``(G,)`` + Target per-gene sample variances. + nt : int + Number of target cells. + mr : numpy.ndarray, shape ``(G,)`` + Reference per-gene means. + vr : numpy.ndarray, shape ``(G,)`` + Reference per-gene sample variances. + nr : int + Number of reference cells. + + Returns + ------- + ~scperteval.types.DEResult + ``score`` is the Welch t-statistic; ``pvalue_adj`` is BH-adjusted. + """ se2 = vt / nt + vr / nr with np.errstate(divide="ignore", invalid="ignore"): t = (mt - mr) / np.sqrt(se2) - df = se2 ** 2 / ((vt / nt) ** 2 / max(nt - 1, 1) + (vr / nr) ** 2 / max(nr - 1, 1)) + df = se2**2 / ((vt / nt) ** 2 / max(nt - 1, 1) + (vr / nr) ** 2 / max(nr - 1, 1)) t = np.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0) df = np.where(np.isfinite(df) & (df > 0), df, 1.0) pval = np.nan_to_num(2.0 * stats.t.sf(np.abs(t), df), nan=1.0) @@ -57,22 +131,22 @@ def ttest_from_moments(mt, vt, nt, mr, vr, nr) -> DEResult: @DE_METHODS.register("t-test", description="Welch's t-test (default) — moment-based and fast") def de_ttest(target, reference) -> DEResult: + """Welch's t-test between target and reference cell matrices.""" return ttest_from_moments(*moments(target), *moments(reference)) @DE_METHODS.register( "t-test_overestim_var", description="scanpy's conservative t-test variant; reference variance scaled by the " - "target's cell count (selectable backend; not used by any current protocol)", + "target's cell count (selectable backend; not used by any current protocol)", ) def de_ttest_overestim(target, reference) -> DEResult: - """scanpy ``rank_genes_groups(method='t-test_overestim_var')``. + """Scanpy ``rank_genes_groups(method='t-test_overestim_var')``. Identical to Welch's t-test except the reference group's cell count is replaced by the target's, which inflates the reference standard-error term ("overestimating" its variance for small target groups) and yields a more conservative statistic. Selectable as a DE - backend (``--de-method``/``--methods``) so new evaluation protocols can use it; no current - protocol does. + backend (``--de-method``/``--methods``); no current protocol uses it. """ mt, vt, nt = moments(target) mr, vr, _nr = moments(reference) @@ -97,10 +171,18 @@ def de_mwu(target, reference) -> DEResult: adata = ad.AnnData(np.vstack([Xt, Xr]).astype(np.float64)) adata.var_names = genes adata.obs["_g"] = ["target"] * nt + ["reference"] * nr - df = asymptotic_wilcoxon(adata, is_log1p=True, group_keys="_g", reference="reference", - n_threads=1, alternative="two-sided", use_continuity=True, - tie_correct=True, return_as_scanpy=False) - sub = df.xs("target", level=0).reindex(genes) + df = asymptotic_wilcoxon( + adata, + is_log1p=True, + group_keys="_g", + reference="reference", + n_threads=1, + alternative="two-sided", + use_continuity=True, + tie_correct=True, + return_as_scanpy=False, + ) + sub = df.xs("target", level=0).reindex(genes) # pyright: ignore[reportAttributeAccessIssue] u = sub["statistic"].to_numpy(dtype=np.float64) pval = np.nan_to_num(sub["p_value"].to_numpy(dtype=np.float64), nan=1.0) cliff = 2.0 * u / (nt * nr) - 1.0 diff --git a/src/scperteval/blocks/spaces.py b/src/scperteval/blocks/spaces.py new file mode 100644 index 0000000..dbefd27 --- /dev/null +++ b/src/scperteval/blocks/spaces.py @@ -0,0 +1,159 @@ +"""Feature spaces: a transform applied to the gene axis before a protocol runs. + +Spaces receive the raw (possibly sparse) cells and return a dense array, so a +gene-subset space densifies only its subset. The parameterised families +``top_`` / ``pca_`` / ``degs_`` are registered on demand by the +``top_space`` / ``pca_space`` / ``degs_space`` factories (used by the protocol +templates); the default instances created at import are what ``scperteval list spaces`` +shows. ``description`` is shown by ``scperteval list spaces``. +""" + +from __future__ import annotations + +import numpy as np + +from ..dataset import to_dense +from ..registry import Registry + +SPACES = Registry("space") +"""Registry of feature-space transforms; keys are space names (e.g. ``"top_50"``). + +Use :meth:`~scperteval.registry.Registry.register` to add a custom space:: + + from scperteval.blocks.spaces import SPACES, to_dense + + @SPACES.register("hvg_100", global_space=True, description="100 highest-variance genes") + def space_hvg(X, ctx, pert): + keep = ... # indices of the 100 genes to keep + return to_dense(X[:, keep]) + +Pass ``global_space=True`` if the transform does not depend on the perturbation +(so it can be computed once and shared across all perturbations in a run). +""" + + +@SPACES.register("full", global_space=True, description="all genes, no transform") +def space_full(X, ctx, pert): + """Identity space: all genes, densified, no transform.""" + return to_dense(X) + + +def _field(de, name): + return de.extra[name.split(":", 1)[1]] if name.startswith("extra:") else getattr(de, name) + + +def register_de_space(name, field, top=None, threshold=None, description=""): + r"""Register a DE-derived gene subset selected from a field of the GT DEResult. + + Exactly one of ``top`` (select top-k by \|value\|) or ``threshold`` (a callable + returning a boolean mask) must be provided. + + Parameters + ---------- + name : str + Registry key for the new space. + field : str + Attribute of :class:`~scperteval.types.DEResult` to read + (e.g. ``"score"``, ``"pvalue_adj"``). + top : int or None + If given, keep the top-k genes by absolute value of ``field``. + threshold : Callable or None + If given, a function ``(values) -> bool mask`` selecting genes to keep. + description : str + Human-readable description shown by ``scperteval list spaces``. + + Returns + ------- + str + The registered space name (same as ``name``). + """ + + def space(X, ctx, pert): + values = _field(ctx.de(pert, ctx.cfg.truth), field) + if top is not None: + keep = np.argsort(-np.abs(values))[:top] + else: + assert threshold is not None # register_de_space takes exactly one of top/threshold + keep = np.where(threshold(values))[0] + return to_dense(X[:, keep]) + + SPACES.add(name, space, description=description) + return name + + +def top_space(k: int) -> str: + r"""top-k genes by absolute ground-truth effect size (registered on demand). + + Parameters + ---------- + k : int + Number of genes to keep (selected by \|ground-truth effect size\| per perturbation). + + Returns + ------- + str + Space name ``"top_"`` (e.g. ``"top_50"``). + """ + name = f"top_{k}" + if name not in SPACES: + register_de_space( + name, field="score", top=k, description=f"top {k} genes by ground-truth effect size, per perturbation" + ) + return name + + +def degs_space(padj: float) -> str: + """ground-truth DEGs at adjusted p < padj (registered on demand). + + Parameters + ---------- + padj : float + Adjusted p-value threshold (e.g. 0.05). + + Returns + ------- + str + Space name ``"degs_"`` (e.g. ``"degs_0.05"``). + """ + name = f"degs_{padj:g}" + if name not in SPACES: + register_de_space( + name, + field="pvalue_adj", + threshold=(lambda v, p=padj: v < p), + description=f"ground-truth DEGs at adjusted p < {padj:g}, per perturbation", + ) + return name + + +def pca_space(k: int) -> str: + """top-k principal components (registered on demand). + + PCA is fit once on (up to 50 000) cells from the full dataset, then applied + to each cell population. The fitted transform is shared across perturbations. + + Parameters + ---------- + k : int + Number of principal components to retain. + + Returns + ------- + str + Space name ``"pca_"`` (e.g. ``"pca_50"``). + """ + name = f"pca_{k}" + if name not in SPACES: + SPACES.add( + name, + lambda X, ctx, pert, k=k: ctx.pca(k).transform(to_dense(X))[:, :k], + global_space=True, + description=f"top {k} principal components (fit on the dataset)", + ) + return name + + +# Default instances — also what `scperteval list spaces` shows. +top_space(50) +pca_space(50) +degs_space(0.05) diff --git a/scperteval/calibrators.py b/src/scperteval/calibrators.py similarity index 76% rename from scperteval/calibrators.py rename to src/scperteval/calibrators.py index 1a1ce00..c7c9298 100644 --- a/scperteval/calibrators.py +++ b/src/scperteval/calibrators.py @@ -1,7 +1,9 @@ -"""Calibrators turn the raw metric values measured on each control into a final -per-metric score. Each declares the control roles it needs, a per-perturbation -combine, and a cross-perturbation aggregate. +"""Calibrators turn raw control metric values into a final per-metric score. + +Each declares the control roles it needs, a per-perturbation combine, and a +cross-perturbation aggregate. """ + from __future__ import annotations import numpy as np @@ -29,17 +31,23 @@ def _bds_per_pert(raws, p): CALIBRATORS = { "drf": Calibrator( - "drf", ("positive", "negative"), _drf_per_pert, + "drf", + ("positive", "negative"), + _drf_per_pert, lambda v: {"mean": float(np.nanmean(v)), "median": float(np.nanmedian(v))}, description="Dynamic Range Fraction — mean/median over perturbations (Miller et al. 2025)", ), "bds": Calibrator( - "bds", ("positive", "negative"), _bds_per_pert, + "bds", + ("positive", "negative"), + _bds_per_pert, lambda v: {"bds": float(np.nanmean(v))}, description="Bound Discrimination Score — fraction of perturbations the positive control wins (SBB 2026)", ), "score": Calibrator( - "score", ("prediction",), lambda raws, p: raws["prediction"], + "score", + ("prediction",), + lambda raws, p: raws["prediction"], lambda v: {"mean": float(np.nanmean(v)), "median": float(np.nanmedian(v))}, description="raw metric of a prediction vs ground truth — mean/median over perturbations (prediction-scoring mode)", ), diff --git a/scperteval/cli.py b/src/scperteval/cli.py similarity index 61% rename from scperteval/cli.py rename to src/scperteval/cli.py index 048e665..86909e3 100644 --- a/scperteval/cli.py +++ b/src/scperteval/cli.py @@ -1,4 +1,5 @@ """scPertEval command-line interface.""" + from __future__ import annotations import argparse @@ -19,7 +20,7 @@ def _concrete(p: Protocol) -> Protocol: """A tunable protocol at its default value; a fixed protocol unchanged.""" - return p.resolve(p.param.default) if p.parameterised else p + return p.resolve(p.param.default) if p.parameterised else p # type: ignore[union-attr] def _resolve_token(token: str) -> list[Protocol]: @@ -27,12 +28,12 @@ def _resolve_token(token: str) -> list[Protocol]: return [_concrete(p) for p in TABLE] if token in GROUPS: return [_concrete(p) for p in TABLE if p.group == token] - if "=" in token: # a tunable protocol with a value, e.g. mse_top_k=30 + if "=" in token: # a tunable protocol with a value, e.g. mse_top_k=30 name, _, value = token.partition("=") p = PROTOCOLS.get(name) if p is None or not p.parameterised: raise SystemExit(f"unknown tunable protocol {name!r}; try `scperteval list protocols`") - return [p.resolve(p.param.cast(value))] + return [p.resolve(p.param.cast(value))] # type: ignore[union-attr] p = PROTOCOLS.get(token) if p is None: raise SystemExit(f"unknown protocol {token!r}; try `scperteval list protocols`") @@ -40,6 +41,7 @@ def _resolve_token(token: str) -> list[Protocol]: def resolve_protocols(specs: list[str]) -> list[Protocol]: + """Resolve CLI protocol specs to a de-duplicated list of concrete protocols.""" out: list[Protocol] = [] for spec in specs: for token in spec.split(","): @@ -53,9 +55,11 @@ def resolve_protocols(specs: list[str]) -> list[Protocol]: def _evaluate(cfg: RunConfig, protocols, ctx, quiet: bool) -> None: - """Run every protocol over the dataset, print the summary, and write the per-perturbation - CSV. Shared by ``calibrate`` and ``score`` (prediction vs ground truth); they differ only - in how ``ctx`` is built and which calibrator ``cfg.output`` selects.""" + """Run every protocol over the dataset, print the summary, and write the CSV. + + Shared by ``calibrate`` and ``score`` (prediction vs ground truth); they differ only in + how ``ctx`` is built and which calibrator ``cfg.output`` selects. + """ calibrator = CALIBRATORS[cfg.output] ctx.warm(protocols) aggregates, rows, timed = {}, [], [] @@ -74,13 +78,22 @@ def _evaluate(cfg: RunConfig, protocols, ctx, quiet: bool) -> None: def cmd_calibrate(args) -> None: + """Run the ``calibrate`` command: score protocols against built-in controls (DRF/BDS).""" protocols = resolve_protocols(args.protocols or ["all"]) cfg = RunConfig( - dataset=args.dataset, protocols=[p.name for p in protocols], de_method=args.de_method, - subsample=args.subsample, seed=args.seed, positive=args.positive, - negative=args.negative, output=args.output, out_dir=args.out_dir, - workers=args.workers, perturbation_key=args.perturbation_key, - control_label=args.control_label, min_cells=args.min_cells, + dataset=args.dataset, + protocols=[p.name for p in protocols], + de_method=args.de_method, + subsample=args.subsample, + seed=args.seed, + positive=args.positive, + negative=args.negative, + output=args.output, + out_dir=args.out_dir, + workers=args.workers, + perturbation_key=args.perturbation_key, + control_label=args.control_label, + min_cells=args.min_cells, profile=args.profile, ) ctx = Context(Dataset.load(cfg.dataset, cfg), cfg) @@ -88,14 +101,25 @@ def cmd_calibrate(args) -> None: def cmd_score(args) -> None: + """Run the ``score`` command: score predictions against ground truth, per protocol.""" protocols = resolve_protocols(args.protocols or ["all"]) cfg = RunConfig( - dataset=args.dataset, protocols=[p.name for p in protocols], de_method=args.de_method, - subsample=args.subsample, seed=args.seed, output="score", out_dir=args.out_dir, - workers=args.workers, perturbation_key=args.perturbation_key, - control_label=args.control_label, min_cells=args.min_cells, profile=args.profile, - predictions=args.predictions, truth="gt_all_cells", + dataset=args.dataset, + protocols=[p.name for p in protocols], + de_method=args.de_method, + subsample=args.subsample, + seed=args.seed, + output="score", + out_dir=args.out_dir, + workers=args.workers, + perturbation_key=args.perturbation_key, + control_label=args.control_label, + min_cells=args.min_cells, + profile=args.profile, + predictions=args.predictions, + truth="gt_all_cells", ) + assert cfg.predictions is not None # required positional on the score subcommand ds = Dataset.load(cfg.dataset, cfg) ctx = Context(ds, cfg) ctx.predictions = PredictionSet.load(cfg.predictions, ds, cfg) @@ -103,12 +127,19 @@ def cmd_score(args) -> None: def cmd_de(args) -> None: + """Run the ``de`` command: export per-gene differential expression to HDF5.""" methods = [m.strip() for m in args.methods.split(",") if m.strip()] cfg = RunConfig( - dataset=args.dataset, protocols=[], de_method=methods[0], - subsample=args.subsample, seed=args.seed, out_dir=args.out_dir, - workers=args.workers, min_cells=args.min_cells, - perturbation_key=args.perturbation_key, control_label=args.control_label, + dataset=args.dataset, + protocols=[], + de_method=methods[0], + subsample=args.subsample, + seed=args.seed, + out_dir=args.out_dir, + workers=args.workers, + min_cells=args.min_cells, + perturbation_key=args.perturbation_key, + control_label=args.control_label, ) ctx = Context(Dataset.load(cfg.dataset, cfg), cfg) ctx._ensure_ref_sums() @@ -119,14 +150,18 @@ def cmd_de(args) -> None: def cmd_list(args) -> None: + """Run the ``list`` command: print the available building blocks of one category.""" + def reg(registry, fmt): return [fmt(n, registry.meta(n)) for n in registry.names()] if args.what == "protocols": + def descr(p): scope = "" if p.scope == "perturbation" else f", {p.scope}-wide" knob = f"{p.param.name}=…" if p.parameterised else f"space={p.space}" return f"{p.group}, {p.representation}{scope}, {knob}" + lines = [f"{p.name:24s} ({descr(p)})" for p in TABLE] elif args.what == "de-methods": lines = reg(DE_METHODS, lambda n, m: f"{n:10s} — {m.get('description', '')}") @@ -136,67 +171,88 @@ def descr(p): lines = reg(SOURCES, lambda n, m: f"{n:14s} ({m.get('provides')}) — {m.get('description', '')}") elif args.what == "calibrators": lines = [f"{n:6s} — {c.description}" for n, c in CALIBRATORS.items()] + else: + raise AssertionError(f"unexpected list target: {args.what!r}") print("\n".join(lines)) def main(argv=None) -> None: + """Parse arguments and dispatch to the selected subcommand.""" parser = argparse.ArgumentParser(prog="scperteval", description=__doc__) sub = parser.add_subparsers(dest="cmd", required=True) - calibrate = sub.add_parser( - "calibrate", help="calibrate protocols against positive/negative controls (DRF/BDS)") + calibrate = sub.add_parser("calibrate", help="calibrate protocols against positive/negative controls (DRF/BDS)") calibrate.add_argument("dataset", help="preprocessed .h5ad") - calibrate.add_argument("-p", "--protocols", action="append", default=[], - help="comma-separated names, a group (pseudobulk|distributional|de), or 'all'") - calibrate.add_argument("--de-method", choices=DE_METHODS.names(), default="t-test", - help="DE backend for EVERY DE-dependent unit in the run: the interpolated " - "positive control, the top_k/degs spaces, the de_* protocols, and the WMSE weights") + calibrate.add_argument( + "-p", + "--protocols", + action="append", + default=[], + help="comma-separated names, a group (pseudobulk|distributional|de), or 'all'", + ) + calibrate.add_argument( + "--de-method", + choices=DE_METHODS.names(), + default="t-test", + help="DE backend for EVERY DE-dependent unit in the run: the interpolated " + "positive control, the top_k/degs spaces, the de_* protocols, and the WMSE weights", + ) calibrate.add_argument("--subsample", type=int, default=8192) calibrate.add_argument("--seed", type=int, default=42) calibrate.add_argument("--positive", default="auto") calibrate.add_argument("--negative", default="auto") calibrate.add_argument( - "--output", default="drf", + "--output", + default="drf", choices=[n for n, c in CALIBRATORS.items() if "prediction" not in c.requires], - help="how per-perturbation values are calibrated (drf/bds)") + help="how per-perturbation values are calibrated (drf/bds)", + ) calibrate.add_argument("--out-dir", default="results") calibrate.add_argument("--workers", type=int, default=0, help="threads (0 = auto)") calibrate.add_argument("--perturbation-key", default="perturbation") calibrate.add_argument("--control-label", default="control") - calibrate.add_argument("--min-cells", type=int, default=30, - help="skip perturbations with fewer cells") - calibrate.add_argument("--profile", action="store_true", - help="also write a per-protocol wall-clock timing table") + calibrate.add_argument("--min-cells", type=int, default=30, help="skip perturbations with fewer cells") + calibrate.add_argument("--profile", action="store_true", help="also write a per-protocol wall-clock timing table") calibrate.add_argument("--quiet", action="store_true") calibrate.set_defaults(func=cmd_calibrate) - score = sub.add_parser( - "score", help="score model predictions against ground truth (real cells), per protocol") + score = sub.add_parser("score", help="score model predictions against ground truth (real cells), per protocol") score.add_argument("dataset", help="preprocessed .h5ad — the ground truth (real cells)") score.add_argument("predictions", help="predicted .h5ad — same genes and perturbation labels") - score.add_argument("-p", "--protocols", action="append", default=[], - help="comma-separated names, a group (pseudobulk|distributional|de), or 'all'") - score.add_argument("--de-method", choices=DE_METHODS.names(), default="t-test", - help="DE backend for every DE-dependent unit (the top_k/degs spaces, the " - "de_* protocols, and the WMSE weights)") - score.add_argument("--subsample", type=int, default=8192, - help="cells in the all-perturbed reference sample (the ground truth itself is never subsampled)") + score.add_argument( + "-p", + "--protocols", + action="append", + default=[], + help="comma-separated names, a group (pseudobulk|distributional|de), or 'all'", + ) + score.add_argument( + "--de-method", + choices=DE_METHODS.names(), + default="t-test", + help="DE backend for every DE-dependent unit (the top_k/degs spaces, the de_* protocols, and the WMSE weights)", + ) + score.add_argument( + "--subsample", + type=int, + default=8192, + help="cells in the all-perturbed reference sample (the ground truth itself is never subsampled)", + ) score.add_argument("--seed", type=int, default=42) score.add_argument("--out-dir", default="results") score.add_argument("--workers", type=int, default=0, help="threads (0 = auto)") score.add_argument("--perturbation-key", default="perturbation") score.add_argument("--control-label", default="control") - score.add_argument("--min-cells", type=int, default=30, - help="skip perturbations with fewer cells") - score.add_argument("--profile", action="store_true", - help="also write a per-protocol wall-clock timing table") + score.add_argument("--min-cells", type=int, default=30, help="skip perturbations with fewer cells") + score.add_argument("--profile", action="store_true", help="also write a per-protocol wall-clock timing table") score.add_argument("--quiet", action="store_true") score.set_defaults(func=cmd_score) de = sub.add_parser("de", help="write per-gene DE (statistic + adj p) per method to HDF5") de.add_argument("dataset", help="preprocessed .h5ad") - de.add_argument("--methods", default="t-test,MWU", - help="comma-separated DE methods to compute (GT first-half vs all-perturbed)") + de.add_argument( + "--methods", default="t-test,MWU", help="comma-separated DE methods to compute (GT first-half vs all-perturbed)" + ) de.add_argument("--subsample", type=int, default=8192) de.add_argument("--seed", type=int, default=42) de.add_argument("--out-dir", default="results") diff --git a/scperteval/context.py b/src/scperteval/context.py similarity index 70% rename from scperteval/context.py rename to src/scperteval/context.py index 976f332..eb91766 100644 --- a/scperteval/context.py +++ b/src/scperteval/context.py @@ -1,11 +1,19 @@ -"""The per-run engine: lazily builds and caches the shared building blocks, and -turns a (perturbation, source) into the exact view a protocol consumes.""" +"""The per-run engine that turns a (perturbation, source) into a protocol's view. + +Lazily builds and caches the shared building blocks, and turns a (perturbation, source) +into the exact view a protocol consumes. +""" + from __future__ import annotations import threading +from typing import TYPE_CHECKING, Any import numpy as np +if TYPE_CHECKING: + from sklearn.decomposition import PCA + from .blocks.de import DE_METHODS, moments, ttest_from_moments from .blocks.spaces import SPACES from .dataset import Dataset, to_dense @@ -13,18 +21,40 @@ from .sources import SOURCES from .types import Protocol, RunConfig +if TYPE_CHECKING: + from .predictions import PredictionSet + class Context: - """Owns the dataset, caches DE / PCA / control mean, and dispatches views. + """Per-run engine: owns the dataset, caches shared computations, and dispatches views. - Caches are keyed per perturbation, so the runner can fan perturbations out - across threads; ``current_pert`` is thread-local for the same reason. + Instantiated once per ``scperteval run`` call and passed to every metric as ``ctx``. + Caches are keyed per perturbation so the runner can fan work out across threads; + ``current_pert`` is thread-local for the same reason. + + Parameters + ---------- + dataset : ~scperteval.dataset.Dataset + Loaded and indexed AnnData wrapper. + cfg : ~scperteval.types.RunConfig + Resolved run options (DE method, subsample size, seed, …). + + Attributes + ---------- + ds : ~scperteval.dataset.Dataset + The underlying dataset. + cfg : ~scperteval.types.RunConfig + The resolved run configuration. + perturbations : list of str + Names of all perturbations that passed the ``min_cells`` filter. + current_pert : str or None + Thread-local name of the perturbation currently being processed. """ def __init__(self, dataset: Dataset, cfg: RunConfig): self.ds = dataset self.cfg = cfg - self.predictions = None # a PredictionSet in prediction-scoring mode, else None + self.predictions: PredictionSet | None = None # set in prediction-scoring mode self._local = threading.local() # Reentrant: several lazy initialisers (e.g. _ensure_ref_sums, ref_projection) # call reference() while already holding this lock, which a plain Lock would @@ -33,19 +63,21 @@ def __init__(self, dataset: Dataset, cfg: RunConfig): self._de: dict = {} self._mom: dict = {} self._weights: dict = {} - self._pca = None + self._pca: Any = None self._pca_k = 0 - self._control_mean = None - self._reference = None + self._control_mean: np.ndarray | None = None + self._reference: Reference | None = None self._ref_proj: dict = {} - self._ref_sums = None + self._ref_sums: tuple | None = None @property def perturbations(self): + """The list of perturbations evaluated in this run.""" return self.ds.perturbations @property def current_pert(self): + """The perturbation the current worker thread is processing.""" return getattr(self._local, "pert", None) @current_pert.setter @@ -53,8 +85,10 @@ def current_pert(self, value): self._local.pert = value def warm(self, protocols): - """Precompute shared singletons before the parallel loop so per-perturbation - threads only ever write per-perturbation cache keys.""" + """Precompute shared singletons before the parallel loop. + + So per-perturbation threads only ever write per-perturbation cache keys. + """ self.control_mean() if any(p.representation in ("population", "de") for p in protocols): self.reference() @@ -63,11 +97,13 @@ def warm(self, protocols): self._moments("control", None) if any(p.space == "pca50" for p in protocols): self.pca() - for space in {p.space for p in protocols - if p.representation == "population" and SPACES.meta(p.space).get("global_space")}: + for space in { + p.space for p in protocols if p.representation == "population" and SPACES.meta(p.space).get("global_space") + }: self.ref_projection(space) def view(self, pert: str, source: str, p: Protocol): + """Return ``source``'s datapoint for ``pert`` in the shape ``p`` consumes.""" if p.representation == "population": if source == "all_perturbed": return self._reference_population(p.space, pert) @@ -80,6 +116,7 @@ def view(self, pert: str, source: str, p: Protocol): raise ValueError(f"unknown protocol representation {p.representation!r}") def centroid(self, pert, source, centering): + """Pseudobulk centroid of ``source`` for ``pert``, optionally centered.""" arr = SOURCES[source](self, pert) if SOURCES.meta(source).get("provides") == "centroid": v = np.asarray(arr, dtype=np.float64).ravel() @@ -92,26 +129,29 @@ def centroid(self, pert, source, centering): return v def _de_view(self, pert, source, p): - """GT -> truth labels (its DEResult); a candidate -> its |score| ranking. - The negative candidate is tested against ``neg_reference`` (e.g. control) - rather than ``reference`` (the all-perturbed sample), the hybrid DE setup.""" + """Return the DE view: the truth's DEResult, or a candidate's ``|score|`` ranking. + + The negative candidate is tested against ``neg_reference`` (e.g. control) rather than + ``reference`` (the all-perturbed sample), the hybrid DE setup. + """ if source == self.cfg.truth: return self.de(pert, self.cfg.truth, p.reference) reference = p.neg_reference if (source == p.negative and p.neg_reference) else p.reference return np.abs(self.de(pert, source, reference).score) def de(self, pert, source, reference="all_perturbed"): - """DE for one (source vs reference) comparison; the reference moments are - leave-one-out, so a perturbation is never compared against a sample of itself.""" + """Differential expression for one (source vs reference) comparison, cached. + + The reference moments are leave-one-out, so a perturbation is never compared against + a sample of itself. + """ method = self.cfg.de_method key = (self._mom_key(source, pert), self._mom_key(reference, pert), method) if key not in self._de: if method == "t-test": - self._de[key] = ttest_from_moments(*self._moments(source, pert), - *self._moments(reference, pert)) + self._de[key] = ttest_from_moments(*self._moments(source, pert), *self._moments(reference, pert)) else: - self._de[key] = DE_METHODS[method](self._de_cells(source, pert), - self._de_cells(reference, pert)) + self._de[key] = DE_METHODS[method](self._de_cells(source, pert), self._de_cells(reference, pert)) return self._de[key] def _moments(self, source, pert): @@ -134,7 +174,7 @@ def _mom_key(source, pert): return source if source == "control" else (source, pert) def wmse_weights(self, pert): - """Mejia DEG weights: min-max normalised |effect size| of GT vs the reference.""" + """Mejia DEG weights: min-max normalised absolute effect size of GT vs the reference.""" if pert not in self._weights: s = np.abs(self.de(pert, self.cfg.truth, "all_perturbed").score) finite = np.isfinite(s) @@ -146,8 +186,10 @@ def wmse_weights(self, pert): # -- the all-perturbed reference: one sample, served leave-one-out ------------- def reference(self) -> Reference: - """The all-perturbed sample (subsampled + densified once), with each cell's - perturbation recorded so it can be served leave-one-out.""" + """The all-perturbed sample, subsampled and densified once. + + Each cell's perturbation is recorded so the sample can be served leave-one-out. + """ if self._reference is None: with self._init_lock: if self._reference is None: @@ -157,8 +199,10 @@ def reference(self) -> Reference: return self._reference def _reference_population(self, space, pert): - """The reference in a feature space with the target perturbation removed: - project the whole sample (cached for global spaces) then drop its rows.""" + """The reference in a feature space with the target perturbation removed. + + Project the whole sample (cached for global spaces) then drop its rows. + """ ref = self.reference() if SPACES.meta(space).get("global_space"): proj = self.ref_projection(space) @@ -176,8 +220,11 @@ def ref_projection(self, space): return self._ref_proj[space] def _ensure_ref_sums(self): - """Cache the reference's column sums and sums-of-squares once, so leave-one-out - moments are an O(target cells) subtraction rather than a re-densify per perturbation.""" + """Cache the reference's column sums and sums-of-squares once. + + Leave-one-out moments are then an O(target cells) subtraction rather than a + re-densify per perturbation. + """ if self._ref_sums is None: with self._init_lock: if self._ref_sums is None: @@ -200,6 +247,7 @@ def _reference_moments(self, pert): return mean, var, k def control_mean(self): + """The control centroid (cached).""" if self._control_mean is None: with self._init_lock: if self._control_mean is None: @@ -218,8 +266,11 @@ def pca(self, k=50): PCA_FIT_CAP = 50000 def _fit_pca(self, n_components): - """Fit PCA on (nearly) all cells; the subsample cap is for the O(n^2) - distance populations, not the PCA basis, which needs many cells to be stable.""" + """Fit PCA on (nearly) all cells. + + The subsample cap is for the O(n^2) distance populations, not the PCA basis, which + needs many cells to be stable. + """ from sklearn.decomposition import PCA n = self.ds.adata.n_obs diff --git a/scperteval/dataset.py b/src/scperteval/dataset.py similarity index 79% rename from scperteval/dataset.py rename to src/scperteval/dataset.py index fa8e307..da9da72 100644 --- a/scperteval/dataset.py +++ b/src/scperteval/dataset.py @@ -1,4 +1,5 @@ """Thin wrapper over a preprocessed AnnData with a perturbation column.""" + from __future__ import annotations import zlib @@ -11,7 +12,7 @@ def _seed(seed: int, *tags) -> np.random.Generator: - key = (seed,) + tuple(zlib.crc32(str(t).encode()) for t in tags) + key = (seed, *(zlib.crc32(str(t).encode()) for t in tags)) return np.random.default_rng(np.array(key, dtype=np.uint32)) @@ -27,7 +28,8 @@ def __init__(self, adata, cfg: RunConfig): self._index() @classmethod - def load(cls, path: str, cfg: RunConfig) -> "Dataset": + def load(cls, path: str, cfg: RunConfig) -> Dataset: + """Load a dataset from a preprocessed ``.h5ad`` path.""" return cls(ad.read_h5ad(path), cfg) def _index(self): @@ -50,6 +52,7 @@ def _index(self): self._mean_sum = self._mean_matrix.sum(0) def cells(self, pert: str, half: str | None = None) -> np.ndarray: + """Cells for ``pert``: the first/second split half, or all cells when ``half`` is None.""" if half == "first": idx = self.halves[pert][0] elif half == "second": @@ -59,23 +62,31 @@ def cells(self, pert: str, half: str | None = None) -> np.ndarray: return self.adata.X[idx] def control_cells(self, cap: int) -> np.ndarray: + """A capped subsample of the non-targeting control cells.""" return self.adata.X[self._cap(self.control_idx, cap, "control")] def all_perturbed_indices(self, cap: int) -> np.ndarray: - """One all-perturbed subsample (the reference sample, shared across perturbations). - The "pool" tag is a fixed reproducibility salt for the draw, not a public name.""" + """Indices of one all-perturbed subsample (the shared reference sample). + + The "pool" tag is a fixed reproducibility salt for the draw, not a public name. + """ return self._cap(np.where(self.pert != self.cfg.control_label)[0], cap, "pool") def allpert_mean_except(self, pert: str) -> np.ndarray: + """Mean of all per-perturbation means, excluding ``pert`` (leave-one-out).""" k = len(self.perturbations) return (self._mean_sum - self._mean_matrix[self._row[pert]]) / max(k - 1, 1) def allpert_mean(self) -> np.ndarray: - """Mean of all per-perturbation means (no target exclusion); a single vector - shared across perturbations, used as the cross-perturbation ranking baseline.""" + """Mean of all per-perturbation means (no target exclusion). + + A single vector shared across perturbations, used as the cross-perturbation ranking + baseline. + """ return self._mean_sum / max(len(self.perturbations), 1) def control_mean(self) -> np.ndarray: + """Pseudobulk centroid of the control cells.""" return np.asarray(self.adata.X[self.control_idx].mean(0)).ravel() def _cap(self, idx: np.ndarray, cap: int, *tags) -> np.ndarray: @@ -86,4 +97,5 @@ def _cap(self, idx: np.ndarray, cap: int, *tags) -> np.ndarray: def to_dense(X) -> np.ndarray: + """Return ``X`` as a dense array (densifying if sparse).""" return X.toarray() if sp.issparse(X) else np.asarray(X) diff --git a/scperteval/io.py b/src/scperteval/io.py similarity index 79% rename from scperteval/io.py rename to src/scperteval/io.py index b0c53cd..03b8fc5 100644 --- a/scperteval/io.py +++ b/src/scperteval/io.py @@ -1,4 +1,5 @@ """Human-readable summary plus a per-perturbation CSV named with dataset + time.""" + from __future__ import annotations from pathlib import Path @@ -8,6 +9,7 @@ def print_summary(cfg, aggregates: dict, calibrator, protocols) -> None: + """Print a formatted table of aggregate scores for every protocol.""" name = Path(cfg.dataset).stem print(f"\n{name} · {cfg.de_method} · subsample={cfg.subsample} · seed={cfg.seed} · output={cfg.output}\n") agg_keys = sorted({k for v in aggregates.values() for k in v}) @@ -22,11 +24,16 @@ def print_summary(cfg, aggregates: dict, calibrator, protocols) -> None: def write_rows(cfg, rows: list, timestamp: str) -> Path: + """Write per-perturbation rows (raw controls + calibrated score) to a timestamped CSV.""" out_dir = Path(cfg.out_dir) out_dir.mkdir(parents=True, exist_ok=True) df = pd.DataFrame(rows) - for col, val in (("dataset", Path(cfg.dataset).stem), ("de_method", cfg.de_method), - ("subsample", cfg.subsample), ("seed", cfg.seed)): + for col, val in ( + ("dataset", Path(cfg.dataset).stem), + ("de_method", cfg.de_method), + ("subsample", cfg.subsample), + ("seed", cfg.seed), + ): df[col] = val path = out_dir / f"{Path(cfg.dataset).stem}__{timestamp}__{cfg.output}.csv" df.to_csv(path, index=False) @@ -37,9 +44,16 @@ def write_timing(cfg, timed: list, timestamp: str) -> Path: """Write per-protocol wall-clock seconds (one row per protocol).""" out_dir = Path(cfg.out_dir) out_dir.mkdir(parents=True, exist_ok=True) - rows = [{"dataset": Path(cfg.dataset).stem, "protocol": p.name, - "representation": p.representation, "space": p.space, "seconds": seconds} - for p, seconds in timed] + rows = [ + { + "dataset": Path(cfg.dataset).stem, + "protocol": p.name, + "representation": p.representation, + "space": p.space, + "seconds": seconds, + } + for p, seconds in timed + ] path = out_dir / f"{Path(cfg.dataset).stem}__{timestamp}__timing.csv" pd.DataFrame(rows).to_csv(path, index=False) return path @@ -49,7 +63,8 @@ def write_de(cfg, genes, perturbations, results: dict, timestamp: str) -> Path: """Write per-gene DE (statistic + adjusted p) per method to an HDF5 file. Layout: ``genes``, ``perturbations``, and one group per method holding - ``statistic`` and ``pvalue_adj`` matrices (perturbations x genes).""" + ``statistic`` and ``pvalue_adj`` matrices (perturbations x genes). + """ import h5py out_dir = Path(cfg.out_dir) diff --git a/scperteval/predictions.py b/src/scperteval/predictions.py similarity index 93% rename from scperteval/predictions.py rename to src/scperteval/predictions.py index c239e04..a963214 100644 --- a/scperteval/predictions.py +++ b/src/scperteval/predictions.py @@ -5,6 +5,7 @@ any order) and the same perturbation column; columns are reordered to the dataset's gene order so every metric's positional ``gt - prediction`` comparison lines up. """ + from __future__ import annotations import anndata as ad @@ -18,7 +19,8 @@ def _align_genes(pred_genes: np.ndarray, ds_genes: np.ndarray) -> np.ndarray: """Indices that reorder the prediction's genes into the dataset's gene order. Errors (naming what's wrong) unless the two gene sets are identical -- metrics compare - gene vectors positionally, so a mismatch would silently compare the wrong genes.""" + gene vectors positionally, so a mismatch would silently compare the wrong genes. + """ pred_set, ds_set = set(map(str, pred_genes)), set(map(str, ds_genes)) missing = [g for g in map(str, ds_genes) if g not in pred_set] extra = [g for g in map(str, pred_genes) if g not in ds_set] @@ -47,7 +49,8 @@ def __init__(self, adata, ds: Dataset, cfg: RunConfig): self.pert = np.asarray(adata.obs[cfg.perturbation_key]).astype(str) @classmethod - def load(cls, path: str, ds: Dataset, cfg: RunConfig) -> "PredictionSet": + def load(cls, path: str, ds: Dataset, cfg: RunConfig) -> PredictionSet: + """Load a prediction ``.h5ad`` and gene-align it to ``ds``.""" return cls(ad.read_h5ad(path), ds, cfg) def cells(self, pert: str) -> np.ndarray: diff --git a/scperteval/protocols/__init__.py b/src/scperteval/protocols/__init__.py similarity index 57% rename from scperteval/protocols/__init__.py rename to src/scperteval/protocols/__init__.py index ed1c0e8..2f351c2 100644 --- a/scperteval/protocols/__init__.py +++ b/src/scperteval/protocols/__init__.py @@ -1,2 +1,3 @@ """Evaluation protocols: pure metrics plus the declarative protocol table.""" -from .table import GROUPS, PROTOCOLS, TABLE # noqa: F401 + +from .table import GROUPS, PROTOCOLS, TABLE diff --git a/src/scperteval/protocols/metrics.py b/src/scperteval/protocols/metrics.py new file mode 100644 index 0000000..2720ed1 --- /dev/null +++ b/src/scperteval/protocols/metrics.py @@ -0,0 +1,406 @@ +r"""Evaluation-protocol metrics — the exact implementation of every metric. + +A metric takes the ground-truth and a prediction (whichever control is being scored) plus +the context, and returns a score. The protocol's ``representation`` sets each datapoint's +shape — ``centroid`` -> a 1-D pseudobulk vector, ``population`` -> a (cells x genes) array, +``de`` -> a DEResult (GT) / \|score\| ranking (prediction). Its ``scope`` sets the call: a +``perturbation``-scope metric gets one perturbation's (gt, prediction) and returns a scalar; +a ``dataset``-scope metric gets the list of every perturbation's gt and prediction and +returns one score per perturbation (e.g. ``rank_retrieval``). + +Every metric is implemented in full here; only external numerical libraries (numpy, +scikit-learn, geomloss) are relied upon. So a metric is completely defined by its function +below plus its row in ``table.py`` — nothing is hidden behind another layer. +""" + +from __future__ import annotations + +import numpy as np +from sklearn.metrics import average_precision_score, roc_auc_score + +# --- shared parameter blocks, substituted into docstrings at decoration time --- + +_CENTROID = """\ +gt : numpy.ndarray + Ground-truth pseudobulk profile, shape ``(G,)``. +prediction : numpy.ndarray + Predicted pseudobulk profile, shape ``(G,)``. +ctx : Context + Unused; present for signature compatibility.""" + +_CENTROID_W = """\ +gt : numpy.ndarray + Ground-truth pseudobulk profile, shape ``(G,)``. +prediction : numpy.ndarray + Predicted pseudobulk profile, shape ``(G,)``. +ctx : Context + Provides per-gene WMSE weights via ``ctx.wmse_weights``.""" + +_POPULATION = """\ +gt : numpy.ndarray + Ground-truth cell matrix, shape ``(n, G)``. +prediction : numpy.ndarray + Predicted cell matrix, shape ``(m, G)``. +ctx : Context + Unused; present for signature compatibility.""" + +_DATASET = """\ +gt : list of numpy.ndarray + Ground-truth centroids, one per perturbation, each shape ``(G,)``. +prediction : list of numpy.ndarray + Predicted centroids, one per perturbation, each shape ``(G,)``. +ctx : Context + Unused; present for signature compatibility.""" + +_DE = """\ +gt : ~scperteval.types.DEResult + Ground-truth DE result; ``gt.pvalue_adj`` defines the positive class. +prediction : numpy.ndarray + Per-gene absolute DE score ranking from the candidate source, shape ``(G,)``. +ctx : Context + Unused; present for signature compatibility.""" + + +def _doc(**subs): + """Decorator that substitutes %(key)s placeholders, propagating surrounding indentation. + + Python's ``%`` substitution only indents the first line of a multi-line value. + This decorator detects the column position of each placeholder and re-indents all + continuation lines to match, so the substituted text stays inside the RST section. + """ + + def deco(fn): + doc = fn.__doc__ + for key, value in subs.items(): + placeholder = f"%({key})s" + while placeholder in doc: + idx = doc.index(placeholder) + line_start = doc.rfind("\n", 0, idx) + 1 + indent = " " * (idx - line_start) + indented = ("\n" + indent).join(value.split("\n")) + doc = doc[:idx] + indented + doc[idx + len(placeholder) :] + fn.__doc__ = doc + return fn + + return deco + + +def _sq_dists(X, Y): + """Pairwise squared euclidean distances via ||x||^2 + ||y||^2 - 2 x.y. + + Routed through a BLAS matrix product, which releases the GIL so the + per-perturbation thread pool actually parallelises. + """ + xx = np.einsum("ij,ij->i", X, X) + yy = np.einsum("ij,ij->i", Y, Y) + sq = xx[:, None] + yy[None, :] - 2.0 * (X @ Y.T) + return np.maximum(sq, 0.0) + + +def _within_unbiased(sq, n): + """Unbiased (U-statistic) mean within-population euclidean distance.""" + if n <= 1: + return 0.0 + return float(np.sqrt(sq).sum() / (n * (n - 1))) + + +@_doc(params=_CENTROID) +def pearson(gt, prediction, ctx): + r"""Pearson correlation between pseudobulk profiles. + + .. math:: + + r = \frac{\sum_g (gt_g - \bar{gt})(pred_g - \bar{pred})}{ + \sqrt{\sum_g (gt_g - \bar{gt})^2 \cdot \sum_g (pred_g - \bar{pred})^2}} + + Parameters + ---------- + %(params)s + + Returns + ------- + float + Pearson r in [-1, 1]; 1 is perfect. + """ + return float(np.corrcoef(gt, prediction)[0, 1]) + + +@_doc(params=_CENTROID) +def mse(gt, prediction, ctx): + r"""Mean squared error between pseudobulk profiles. + + .. math:: + + \text{MSE} = \frac{1}{G}\sum_{g=1}^G (gt_g - pred_g)^2 + + Parameters + ---------- + %(params)s + + Returns + ------- + float + Non-negative MSE; 0 is perfect. + """ + return float(np.mean((gt - prediction) ** 2)) + + +@_doc(params=_CENTROID_W) +def weighted_mse(gt, prediction, ctx, exp=2.0): + r"""MSE weighted by ground-truth effect size raised to ``exp``. + + Weights are min-max normalised per-gene; high-effect genes contribute more. + + .. math:: + + \text{wMSE} = \sum_g w_g \,(gt_g - pred_g)^2, \quad + w_g \propto |s_g|^{\text{exp}} / \sum_{g'} |s_{g'}|^{\text{exp}} + + where :math:`s_g` is the ground-truth DE t-statistic for gene :math:`g`. + + Parameters + ---------- + %(params)s + exp : float + Exponent applied to the effect-size weights (default 2.0). + + Returns + ------- + float + Non-negative weighted MSE; 0 is perfect. + """ + w = ctx.wmse_weights(ctx.current_pert) ** exp + total = w.sum() + w = w / total if total > 0 else np.full(w.size, 1.0 / w.size) + return float(np.sum(w * (gt - prediction) ** 2)) + + +@_doc(params=_POPULATION) +def energy_distance(gt, prediction, ctx): + r"""Székely–Rizzo energy distance with bias-corrected within-population terms. + + .. math:: + + E(X, Y) = 2\,\mathbb{E}[\|X - Y\|] + - \mathbb{E}[\|X - X'\|] - \mathbb{E}[\|Y - Y'\|] + + Within-population terms use the unbiased (U-statistic) estimator. + + Parameters + ---------- + %(params)s + + Returns + ------- + float + Energy distance >= 0; 0 is perfect (identical distributions). + Returns ``nan`` if either population is empty. + """ + if len(gt) == 0 or len(prediction) == 0: + return float("nan") + X = gt.astype(np.float64) + Y = prediction.astype(np.float64) + cross = np.sqrt(_sq_dists(X, Y)).mean() + xx = _within_unbiased(_sq_dists(X, X), len(X)) + yy = _within_unbiased(_sq_dists(Y, Y), len(Y)) + return float(2.0 * cross - xx - yy) + + +@_doc(params=_POPULATION) +def unbiased_mmd_median(gt, prediction, ctx): + r"""Unbiased RBF-MMD² with median-heuristic bandwidth (Gretton 2012). + + .. math:: + + \widehat{\text{MMD}}^2(X, Y) + = \frac{1}{n(n-1)} \sum_{i \neq j} k(x_i, x_j) + + \frac{1}{m(m-1)} \sum_{i \neq j} k(y_i, y_j) + - \frac{2}{nm} \sum_{i,j} k(x_i, y_j) + + with :math:`k(x,y) = \exp(-\|x-y\|^2 / 2\sigma^2)` and :math:`\sigma` the median + pairwise Euclidean distance over the pooled sample. + + Parameters + ---------- + %(params)s + + Returns + ------- + float + MMD² (may be slightly negative due to estimation variance); 0 is perfect. + Returns ``nan`` if either population has fewer than 2 cells. + """ + if len(gt) < 2 or len(prediction) < 2: + return float("nan") + X = gt.astype(np.float64) + Y = prediction.astype(np.float64) + nx, ny = len(X), len(Y) + pooled = np.vstack([X, Y]) + euc = np.sqrt(_sq_dists(pooled, pooled)) + n = euc.shape[0] + sigma = float(np.median(euc[~np.eye(n, dtype=bool)])) + if sigma <= 0: + return 0.0 + gamma = 1.0 / (2.0 * sigma * sigma) + k_xx = np.exp(-gamma * _sq_dists(X, X)) + k_yy = np.exp(-gamma * _sq_dists(Y, Y)) + k_xy = np.exp(-gamma * _sq_dists(X, Y)) + xx = (k_xx.sum() - np.trace(k_xx)) / (nx * (nx - 1)) + yy = (k_yy.sum() - np.trace(k_yy)) / (ny * (ny - 1)) + return float(xx + yy - 2.0 * k_xy.mean()) + + +_geomloss_cache: dict = {} + + +@_doc(params=_POPULATION) +def sinkhorn_w2(gt, prediction, ctx, blur=0.05): + r"""Debiased Sinkhorn 2-Wasserstein distance (geomloss, p=2). + + .. math:: + + W_2(X, Y) = \sqrt{2\,S_\varepsilon(X, Y)} + + where :math:`S_\varepsilon` is the debiased Sinkhorn divergence with blur + :math:`\varepsilon`. Requires ``geomloss`` and ``torch``. + + Parameters + ---------- + %(params)s + blur : float + Sinkhorn entropic regularisation parameter (default 0.05). + + Returns + ------- + float + W2 distance >= 0; 0 is perfect. + Returns ``nan`` if either population is empty. + """ + if len(gt) == 0 or len(prediction) == 0: + return float("nan") + import torch + from geomloss import SamplesLoss + + loss = _geomloss_cache.get(blur) + if loss is None: + torch.set_num_threads(1) + loss = SamplesLoss(loss="sinkhorn", p=2, blur=blur, debias=True, backend="tensorized") + _geomloss_cache[blur] = loss + Xt = torch.as_tensor(np.ascontiguousarray(gt), dtype=torch.float32) + Yt = torch.as_tensor(np.ascontiguousarray(prediction), dtype=torch.float32) + a = torch.full((len(gt),), 1.0 / len(gt), dtype=torch.float32) + b = torch.full((len(prediction),), 1.0 / len(prediction), dtype=torch.float32) + with torch.no_grad(): + val = float(loss(a, Xt, b, Yt)) + return float(np.sqrt(max(2.0 * val, 0.0))) + + +@_doc(params=_DATASET) +def rank_retrieval(gt, prediction, ctx, transpose=False): + r"""Cross-perturbation retrieval rank — dataset-scope metric, lower is better. + + Builds the ``(n x n)`` squared-distance matrix between all predicted and ground-truth + centroids, then reads off the diagonal rank (column-wise by default). + + .. math:: + + \text{rank}(a) = \frac{\text{rank}_{\text{col}}(D_{aa})}{n - 1}, \quad + D_{ij} = \|P_i - G_j\|^2 + + where :math:`P_i` and :math:`G_j` are the predicted and ground-truth centroids. + ``transpose_rank`` transposes the matrix first (each prediction ranked among all GTs). + Tie-breaking noise (seed 42) matches the DRF calibration convention. + + Parameters + ---------- + %(params)s + transpose : bool + If ``True``, rank row-wise (each prediction vs all GTs) instead of column-wise. + + Returns + ------- + np.ndarray + Per-perturbation normalised rank in [0, 1]; 0 is a perfect top-1 retrieval. + """ + G = np.vstack(gt) + P = np.vstack(prediction) + sq = _sq_dists(P, G) + if transpose: + sq = sq.T + n = sq.shape[0] + noise = np.random.default_rng(42).uniform(0, 1e-12, size=sq.shape) + ranks = np.argsort(np.argsort(sq + noise, axis=0), axis=0) + return np.diag(ranks).astype(np.float64) / max(n - 1, 1) + + +@_doc(params=_DE) +def de_auprc(gt, prediction, ctx): + """Area under the precision-recall curve for DEG recovery. + + Positive class: ground-truth DEGs with ``gt.pvalue_adj < 0.05``. + + Parameters + ---------- + %(params)s + + Returns + ------- + float + AUPRC in [0, 1]; higher is better. + Returns ``nan`` if all genes fall in the same class. + """ + labels = gt.pvalue_adj < 0.05 + if labels.sum() == 0 or labels.sum() == labels.size: + return float("nan") + return float(average_precision_score(labels, prediction)) + + +@_doc(params=_DE) +def de_auroc(gt, prediction, ctx): + """Area under the ROC curve for DEG recovery. + + Positive class: ground-truth DEGs with ``gt.pvalue_adj < 0.05``. + + Parameters + ---------- + %(params)s + + Returns + ------- + float + AUROC in [0, 1]; higher is better. + Returns ``nan`` if all genes fall in the same class. + """ + labels = gt.pvalue_adj < 0.05 + if labels.sum() == 0 or labels.sum() == labels.size: + return float("nan") + return float(roc_auc_score(labels, prediction)) + + +@_doc(params=_DE) +def de_overlap(gt, prediction, ctx, k=50): + r"""Top-k overlap between ground-truth and predicted DE gene rankings. + + .. math:: + + \text{Overlap}_k + = \frac{|\text{top-}k(|gt.score|) \cap \text{top-}k(pred)|}{k} + + Parameters + ---------- + %(params)s + k : int + Number of top genes to intersect (default 50). + + Returns + ------- + float + Fraction of top-k genes shared, in [0, 1]; higher is better. + Returns ``nan`` if k >= number of genes. + """ + truth = np.abs(gt.score) + if k >= truth.size: + return float("nan") + top_truth = np.argpartition(-truth, k - 1)[:k] + top_prediction = np.argpartition(-prediction, k - 1)[:k] + return float(np.intersect1d(top_truth, top_prediction).size) / k diff --git a/scperteval/protocols/table.py b/src/scperteval/protocols/table.py similarity index 71% rename from scperteval/protocols/table.py rename to src/scperteval/protocols/table.py index 946609b..0cd535e 100644 --- a/scperteval/protocols/table.py +++ b/src/scperteval/protocols/table.py @@ -5,30 +5,42 @@ with no value the family default is used. To add a protocol, write a metric in ``metrics.py`` and add one row below. """ + from __future__ import annotations from functools import partial +from typing import Any from ..blocks.spaces import degs_space, pca_space, top_space from ..types import Param, Protocol from . import metrics as M - # --- parameter families: a CLI value selects a feature space (or feeds the metric) --- -top_k = Param("k", int, 50, space=top_space) # top-k DEGs by effect size -pca_k = Param("k", int, 50, space=pca_space) # k principal components +top_k = Param("k", int, 50, space=top_space) # top-k DEGs by effect size +pca_k = Param("k", int, 50, space=pca_space) # k principal components degs_padj = Param("padj", float, 0.05, space=degs_space) # DEGs at adjusted p < padj -overlap_k = Param("k", int, 50) # passed straight to de_overlap's k +overlap_k = Param("k", int, 50) # passed straight to de_overlap's k # --- shared wiring bundles (controls + score scale), splatted into rows with ** --- -_PB = dict(group="pseudobulk", positive="interpolated", negative="all_perturbed_mean") -_PB_CTRL = dict(group="pseudobulk", positive="interpolated", negative="control") -_LOWER = dict(better="lower", perfect=0.0) -_DIST = dict(group="distributional", positive="tech_dup", negative="all_perturbed", better="lower", perfect=0.0) -_DE = dict(group="de", positive="tech_dup", negative="all_perturbed", reference="all_perturbed", - neg_reference="control", better="higher", perfect=1.0) -_RANK = dict(group="pseudobulk", positive="interpolated", negative="global_mean", better="lower", perfect=0.0) +_PB: dict[str, Any] = dict(group="pseudobulk", positive="interpolated", negative="all_perturbed_mean") +_PB_CTRL: dict[str, Any] = dict(group="pseudobulk", positive="interpolated", negative="control") +_LOWER: dict[str, Any] = dict(better="lower", perfect=0.0) +_DIST: dict[str, Any] = dict( + group="distributional", positive="tech_dup", negative="all_perturbed", better="lower", perfect=0.0 +) +_DE: dict[str, Any] = dict( + group="de", + positive="tech_dup", + negative="all_perturbed", + reference="all_perturbed", + neg_reference="control", + better="higher", + perfect=1.0, +) +_RANK: dict[str, Any] = dict( + group="pseudobulk", positive="interpolated", negative="global_mean", better="lower", perfect=0.0 +) TABLE = [ @@ -43,14 +55,14 @@ Protocol("mse_top_k", M.mse, representation="centroid", param=top_k, **_PB, **_LOWER), Protocol("mse_degs_padj", M.mse, representation="centroid", param=degs_padj, **_PB, **_LOWER), Protocol("pearson_pert_top_k", M.pearson, representation="centroid", centering="allpert", param=top_k, **_PB_CTRL), - Protocol("pearson_pert_degs_padj", M.pearson, representation="centroid", centering="allpert", param=degs_padj, **_PB_CTRL), - + Protocol( + "pearson_pert_degs_padj", M.pearson, representation="centroid", centering="allpert", param=degs_padj, **_PB_CTRL + ), # --- cross-perturbation retrieval rank (dataset-wide over centroids) --- - Protocol("rank", partial(M.rank_retrieval, transpose=False), - representation="centroid", scope="dataset", **_RANK), - Protocol("transpose_rank", partial(M.rank_retrieval, transpose=True), - representation="centroid", scope="dataset", **_RANK), - + Protocol("rank", partial(M.rank_retrieval, transpose=False), representation="centroid", scope="dataset", **_RANK), + Protocol( + "transpose_rank", partial(M.rank_retrieval, transpose=True), representation="centroid", scope="dataset", **_RANK + ), # --- distributional: distances between cell populations (positive = technical duplicate) --- Protocol("unbiased_mmd_median_top_k", M.unbiased_mmd_median, representation="population", param=top_k, **_DIST), Protocol("unbiased_mmd_median_pca_k", M.unbiased_mmd_median, representation="population", param=pca_k, **_DIST), @@ -58,7 +70,6 @@ Protocol("energy_distance_pca_k", M.energy_distance, representation="population", param=pca_k, **_DIST), Protocol("sinkhorn_w2_top_k", M.sinkhorn_w2, representation="population", param=top_k, **_DIST), Protocol("sinkhorn_w2_pca_k", M.sinkhorn_w2, representation="population", param=pca_k, **_DIST), - # --- differential expression: GT DEGs vs prediction ranking --- Protocol("de_auprc", M.de_auprc, representation="de", **_DE), Protocol("de_auroc", M.de_auroc, representation="de", **_DE), diff --git a/scperteval/reference.py b/src/scperteval/reference.py similarity index 82% rename from scperteval/reference.py rename to src/scperteval/reference.py index 040c893..82b0fa3 100644 --- a/scperteval/reference.py +++ b/src/scperteval/reference.py @@ -1,12 +1,14 @@ """The comparison reference: a fixed cell sample served leave-one-out.""" + from __future__ import annotations import warnings +import numpy as np + class Reference: - """A comparison sample of cells (the all-perturbed subsample, or non-targeting - control), served leave-one-out. + """A fixed cell sample (all-perturbed subsample or control), served leave-one-out. ``subset(P)`` returns the sample with perturbation ``P``'s own cells removed, so a perturbation is never scored against a reference that contains itself. When @@ -16,13 +18,13 @@ class Reference: """ def __init__(self, cells, labels=None, warn_frac: float = 0.10): - self.cells = cells # densified once, (n_cells, n_genes) - self.labels = labels # per-cell perturbation, or None + self.cells = cells # densified once, (n_cells, n_genes) + self.labels = labels # per-cell perturbation, or None self.warn_frac = warn_frac self._n = len(cells) self._warned: set = set() - def keep(self, exclude) -> "object": + def keep(self, exclude) -> np.ndarray | None: """Boolean mask of the cells to keep, or None when nothing is excluded.""" if self.labels is None: return None @@ -31,6 +33,7 @@ def keep(self, exclude) -> "object": return mask def subset(self, exclude): + """Return the sample with ``exclude``'s own cells removed (leave-one-out).""" mask = self.keep(exclude) return self.cells if mask is None else self.cells[mask] diff --git a/scperteval/registry.py b/src/scperteval/registry.py similarity index 53% rename from scperteval/registry.py rename to src/scperteval/registry.py index 25418c0..9a81a11 100644 --- a/scperteval/registry.py +++ b/src/scperteval/registry.py @@ -1,35 +1,63 @@ """A minimal decorator registry for the pluggable building blocks.""" + from __future__ import annotations -from typing import Callable +from collections.abc import Callable class Registry: - """Maps a name to a function plus optional metadata, populated by decoration.""" + """Maps a name to a function plus optional metadata, populated by decoration. + + Parameters + ---------- + kind : str + Human-readable label used in error messages (e.g. ``"de-method"``). + + Example + ------- + >>> from scperteval.registry import Registry + >>> MY_REG = Registry("example") + >>> @MY_REG.register("double", description="multiply by 2") + ... def double(x): + ... return x * 2 + >>> MY_REG["double"](3) + 6 + >>> MY_REG.meta("double") + {'description': 'multiply by 2'} + >>> MY_REG.names() + ['double'] + """ def __init__(self, kind: str): self.kind = kind self._items: dict[str, tuple[Callable, dict]] = {} def register(self, name: str, **meta) -> Callable: + """Decorator that registers a function under ``name`` with optional metadata.""" + def deco(fn: Callable) -> Callable: self._items[name] = (fn, meta) return fn + return deco def add(self, name: str, fn: Callable, **meta) -> None: + """Register a function under ``name`` without using the decorator form.""" self._items[name] = (fn, meta) def __getitem__(self, name: str) -> Callable: + """Return the function registered under ``name``.""" if name not in self._items: raise KeyError(f"unknown {self.kind} {name!r}; available: {self.names()}") return self._items[name][0] def meta(self, name: str) -> dict: + """Return the metadata dict registered alongside ``name``.""" return self._items[name][1] def __contains__(self, name: str) -> bool: return name in self._items def names(self) -> list[str]: + """Sorted list of registered names.""" return sorted(self._items) diff --git a/scperteval/runner.py b/src/scperteval/runner.py similarity index 61% rename from scperteval/runner.py rename to src/scperteval/runner.py index b17de77..2fa00a1 100644 --- a/scperteval/runner.py +++ b/src/scperteval/runner.py @@ -1,4 +1,5 @@ """Runs one protocol over every perturbation and applies the chosen calibrator.""" + from __future__ import annotations import os @@ -12,13 +13,16 @@ def n_workers(cfg) -> int: + """Resolve the worker-thread count (``0`` = auto: CPU count minus 2, capped at 16).""" return cfg.workers if cfg.workers > 0 else max(1, min(16, (os.cpu_count() or 2) - 2)) def resolve_roles(p: Protocol, cfg) -> dict: - """Map each candidate role a calibrator may require to a source name. ``positive`` / - ``negative`` come from the protocol (or a CLI override); ``prediction`` is always the - model-prediction source (used by the ``score`` calibrator).""" + """Map each candidate calibrator role to a source name. + + ``positive`` / ``negative`` come from the protocol (or a CLI override); ``prediction`` + is always the model-prediction source (used by the ``score`` calibrator). + """ return { "positive": cfg.positive if cfg.positive != "auto" else p.positive, "negative": cfg.negative if cfg.negative != "auto" else p.negative, @@ -27,10 +31,29 @@ def resolve_roles(p: Protocol, cfg) -> dict: def run_protocol(p: Protocol, ctx, calibrator: Calibrator): - """Return (aggregate scores, per-perturbation rows, wall-clock seconds) for one protocol. - - ``scope`` chooses the loop: per-perturbation protocols score one perturbation at a time; - dataset-scope protocols hand the metric every perturbation's datapoint at once. + """Run one protocol over every perturbation and apply the calibrator. + + ``p.scope`` chooses the execution path: ``"perturbation"``-scope protocols run in a + thread pool (one perturbation at a time); ``"dataset"``-scope protocols collect all + perturbations' datapoints first, then call the metric once. + + Parameters + ---------- + p : ~scperteval.types.Protocol + The concrete (non-parameterised) protocol to evaluate. + ctx : ~scperteval.context.Context + Per-run context holding the dataset, caches, and building blocks. + calibrator : ~scperteval.types.Calibrator + Calibrator that converts raw positive/negative control scores into a final value. + + Returns + ------- + aggregates : dict + Aggregate scores across perturbations (e.g. ``{"mean": 0.42, "median": 0.38}``). + rows : list of dict + Per-perturbation records with raw control values and the calibrated score. + seconds : float + Wall-clock time for this protocol. """ roles = resolve_roles(p, ctx.cfg) needed = {role: roles[role] for role in calibrator.requires} @@ -42,15 +65,21 @@ def run_protocol(p: Protocol, ctx, calibrator: Calibrator): def _finalize(p, calibrator, perts, raws_list): """Per-perturbation rows + the aggregate, from each perturbation's raw control values.""" per_pert = [calibrator.per_pert(raws, p) for raws in raws_list] - rows = [{"protocol": p.name, "perturbation": pert, - **{f"raw_{role}": raws[role] for role in raws}, - calibrator.name: value} - for pert, raws, value in zip(perts, raws_list, per_pert)] + rows = [ + { + "protocol": p.name, + "perturbation": pert, + **{f"raw_{role}": raws[role] for role in raws}, + calibrator.name: value, + } + for pert, raws, value in zip(perts, raws_list, per_pert) + ] return calibrator.aggregate(np.asarray(per_pert, dtype=float)), rows def _run_per_perturbation(p: Protocol, ctx, calibrator: Calibrator, needed: dict): """Score one perturbation at a time (across a thread pool), gt vs each control.""" + def work(pert): ctx.current_pert = pert gt = ctx.view(pert, ctx.cfg.truth, p) @@ -66,11 +95,11 @@ def work(pert): def _run_dataset(p: Protocol, ctx, calibrator: Calibrator, needed: dict): - """Dataset-scope protocols: build every perturbation's gt and control datapoints, hand - the metric the full lists at once, then read off each perturbation's score. + """Score dataset-scope protocols by handing the metric all perturbations at once. - Perturbations are treated as a single group (these datasets are single-covariate); - drf instead ranks within each covariate group. + Build every perturbation's gt and control datapoints, call the metric once on the full + lists, then read off each perturbation's score. Perturbations are treated as a single + group (these datasets are single-covariate); drf instead ranks within each covariate group. """ perts = ctx.perturbations @@ -92,8 +121,11 @@ def collect(source): def compute_de_export(ctx, methods): - """{method: (statistic, pvalue_adj)} matrices (perturbations x genes) for each - method's GT(first-half)-vs-all-perturbed differential expression.""" + """Per-gene DE matrices for each method, for export. + + Returns ``{method: (statistic, pvalue_adj)}`` matrices (perturbations x genes) for each + method's ground-truth-vs-all-perturbed differential expression. + """ out = {} for method in methods: ctx.cfg.de_method = method diff --git a/src/scperteval/sources.py b/src/scperteval/sources.py new file mode 100644 index 0000000..c482173 --- /dev/null +++ b/src/scperteval/sources.py @@ -0,0 +1,110 @@ +"""Control/reference sources: each yields a perturbation's cells or pseudobulk centroid. + +A source's positive/negative role is chosen at the CLI; the registry just produces +the data. ``provides`` ("cells" or "centroid") drives the runner's compatibility +check and how the context turns a source into a view. ``description`` is shown by +``scperteval list sources``. +""" + +from __future__ import annotations + +import numpy as np + +from .dataset import to_dense +from .registry import Registry + +SOURCES = Registry("source") + + +@SOURCES.register( + "gt_half", + provides="cells", + description="ground truth — the first half of a perturbation's cells (calibration truth)", +) +def src_gt_half(ctx, pert): + """Ground-truth cells: the first half of the perturbation's cells.""" + return ctx.ds.cells(pert, half="first") + + +@SOURCES.register( + "gt_all_cells", + provides="cells", + description="ground truth — all of a perturbation's real cells (prediction-scoring truth)", +) +def src_gt_all_cells(ctx, pert): + """Ground-truth cells: all of the perturbation's real cells.""" + return ctx.ds.cells(pert) + + +@SOURCES.register( + "prediction", + provides="cells", + description="model-predicted cells for the perturbation, from the --predictions h5ad", +) +def src_prediction(ctx, pert): + """Model-predicted cells for the perturbation, gene-aligned to the dataset.""" + return ctx.predictions.cells(pert) + + +@SOURCES.register( + "tech_dup", + provides="cells", + description="technical duplicate — the held-out second half (single-cell positive control)", +) +def src_tech_dup(ctx, pert): + """Technical-duplicate cells: the perturbation's held-out second half.""" + return ctx.ds.cells(pert, half="second") + + +@SOURCES.register("control", provides="cells", description="non-targeting control cells") +def src_control(ctx, pert): + """Non-targeting control cells (subsampled).""" + return ctx.ds.control_cells(ctx.cfg.subsample) + + +@SOURCES.register( + "all_perturbed", + provides="cells", + description="all-perturbed reference sample, leave-one-out (single-cell negative control)", +) +def src_all_perturbed(ctx, pert): + """All-perturbed reference cells, with the target perturbation removed.""" + return ctx.reference().subset(pert) + + +@SOURCES.register( + "all_perturbed_mean", + provides="centroid", + description="all-perturbed mean, excluding the target — leave-one-out " + "(pseudobulk sibling of all_perturbed; pseudobulk negative control)", +) +def src_all_perturbed_mean(ctx, pert): + """All-perturbed pseudobulk mean, excluding the target perturbation.""" + return ctx.ds.allpert_mean_except(pert) + + +@SOURCES.register( + "global_mean", + provides="centroid", + description="mean of all perturbations — shared baseline for the ranking protocols", +) +def src_global_mean(ctx, pert): + """Pseudobulk mean over all perturbations (no target exclusion).""" + return ctx.ds.allpert_mean() + + +@SOURCES.register( + "interpolated", + provides="centroid", + description="interpolated duplicate — DE-weighted blend of the held-out half and " + "the dataset mean (pseudobulk positive control)", +) +def src_interpolated(ctx, pert): + """DE-weighted blend toward the held-out replicate, else the all-perturbed mean. + + Alpha = 1 - adjusted p per gene (from the run's DE method, vs control); blend toward + the held-out replicate where the gene is significant, else toward the all-perturbed mean. + """ + tech = np.asarray(to_dense(ctx.ds.cells(pert, half="second"))).mean(0) + alpha = np.nan_to_num(1.0 - ctx.de(pert, "tech_dup", "control").pvalue_adj, nan=0.0) + return alpha * tech + (1.0 - alpha) * ctx.ds.allpert_mean_except(pert) diff --git a/src/scperteval/types.py b/src/scperteval/types.py new file mode 100644 index 0000000..261a99b --- /dev/null +++ b/src/scperteval/types.py @@ -0,0 +1,241 @@ +"""Core dataclasses shared across the package.""" + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field, replace +from functools import partial + +import numpy as np + + +@dataclass +class RunConfig: + """Resolved options for a single run. + + Attributes + ---------- + dataset : str + Path to the preprocessed ``.h5ad`` file. + protocols : list of str + Names of the resolved (concrete) protocols to run. + de_method : str + DE backend for every DE-dependent unit (default ``"t-test"``). + subsample : int + Number of cells in the all-perturbed reference sample (default 8192). + seed : int + Random seed for subsampling and reproducibility (default 42). + positive : str + Override the positive control source; ``"auto"`` defers to the protocol. + negative : str + Override the negative control source; ``"auto"`` defers to the protocol. + truth : str + Label of the ground-truth condition used for DE (the perturbation key whose + cells serve as the ground-truth target; default ``"gt_half"``). + predictions : str or None + Path to a model-predictions ``.h5ad`` for prediction-scoring mode; + ``None`` selects calibration mode (default). + output : str + Calibrator to apply — ``"drf"`` or ``"bds"`` (default ``"drf"``). + out_dir : str + Directory for output CSV files (default ``"results"``). + workers : int + Number of worker threads; 0 auto-detects (default 0). + min_cells : int + Minimum cells required to evaluate a perturbation (default 30). + perturbation_key : str + ``adata.obs`` column holding perturbation labels (default ``"perturbation"``). + control_label : str + Label in ``perturbation_key`` that identifies control cells (default ``"control"``). + profile : bool + If ``True``, also write a per-protocol wall-clock timing CSV. + """ + + dataset: str + protocols: list[str] + de_method: str = "t-test" + subsample: int = 8192 + seed: int = 42 + positive: str = "auto" + negative: str = "auto" + truth: str = "gt_half" + predictions: str | None = None + output: str = "drf" + out_dir: str = "results" + workers: int = 0 + min_cells: int = 30 + perturbation_key: str = "perturbation" + control_label: str = "control" + profile: bool = False + + +@dataclass(frozen=True) +class DEResult: + """Per-gene differential expression for one target-vs-reference comparison. + + Attributes + ---------- + score : numpy.ndarray + Per-gene test statistic (e.g. t-statistic or Cliff's delta), shape ``(G,)``. + pvalue : numpy.ndarray + Raw per-gene p-values, shape ``(G,)``. + pvalue_adj : numpy.ndarray + Benjamini-Hochberg adjusted p-values, shape ``(G,)``. + extra : dict + Optional method-specific extras (e.g. ``{"u": u_statistic}`` for MWU). + """ + + score: np.ndarray + pvalue: np.ndarray + pvalue_adj: np.ndarray + extra: dict = field(default_factory=dict) + + +@dataclass(frozen=True) +class Param: + """A protocol's tunable knob: how a CLI value is cast, defaulted, and applied. + + ``space`` maps the value (``k=30``, ``padj=0.05``) to a feature-space name; when it is + ``None`` the value is passed straight to the metric as a keyword argument. + + Attributes + ---------- + name : str + Keyword argument name passed to the metric or space factory (e.g. ``"k"``). + cast : Callable + Type cast applied to the CLI string (e.g. ``int``, ``float``). + default : float + Default value used when no value is given on the CLI. + space : Callable or None + Factory mapping the value to a feature-space name (e.g. ``top_space``); + ``None`` means the value is passed directly to the metric as a keyword argument. + """ + + name: str + cast: Callable + default: float + space: Callable | None = None + + +@dataclass(frozen=True) +class Protocol: + """An evaluation protocol: a pure metric plus its data and control wiring. + + ``representation`` and ``scope`` are independent and together decide what the metric + receives: + + - ``representation`` — the shape of one perturbation's datapoint: ``"centroid"`` (a 1-D + pseudobulk vector), ``"population"`` (a cells × genes matrix), or ``"de"`` (a DEResult). + - ``scope`` — ``"perturbation"`` (default): the metric is called once per perturbation, + gets that perturbation's ``(gt, prediction)`` datapoints, and returns a scalar. + ``"dataset"``: the metric is called once, gets the list of *every* perturbation's + ``gt`` and ``prediction`` datapoints, and returns one score per perturbation (e.g. a + cross-perturbation retrieval rank). + + Set ``param`` to make the protocol *tunable* — its feature space (or a metric argument) + is then chosen per run from a CLI value, e.g. ``-p mse_top_k=30``; with no value the + parameter's default is used. Leave ``param`` unset for a fully-specified protocol. + + ``better`` and ``perfect`` describe the metric's score scale and are independent: + + - ``better`` — which direction is an improvement, ``"higher"`` or ``"lower"``. + Correlations and overlaps improve as they go up (``"higher"``); errors and + distances improve as they go down (``"lower"``). This is the metric's *sense*, + and it is not implied by ``perfect`` — e.g. perplexity has ``perfect=1.0`` yet + ``better="lower"``, and a log-likelihood has ``perfect=0.0`` yet ``better="higher"``. + - ``perfect`` — the value a flawless prediction attains (1.0 for a correlation, + 0.0 for an error). It anchors the top of the DRF scale. + + Together they let the calibrator orient the score: DRF measures how far the positive + control moves from the negative-control floor toward ``perfect`` in the ``better`` + direction, and BDS counts the perturbations where the positive control is ``better`` + than the negative. + + Attributes + ---------- + name : str + Protocol identifier used on the CLI (``-p name``) and in output CSVs. + metric : Callable + Pure metric function ``(gt, prediction, ctx) -> float``. + representation : str + Shape of each datapoint: ``"centroid"``, ``"population"``, or ``"de"``. + scope : str + ``"perturbation"`` (default) or ``"dataset"`` — how many perturbations the metric sees at once. + space : str + Feature space applied before scoring (default ``"full"``). + centering : str or None + Baseline subtracted before scoring — ``"ctrl"``, ``"allpert"``, or ``None``. + reference : str + Source used as the reference for the GT DE computation (default ``"all_perturbed"``). + neg_reference : str or None + Reference for the negative-control DE computation; ``None`` uses ``reference``. + better : str + ``"higher"`` or ``"lower"`` — which direction improves the score. + perfect : float + Score a flawless prediction attains (e.g. 1.0 for correlations, 0.0 for errors). + positive : str + Positive control source name (default ``"auto"``, deferring to the protocol). + negative : str + Negative control source name (default ``"auto"``, deferring to the protocol). + group : str + Display group for ``scperteval list protocols`` (e.g. ``"pseudobulk"``). + param : ~scperteval.types.Param or None + If set, makes the protocol tunable from the CLI; ``None`` for fixed protocols. + """ + + name: str + metric: Callable + representation: str + scope: str = "perturbation" + space: str = "full" + centering: str | None = None + reference: str = "all_perturbed" + neg_reference: str | None = None + better: str = "higher" + perfect: float = 1.0 + positive: str = "auto" + negative: str = "auto" + group: str = "" + param: Param | None = None + + @property + def parameterised(self) -> bool: + """Whether this protocol takes a CLI-supplied parameter.""" + return self.param is not None + + def resolve(self, value) -> Protocol: + """Concrete protocol for a tunable one at ``value`` (sets the space or metric arg).""" + assert self.param is not None # resolve() is only called on parameterised protocols + suffix = f"{value:g}" if isinstance(value, float) else str(value) + name = f"{self.name}={suffix}" + if self.param.space is not None: + return replace(self, name=name, space=self.param.space(value), param=None) + metric = partial(self.metric, **{self.param.name: value}) + return replace(self, name=name, metric=metric, param=None) + + +@dataclass(frozen=True) +class Calibrator: + """Turns per-control raw metric values into per-perturbation and aggregate scores. + + Attributes + ---------- + name : str + Registry key and output column name (e.g. ``"drf"``). + requires : tuple of str + Control roles needed — typically ``("positive", "negative")``. + per_pert : Callable + ``(raws: dict, protocol: Protocol) -> float`` — combines raw control values + into one per-perturbation calibrated score. + aggregate : Callable + ``(values: numpy.ndarray) -> dict`` — reduces per-perturbation scores into + summary statistics (e.g. ``{"mean": …, "median": …}``). + description : str + Human-readable description shown by ``scperteval list calibrators``. + """ + + name: str + requires: tuple[str, ...] + per_pert: Callable + aggregate: Callable + description: str = "" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a3d3e06 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,90 @@ +"""Shared fixtures and tiny in-memory dataset builders for the test suite.""" + +from __future__ import annotations + +import anndata as ad +import numpy as np +import pytest + +from scperteval.types import RunConfig + +# Each perturbation gets a strong, *distinct* block of DE genes, so ground-truth DEGs +# form a proper subset (de_auprc/auroc are well-defined) and the perturbation signal is +# unambiguous for the calibration controls. +_DE_GENES = { + "pertA": range(0, 6), + "pertB": range(15, 21), + "pertC": range(30, 36), + "pertD": range(45, 51), +} + + +def make_dataset(seed: int = 0, ng: int = 60, n_ctrl: int = 150, n_pert: int = 120) -> ad.AnnData: + """A tiny log-normalised-looking dataset: control + 4 perturbations with distinct DE blocks.""" + rng = np.random.default_rng(seed) + parts = [rng.poisson(1.0, (n_ctrl, ng)).astype(np.float32)] + labels = ["control"] * n_ctrl + for lab, genes in _DE_GENES.items(): + x = rng.poisson(1.0, (n_pert, ng)).astype(np.float32) + x[:, list(genes)] += 6.0 + parts.append(x) + labels += [lab] * n_pert + adata = ad.AnnData(np.vstack(parts)) + adata.var_names = [f"g{i}" for i in range(ng)] + adata.obs["perturbation"] = labels + return adata + + +def make_predictions( + dataset: ad.AnnData, kind: str = "perfect", shuffle_genes: bool = False, seed: int = 1 +) -> ad.AnnData: + """Build a prediction AnnData from the dataset's perturbed cells. + + ``perfect`` is an exact replica (should score optimally); ``degraded`` shrinks each cell + toward the control mean plus noise (a worse prediction). ``shuffle_genes`` permutes the + gene columns to exercise the name-based alignment. + """ + rng = np.random.default_rng(seed) + pert = np.asarray(dataset.obs["perturbation"]).astype(str) + mask = pert != "control" + sub = dataset[mask].copy() + x = np.asarray(sub.X, dtype=np.float32) + if kind == "degraded": + ctrl_mean = np.asarray(dataset.X[pert == "control"]).mean(0) + x = np.clip(0.4 * x + 0.6 * ctrl_mean + rng.normal(0, 0.2, x.shape), 0, None).astype(np.float32) + elif kind != "perfect": + raise ValueError(f"unknown prediction kind {kind!r}") + pred = ad.AnnData(x, obs=sub.obs.copy()) + pred.var_names = list(dataset.var_names) + if shuffle_genes: + pred = pred[:, rng.permutation(pred.n_vars)].copy() + return pred + + +def make_cfg(**kw) -> RunConfig: + """A RunConfig with small, fast, deterministic defaults for tests.""" + base = dict(dataset="-", protocols=[], de_method="t-test", subsample=400, seed=0, min_cells=10, workers=1) + base.update(kw) + return RunConfig(**base) + + +@pytest.fixture +def dataset_adata() -> ad.AnnData: + return make_dataset() + + +@pytest.fixture +def dataset_path(tmp_path, dataset_adata) -> str: + path = tmp_path / "dataset.h5ad" + dataset_adata.write_h5ad(path) + return str(path) + + +@pytest.fixture +def cfg_factory(): + return make_cfg + + +@pytest.fixture +def predictions_factory(): + return make_predictions diff --git a/tests/test_calibrate.py b/tests/test_calibrate.py new file mode 100644 index 0000000..c81111a --- /dev/null +++ b/tests/test_calibrate.py @@ -0,0 +1,53 @@ +"""Calibration mode: DRF/BDS over built-in positive/negative controls.""" + +from __future__ import annotations + +import numpy as np + +from scperteval.calibrators import CALIBRATORS +from scperteval.cli import _concrete +from scperteval.context import Context +from scperteval.dataset import Dataset +from scperteval.protocols.table import PROTOCOLS +from scperteval.runner import run_protocol + + +def _run(name, calibrator, dataset_adata, cfg): + ctx = Context(Dataset(dataset_adata, cfg), cfg) + return run_protocol(_concrete(PROTOCOLS[name]), ctx, CALIBRATORS[calibrator]) + + +def test_calibrators_registered(): + assert {"drf", "bds", "score"} <= set(CALIBRATORS) + # drf/bds need both controls; score needs only the prediction + assert CALIBRATORS["drf"].requires == ("positive", "negative") + assert CALIBRATORS["score"].requires == ("prediction",) + + +def test_drf_rows_have_control_columns(dataset_adata, cfg_factory): + agg, rows, seconds = _run("pearson_ctrl", "drf", dataset_adata, cfg_factory()) + assert seconds >= 0.0 + assert len(rows) == 4 # one row per perturbation + cols = set(rows[0]) + assert {"protocol", "perturbation", "raw_positive", "raw_negative", "drf"} <= cols + assert {"mean", "median"} <= set(agg) + + +def test_drf_positive_for_real_signal(dataset_adata, cfg_factory): + # the positive control (held-out replicate) should beat the uninformative baseline, + # so mean DRF is clearly positive on a dataset with strong perturbation signal. + for name in ("pearson_ctrl", "mse"): + agg, _, _ = _run(name, "drf", dataset_adata, cfg_factory()) + assert agg["mean"] > 0.0, name + + +def test_bds_is_a_fraction(dataset_adata, cfg_factory): + agg, rows, _ = _run("pearson_ctrl", "bds", dataset_adata, cfg_factory()) + assert 0.0 <= agg["bds"] <= 1.0 + assert all(r["bds"] in (0.0, 1.0) for r in rows) # per-perturbation BDS is binary + + +def test_de_protocol_calibrates(dataset_adata, cfg_factory): + # the de representation should produce a finite, well-defined auprc (distinct DE blocks) + agg, _, _ = _run("de_auprc", "drf", dataset_adata, cfg_factory()) + assert np.isfinite(agg["mean"]) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..f0382f6 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,35 @@ +"""End-to-end CLI dispatch for the calibrate / score / de subcommands.""" + +from __future__ import annotations + +import pytest + +from scperteval.cli import main + + +def test_calibrate_writes_drf_csv(dataset_path, tmp_path): + main(["calibrate", dataset_path, "-p", "pearson_ctrl,mse", "--out-dir", str(tmp_path), "--quiet"]) + assert len(list(tmp_path.glob("*__drf.csv"))) == 1 + + +def test_calibrate_bds_output(dataset_path, tmp_path): + main(["calibrate", dataset_path, "-p", "mse", "--output", "bds", "--out-dir", str(tmp_path), "--quiet"]) + assert len(list(tmp_path.glob("*__bds.csv"))) == 1 + + +def test_score_writes_score_csv(dataset_path, dataset_adata, predictions_factory, tmp_path): + pred_path = tmp_path / "pred.h5ad" + predictions_factory(dataset_adata, kind="degraded").write_h5ad(pred_path) + main(["score", dataset_path, str(pred_path), "-p", "pearson,mse,de_auprc", "--out-dir", str(tmp_path), "--quiet"]) + assert len(list(tmp_path.glob("*__score.csv"))) == 1 + + +def test_de_writes_h5(dataset_path, tmp_path): + main(["de", dataset_path, "--methods", "t-test", "--out-dir", str(tmp_path)]) + assert len(list(tmp_path.glob("*__de.h5"))) == 1 + + +def test_calibrate_rejects_score_output(dataset_path, tmp_path): + # `score` is a scoring-mode calibrator, not selectable from `calibrate --output` + with pytest.raises(SystemExit): + main(["calibrate", dataset_path, "-p", "mse", "--output", "score", "--out-dir", str(tmp_path)]) diff --git a/tests/test_de.py b/tests/test_de.py index 82bf471..1c1bc56 100644 --- a/tests/test_de.py +++ b/tests/test_de.py @@ -1,5 +1,6 @@ """Tests for the differential-expression backends, focused on the scanpy ``t-test_overestim_var`` variant added as a selectable DE method.""" + from __future__ import annotations import anndata as ad @@ -18,8 +19,7 @@ def _scanpy_overestim(Xt, Xr): adata.var_names = names adata.obs["g"] = ["target"] * Xt.shape[0] + ["reference"] * Xr.shape[0] adata.obs["g"] = adata.obs["g"].astype("category") - sc.tl.rank_genes_groups(adata, "g", groups=["target"], reference="reference", - method="t-test_overestim_var") + sc.tl.rank_genes_groups(adata, "g", groups=["target"], reference="reference", method="t-test_overestim_var") res = adata.uns["rank_genes_groups"] order = np.array([int(n) for n in res["names"]["target"]]) scores = np.empty(ng) @@ -32,8 +32,8 @@ def _scanpy_overestim(Xt, Xr): def test_overestim_var_matches_scanpy(): """Our backend reproduces scanpy's t-test_overestim_var statistic and p-values.""" rng = np.random.default_rng(0) - Xt = rng.poisson(1.0, (40, 60)).astype(np.float64) # small target group - Xr = rng.poisson(1.3, (90, 60)).astype(np.float64) # larger reference + Xt = rng.poisson(1.0, (40, 60)).astype(np.float64) # small target group + Xr = rng.poisson(1.3, (90, 60)).astype(np.float64) # larger reference de = de_ttest_overestim(Xt, Xr) sc_scores, sc_pvals = _scanpy_overestim(Xt, Xr) assert np.allclose(de.score, sc_scores, atol=1e-5, rtol=1e-4) @@ -76,8 +76,9 @@ def test_overestim_var_runs_through_export_path(): adata = ad.AnnData(np.vstack(parts).astype(np.float64)) adata.var_names = [f"g{i}" for i in range(ng)] adata.obs["perturbation"] = labels - cfg = RunConfig(dataset="-", protocols=[], de_method="t-test_overestim_var", - subsample=200, seed=0, min_cells=10, workers=1) + cfg = RunConfig( + dataset="-", protocols=[], de_method="t-test_overestim_var", subsample=200, seed=0, min_cells=10, workers=1 + ) ctx = Context(Dataset(adata, cfg), cfg) out = compute_de_export(ctx, ["t-test_overestim_var"]) stat, padj = out["t-test_overestim_var"] diff --git a/tests/test_score.py b/tests/test_score.py new file mode 100644 index 0000000..28f8ba9 --- /dev/null +++ b/tests/test_score.py @@ -0,0 +1,82 @@ +"""Prediction-scoring mode: score predictions against ground truth, and PredictionSet.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from scperteval.calibrators import CALIBRATORS +from scperteval.cli import _concrete +from scperteval.context import Context +from scperteval.dataset import Dataset +from scperteval.predictions import PredictionSet +from scperteval.protocols.table import PROTOCOLS +from scperteval.runner import run_protocol + + +def _score(name, dataset_adata, pred_adata, cfg): + ds = Dataset(dataset_adata, cfg) + ctx = Context(ds, cfg) + ctx.predictions = PredictionSet(pred_adata, ds, cfg) + return run_protocol(_concrete(PROTOCOLS[name]), ctx, CALIBRATORS["score"]) + + +def _score_cfg(cfg_factory): + return cfg_factory(truth="gt_all_cells", output="score") + + +def test_score_rows_have_prediction_column(dataset_adata, predictions_factory, cfg_factory): + pred = predictions_factory(dataset_adata, kind="perfect") + agg, rows, _ = _score("mse", dataset_adata, pred, _score_cfg(cfg_factory)) + assert {"protocol", "perturbation", "raw_prediction", "score"} <= set(rows[0]) + assert {"mean", "median"} <= set(agg) + + +def test_perfect_prediction_is_optimal(dataset_adata, predictions_factory, cfg_factory): + # an exact replica of the real cells must score optimally on every representation, + # even with the prediction's gene columns shuffled (name-based alignment). + pred = predictions_factory(dataset_adata, kind="perfect", shuffle_genes=True) + cfg = _score_cfg(cfg_factory) + assert _score("pearson", dataset_adata, pred, cfg)[0]["mean"] == pytest.approx(1.0, abs=1e-6) + assert _score("mse", dataset_adata, pred, cfg)[0]["mean"] == pytest.approx(0.0, abs=1e-6) + assert _score("de_auprc", dataset_adata, pred, cfg)[0]["mean"] == pytest.approx(1.0, abs=1e-6) + + +def test_degraded_prediction_scores_worse(dataset_adata, predictions_factory, cfg_factory): + cfg = _score_cfg(cfg_factory) + perfect = predictions_factory(dataset_adata, kind="perfect") + degraded = predictions_factory(dataset_adata, kind="degraded") + + def mean(name, pred): + return _score(name, dataset_adata, pred, cfg)[0]["mean"] + + assert mean("mse", degraded) > mean("mse", perfect) # error up + assert mean("de_auprc", degraded) < mean("de_auprc", perfect) # auprc down + + +def test_gene_set_mismatch_raises(dataset_adata, predictions_factory, cfg_factory): + ds = Dataset(dataset_adata, cfg_factory()) + pred_missing = predictions_factory(dataset_adata, kind="perfect")[:, :-1].copy() + with pytest.raises(ValueError, match="gene mismatch"): + PredictionSet(pred_missing, ds, cfg_factory()) + + +def test_missing_perturbation_raises(dataset_adata, predictions_factory, cfg_factory): + ds = Dataset(dataset_adata, cfg_factory()) + pred = predictions_factory(dataset_adata, kind="perfect") + only_a = pred[np.asarray(pred.obs["perturbation"]) == "pertA"].copy() + ps = PredictionSet(only_a, ds, cfg_factory()) + with pytest.raises(ValueError, match="no cells for perturbation"): + ps.cells("pertB") + + +def test_gene_alignment_reorders_by_name(dataset_adata, predictions_factory, cfg_factory): + # a shuffled-gene prediction is reordered to the dataset's gene order + ds = Dataset(dataset_adata, cfg_factory()) + pred = predictions_factory(dataset_adata, kind="perfect", shuffle_genes=True) + ps = PredictionSet(pred, ds, cfg_factory()) + cells = np.asarray(ps.cells("pertA")) + assert cells.shape[1] == len(ds.var_names) + # pertA's DE block (genes 0-5) should be the high-expression columns after realignment + col_means = cells.mean(0) + assert col_means[list(range(0, 6))].min() > col_means[10:].max()