From 54a4d9808233814d81f2356aa1c0be61eaedac73 Mon Sep 17 00:00:00 2001 From: HugoHakem Date: Mon, 15 Jun 2026 09:45:56 +0100 Subject: [PATCH 01/15] fix: illico requires python>=3.11 --- .gitignore | 5 +++++ pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2fd5a41..1fcaa4a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,11 @@ dist/ .venv/ venv/ +# Environment +uv.lock +.envrc +requirements-local.txt + # Run outputs results/ diff --git a/pyproject.toml b/pyproject.toml index 2d0924d..25bb2c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "scperteval" version = "0.1.0" description = "Evaluation Protocols for Perturbation Studies: per-metric DRF/BDS calibration on a single preprocessed dataset." -requires-python = ">=3.10" +requires-python = ">=3.11" dependencies = [ "anndata", "scanpy", From 4d8923d2b02e66d82a93823732056ef19b884f05 Mon Sep 17 00:00:00 2001 From: HugoHakem Date: Mon, 15 Jun 2026 09:52:26 +0100 Subject: [PATCH 02/15] feat: add linting: ruff, mypy, pyright --- pyproject.toml | 56 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 25bb2c0..2cc6b94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,16 @@ dependencies = [ "h5py", ] +[dependency-groups] +lint = [ + "ruff", + "mypy", + "pyright", +] +dev = [ + { include-group = "lint" }, +] + [project.scripts] scperteval = "scperteval.cli:main" @@ -28,3 +38,49 @@ build-backend = "hatchling.build" [tool.hatch.build.targets.wheel] packages = ["scperteval"] + +[tool.ruff] +line-length = 120 +src = ["scperteval"] +extend-include = ["*.ipynb"] +extend-exclude = ["deprecated"] +target-version = "py311" +format.docstring-code-format = true +lint.select = [ + "B", # flake8-bugbear — common bug patterns (mutable defaults, unused loop vars…) + "BLE", # flake8-blind-except — catch bare `except Exception` + "C4", # flake8-comprehensions — idiomatic list/dict/set comprehensions + "D", # pydocstyle — numpy docstring convention + "E", # pycodestyle errors (E722 bare except, E731 lambda assign, E741 ambiguous names…) + "F", # Pyflakes (unused imports F401, undefined names, unused variables…) + "I", # isort — consistent import ordering + "PTH", # flake8-use-pathlib — replace os.path.* and open() with pathlib equivalents + "RUF", # Ruff-specific — list unpacking, quadratic sum, unused unpacked vars… + "SIM", # flake8-simplify — simplify if/else, dict.keys() iteration… + "UP", # pyupgrade — modern Python syntax (UP006 built-in generics, UP007 x | None…) + "W", # pycodestyle warnings (W291 trailing whitespace, W605 invalid escape sequence…) +] +lint.ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic methods + "D107", # Missing docstring in __init__ + "D203", # No blank line before class docstring (incompatible with D211) + "D213", # Multi-line summary second line (incompatible with D212) + "D400", # First line should end with period (breaks single-line docstrings) + "D401", # First line in imperative mood + "E501", # line too long — formatter enforces line-length = 120, remaining cases are unfixable strings/comments + "RUF001", # ambiguous unicode in strings — intentional (×, –, ➕ used in UI labels and math) + "RUF002", # ambiguous unicode in docstrings — same rationale + "RUF003", # ambiguous unicode in comments — same rationale +] +lint.per-file-ignores."*/__init__.py" = ["F401"] +lint.per-file-ignores."docs/*" = ["I"] +lint.per-file-ignores."tests/*" = ["D"] +lint.pydocstyle.convention = "numpy" + +[tool.mypy] +ignore_missing_imports = true +explicit_package_bases = true +check_untyped_defs = true +warn_unreachable = true \ No newline at end of file From 00195250720ed9b53cd3bd6b526875f1a0944cb1 Mon Sep 17 00:00:00 2001 From: HugoHakem Date: Mon, 15 Jun 2026 16:52:09 +0100 Subject: [PATCH 03/15] docs: add sphinx setup inspired from scverse --- .gitignore | 4 + .readthedocs.yaml | 16 +++ CHANGELOG.md | 5 + docs/_static/.gitkeep | 0 docs/_static/css/custom.css | 4 + docs/_templates/.gitkeep | 0 docs/_templates/autosummary/class.rst | 59 ++++++++++ docs/api.md | 113 +++++++++++++++++++ docs/changelog.md | 3 + docs/conf.py | 121 ++++++++++++++++++++ docs/contributing.md | 56 ++++++++++ docs/extensions/typed_returns.py | 31 ++++++ docs/index.md | 13 +++ docs/notebooks/.gitkeep | 0 docs/references.bib | 10 ++ docs/references.md | 5 + pyproject.toml | 155 ++++++++++++++++---------- 17 files changed, 535 insertions(+), 60 deletions(-) create mode 100644 .readthedocs.yaml create mode 100644 CHANGELOG.md create mode 100644 docs/_static/.gitkeep create mode 100644 docs/_static/css/custom.css create mode 100644 docs/_templates/.gitkeep create mode 100644 docs/_templates/autosummary/class.rst create mode 100644 docs/api.md create mode 100644 docs/changelog.md create mode 100644 docs/conf.py create mode 100644 docs/contributing.md create mode 100644 docs/extensions/typed_returns.py create mode 100644 docs/index.md create mode 100644 docs/notebooks/.gitkeep create mode 100644 docs/references.bib create mode 100644 docs/references.md diff --git a/.gitignore b/.gitignore index 1fcaa4a..b84bb62 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,7 @@ results/ .idea/ .vscode/ .DS_Store + +# docs +docs/generated/ +docs/_build/ diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..fe7f5fe --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,16 @@ +# https://docs.readthedocs.io/en/stable/config-file/v2.html +version: 2 +build: + os: ubuntu-24.04 + tools: + python: "3.12" + jobs: + create_environment: + - asdf plugin add uv + - asdf install uv latest + - asdf global uv latest + build: + html: + - uv sync --group docs + - uv run sphinx-build -M html docs docs/_build -W + - mv docs/_build $READTHEDOCS_OUTPUT diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..e2e763c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,5 @@ +# Changelog + +## 0.1.0 (unreleased) + +Initial implementation of scPertEval. diff --git a/docs/_static/.gitkeep b/docs/_static/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/_static/css/custom.css b/docs/_static/css/custom.css new file mode 100644 index 0000000..b8c8d47 --- /dev/null +++ b/docs/_static/css/custom.css @@ -0,0 +1,4 @@ +/* Reduce the font size in data frames - See https://github.com/scverse/cookiecutter-scverse/issues/193 */ +div.cell_output table.dataframe { + font-size: 0.8em; +} diff --git a/docs/_templates/.gitkeep b/docs/_templates/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst new file mode 100644 index 0000000..67da490 --- /dev/null +++ b/docs/_templates/autosummary/class.rst @@ -0,0 +1,59 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + +{% block attributes %} +{% if attributes %} +Attributes table +~~~~~~~~~~~~~~~~ + +.. autosummary:: +{% for item in attributes %} + ~{{ name }}.{{ item }} +{%- endfor %} +{% endif %} +{% endblock %} + +{% block methods %} +{% if methods %} +Methods table +~~~~~~~~~~~~~ + +.. autosummary:: +{% for item in methods %} + {%- if item != '__init__' %} + ~{{ name }}.{{ item }} + {%- endif -%} +{%- endfor %} +{% endif %} +{% endblock %} + +{% block attributes_documentation %} +{% if attributes %} +Attributes +~~~~~~~~~~ + +{% for item in attributes %} + +.. autoattribute:: {{ [objname, item] | join(".") }} +{%- endfor %} + +{% endif %} +{% endblock %} + +{% block methods_documentation %} +{% if methods %} +Methods +~~~~~~~ + +{% for item in methods %} +{%- if item != '__init__' %} + +.. automethod:: {{ [objname, item] | join(".") }} +{%- endif -%} +{%- endfor %} + +{% endif %} +{% endblock %} diff --git a/docs/api.md b/docs/api.md new file mode 100644 index 0000000..71467b5 --- /dev/null +++ b/docs/api.md @@ -0,0 +1,113 @@ +# API + +## Core types + +```{eval-rst} +.. module:: scperteval.types +.. currentmodule:: scperteval.types + +.. autosummary:: + :toctree: generated + + RunConfig + Protocol + Calibrator +``` + +## Runner + +```{eval-rst} +.. module:: scperteval.runner +.. currentmodule:: scperteval.runner + +.. autosummary:: + :toctree: generated + + run_protocol +``` + +## Protocols + +```{eval-rst} +.. module:: scperteval.protocols +.. currentmodule:: scperteval.protocols + +.. autosummary:: + :toctree: generated + + PROTOCOLS + GROUPS + TABLE +``` + +### Metrics + +```{eval-rst} +.. module:: scperteval.protocols.metrics +.. currentmodule:: scperteval.protocols.metrics + +.. autosummary:: + :toctree: generated +``` + +## Calibrators + +```{eval-rst} +.. module:: scperteval.calibrators +.. currentmodule:: scperteval.calibrators + +.. autosummary:: + :toctree: generated +``` + +## Building blocks + +### Differential expression + +```{eval-rst} +.. module:: scperteval.blocks.de +.. currentmodule:: scperteval.blocks.de + +.. autosummary:: + :toctree: generated +``` + +### Feature spaces + +```{eval-rst} +.. module:: scperteval.blocks.spaces +.. currentmodule:: scperteval.blocks.spaces + +.. autosummary:: + :toctree: generated +``` + +## Registry + +```{eval-rst} +.. module:: scperteval.registry +.. currentmodule:: scperteval.registry + +.. autosummary:: + :toctree: generated + + Registry +``` + +## Dataset & I/O + +```{eval-rst} +.. module:: scperteval.dataset +.. currentmodule:: scperteval.dataset + +.. autosummary:: + :toctree: generated +``` + +```{eval-rst} +.. module:: scperteval.io +.. currentmodule:: scperteval.io + +.. autosummary:: + :toctree: generated +``` diff --git a/docs/changelog.md b/docs/changelog.md new file mode 100644 index 0000000..d9e79ba --- /dev/null +++ b/docs/changelog.md @@ -0,0 +1,3 @@ +```{include} ../CHANGELOG.md + +``` diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..cdb841c --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,121 @@ +# Configuration file for the Sphinx documentation builder. +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +import shutil +import sys +from datetime import datetime +from importlib.metadata import metadata +from pathlib import Path + +from sphinxcontrib import katex + +HERE = Path(__file__).parent +sys.path.insert(0, str(HERE / "extensions")) + +# -- Project information ----------------------------------------------------- + +info = metadata("scperteval") +project = info["Name"] +author = info.get("Author") or "scPertEval authors" +copyright = f"{datetime.now():%Y}, {author}" +version = info["Version"] +_project_urls = info.get_all("Project-URL") or [] +urls = dict(pu.split(", ", 1) for pu in _project_urls) +repository_url = urls.get("Source", "https://github.com/Virtual-Cell-Research-Community/scPertEval") + +release = info["Version"] + +bibtex_bibfiles = ["references.bib"] +templates_path = ["_templates"] +nitpicky = True +needs_sphinx = "4.0" + +html_context = { + "display_github": True, + "github_user": "Virtual-Cell-Research-Community", + "github_repo": "scPertEval", + "github_version": "main", + "conf_py_path": "/docs/", +} + +# -- General configuration --------------------------------------------------- + +extensions = [ + "myst_nb", + "sphinx_copybutton", + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinxcontrib.bibtex", + "sphinxcontrib.katex", + "sphinx_autodoc_typehints", + "sphinx_design", + "IPython.sphinxext.ipython_console_highlighting", + "sphinxext.opengraph", + *[p.stem for p in (HERE / "extensions").glob("*.py")], +] + +autosummary_generate = True +autodoc_member_order = "groupwise" +default_role = "literal" +napoleon_google_docstring = False +napoleon_numpy_docstring = True +napoleon_include_init_with_doc = False +napoleon_use_rtype = True +napoleon_use_param = True +myst_heading_anchors = 6 +myst_enable_extensions = [ + "amsmath", + "colon_fence", + "deflist", + "dollarmath", + "html_image", + "html_admonition", +] +myst_url_schemes = ("http", "https", "mailto") +nb_output_stderr = "remove" +nb_execution_mode = "off" +nb_merge_streams = True +typehints_defaults = "braces" +always_use_bars_union = True + +source_suffix = { + ".rst": "restructuredtext", + ".ipynb": "myst-nb", + ".myst": "myst-nb", +} + +intersphinx_mapping = { + "python": ("https://docs.python.org/3", None), + "anndata": ("https://anndata.readthedocs.io/en/stable/", None), + "scanpy": ("https://scanpy.readthedocs.io/en/stable/", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/", None), + "sklearn": ("https://scikit-learn.org/stable/", None), +} + +exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "**.ipynb_checkpoints"] + +# -- Options for HTML output ------------------------------------------------- + +html_theme = "sphinx_book_theme" +html_static_path = ["_static"] +html_css_files = ["css/custom.css"] + +html_title = project + +html_theme_options = { + "repository_url": repository_url, + "use_repository_button": True, + "path_to_docs": "docs/", + "navigation_with_keys": False, +} + +pygments_style = "default" +katex_prerender = shutil.which(katex.NODEJS_BINARY) is not None + +nitpick_ignore = [ # type: ignore + # Add exceptions here for links outside your control that fail to resolve + # ("py:class", "igraph.Graph"), +] diff --git a/docs/contributing.md b/docs/contributing.md new file mode 100644 index 0000000..3771dea --- /dev/null +++ b/docs/contributing.md @@ -0,0 +1,56 @@ +# Contributing guide + +We welcome contributions! Please open an issue or pull request on [GitHub](https://github.com/Virtual-Cell-Research-Community/scPertEval). + +## Installing dev dependencies + +:::::{tab-set} +::::{tab-item} uv + +```bash +uv sync --group dev +``` + +:::: + +::::{tab-item} pip + +```bash +pip install -e ".[dev]" +``` + +:::: +::::: + +## Code style + +This project uses [ruff][] for formatting and linting, and [mypy][]/[pyright][] for type checking. + +```bash +uv run ruff format . +uv run ruff check . +uv run mypy scperteval +``` + +[ruff]: https://docs.astral.sh/ruff/ + +## Building the docs locally + +```bash +uv sync --group doc +uv run sphinx-build -M html docs docs/_build -W +``` + +Then open `docs/_build/html/index.html`. + +## Publishing a release + +Update the version in `pyproject.toml`, commit, push, and create a GitHub release tagged `vX.Y.Z`. + +## Writing documentation + +- Use [numpy-style docstrings][numpydoc]. +- Add tutorials as Jupyter notebooks in `docs/notebooks/`. +- Add intersphinx entries to `docs/conf.py` for cross-references to external packages. + +[numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html diff --git a/docs/extensions/typed_returns.py b/docs/extensions/typed_returns.py new file mode 100644 index 0000000..1c2e393 --- /dev/null +++ b/docs/extensions/typed_returns.py @@ -0,0 +1,31 @@ +# code from https://github.com/theislab/scanpy/blob/master/docs/extensions/typed_returns.py +from __future__ import annotations + +import re +from collections.abc import Generator, Iterable + +from sphinx.application import Sphinx +from sphinx.ext.napoleon import NumpyDocstring # type: ignore + + +def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: + for line in lines: + if m := re.fullmatch(r"(?P\w+)\s+:\s+(?P[\w.]+)", line): + yield f"-{m['param']} (:class:`~{m['type']}`)" + else: + yield line + + +def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: + lines_raw = self._dedent(self._consume_to_next_section()) + if lines_raw[0] == ":": + del lines_raw[0] + lines = self._format_block(":returns: ", list(_process_return(lines_raw))) + if lines and lines[-1]: + lines.append("") + return lines + + +def setup(app: Sphinx): + """Set app.""" + NumpyDocstring._parse_returns_section = _parse_returns_section # type: ignore diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..53ff8f2 --- /dev/null +++ b/docs/index.md @@ -0,0 +1,13 @@ +```{include} ../README.md + +``` + +```{toctree} +:hidden: true +:maxdepth: 1 + +api.md +changelog.md +contributing.md +references.md +``` diff --git a/docs/notebooks/.gitkeep b/docs/notebooks/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docs/references.bib b/docs/references.bib new file mode 100644 index 0000000..0ce3036 --- /dev/null +++ b/docs/references.bib @@ -0,0 +1,10 @@ +@article{Virshup_2023, + doi = {10.1038/s41587-023-01733-8}, + url = {https://doi.org/10.1038%2Fs41587-023-01733-8}, + year = 2023, + month = {apr}, + publisher = {Springer Science and Business Media {LLC}}, + author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis}, + title = {The scverse project provides a computational ecosystem for single-cell omics data analysis}, + journal = {Nature Biotechnology} +} diff --git a/docs/references.md b/docs/references.md new file mode 100644 index 0000000..00ad6a6 --- /dev/null +++ b/docs/references.md @@ -0,0 +1,5 @@ +# References + +```{bibliography} +:cited: +``` diff --git a/pyproject.toml b/pyproject.toml index 2cc6b94..a2bc892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,86 +1,121 @@ +[build-system] +build-backend = "hatchling.build" +requires = [ "hatchling" ] + [project] name = "scperteval" version = "0.1.0" description = "Evaluation Protocols for Perturbation Studies: per-metric DRF/BDS calibration on a single preprocessed dataset." +readme = "README.md" +license = { file = "LICENSE" } +authors = [ + { name = "Zach Boldyga" }, +] requires-python = ">=3.11" +classifiers = [ + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", +] dependencies = [ - "anndata", - "scanpy", - "numpy", - "scipy", - "scikit-learn", - "pandas", - "geomloss", - "torch", - "illico", - "h5py", + "anndata", + "geomloss", + "h5py", + "illico", + "numpy", + "pandas", + "scanpy", + "scikit-learn", + "scipy", + "torch", ] +urls.Documentation = "https://scperteval.readthedocs.io/" +urls.Homepage = "https://github.com/Virtual-Cell-Research-Community/scPertEval" +urls.Source = "https://github.com/Virtual-Cell-Research-Community/scPertEval" +scripts.scperteval = "scperteval.cli:main" [dependency-groups] +dev = [ + { include-group = "docs" }, + { include-group = "lint" }, + { include-group = "test" }, +] +docs = [ + "ipykernel", + "ipython", + "myst-nb>=1.1", + "pandas", + "sphinx>=8.1", + "sphinx-autobuild>=2025.8.25", + "sphinx-autodoc-typehints", + "sphinx-book-theme>=1", + "sphinx-copybutton", + "sphinx-design", + "sphinxcontrib-bibtex>=1", + "sphinxcontrib-katex", + "sphinxext-opengraph", +] lint = [ - "ruff", - "mypy", - "pyright", + "mypy", + "pyright", + "ruff", ] -dev = [ - { include-group = "lint" }, +test = [ + "pytest", ] -[project.scripts] -scperteval = "scperteval.cli:main" - -[project.optional-dependencies] -test = ["pytest"] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.hatch.build.targets.wheel] -packages = ["scperteval"] +[tool.hatch] +build.targets.wheel.packages = [ "src/scperteval" ] [tool.ruff] -line-length = 120 -src = ["scperteval"] -extend-include = ["*.ipynb"] -extend-exclude = ["deprecated"] target-version = "py311" +line-length = 120 +src = [ "src" ] +extend-include = [ "*.ipynb" ] format.docstring-code-format = true lint.select = [ - "B", # flake8-bugbear — common bug patterns (mutable defaults, unused loop vars…) - "BLE", # flake8-blind-except — catch bare `except Exception` - "C4", # flake8-comprehensions — idiomatic list/dict/set comprehensions - "D", # pydocstyle — numpy docstring convention - "E", # pycodestyle errors (E722 bare except, E731 lambda assign, E741 ambiguous names…) - "F", # Pyflakes (unused imports F401, undefined names, unused variables…) - "I", # isort — consistent import ordering - "PTH", # flake8-use-pathlib — replace os.path.* and open() with pathlib equivalents - "RUF", # Ruff-specific — list unpacking, quadratic sum, unused unpacked vars… - "SIM", # flake8-simplify — simplify if/else, dict.keys() iteration… - "UP", # pyupgrade — modern Python syntax (UP006 built-in generics, UP007 x | None…) - "W", # pycodestyle warnings (W291 trailing whitespace, W605 invalid escape sequence…) + "B", # flake8-bugbear — common bug patterns (mutable defaults, unused loop vars…) + "BLE", # flake8-blind-except — catch bare `except Exception` + "C4", # flake8-comprehensions — idiomatic list/dict/set comprehensions + "D", # pydocstyle — numpy docstring convention + "E", # pycodestyle errors (E722 bare except, E731 lambda assign, E741 ambiguous names…) + "F", # Pyflakes (unused imports F401, undefined names, unused variables…) + "I", # isort — consistent import ordering + "PTH", # flake8-use-pathlib — replace os.path.* and open() with pathlib equivalents + "RUF", # Ruff-specific — list unpacking, quadratic sum, unused unpacked vars… + "SIM", # flake8-simplify — simplify if/else, dict.keys() iteration… + "UP", # pyupgrade — modern Python syntax (UP006 built-in generics, UP007 x | None…) + "W", # pycodestyle warnings (W291 trailing whitespace, W605 invalid escape sequence…) ] lint.ignore = [ - "D100", # Missing docstring in public module - "D104", # Missing docstring in public package - "D105", # Missing docstring in magic methods - "D107", # Missing docstring in __init__ - "D203", # No blank line before class docstring (incompatible with D211) - "D213", # Multi-line summary second line (incompatible with D212) - "D400", # First line should end with period (breaks single-line docstrings) - "D401", # First line in imperative mood - "E501", # line too long — formatter enforces line-length = 120, remaining cases are unfixable strings/comments - "RUF001", # ambiguous unicode in strings — intentional (×, –, ➕ used in UI labels and math) - "RUF002", # ambiguous unicode in docstrings — same rationale - "RUF003", # ambiguous unicode in comments — same rationale + "B905", # `zip()` without an explicit `strict=` parameter + "C408", # Unnecessary dict(), list() or tuple() calls that can be rewritten as empty literals. + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D105", # Missing docstring in magic methods + "D107", # Missing docstring in __init__ + "D203", # No blank line before class docstring (incompatible with D211) + "D213", # Multi-line summary second line (incompatible with D212) + "D400", # First line should end with period (breaks single-line docstrings) + "D401", # First line in imperative mood + "E501", # line too long — formatter enforces line-length = 120, remaining cases are unfixable strings/comments + "RUF001", # ambiguous unicode in strings — intentional (×, –, ➕ used in UI labels and math) + "RUF002", # ambiguous unicode in docstrings — same rationale + "RUF003", # ambiguous unicode in comments — same rationale ] -lint.per-file-ignores."*/__init__.py" = ["F401"] -lint.per-file-ignores."docs/*" = ["I"] -lint.per-file-ignores."tests/*" = ["D"] +lint.per-file-ignores."*/__init__.py" = [ "F401" ] +lint.per-file-ignores."docs/*" = [ "I" ] +lint.per-file-ignores."tests/*" = [ "D" ] lint.pydocstyle.convention = "numpy" [tool.mypy] +mypy_path = [ "src" ] +exclude = [ "^docs/" ] ignore_missing_imports = true -explicit_package_bases = true check_untyped_defs = true -warn_unreachable = true \ No newline at end of file +warn_unreachable = true + +[tool.pyright] +exclude = [ "**/.*", "**/__pycache__", "**/node_modules", ".venv", "docs/**" ] From 66b5ce7dcb0a75d069f6a548fc3fb2db7d3d8770 Mon Sep 17 00:00:00 2001 From: HugoHakem Date: Fri, 26 Jun 2026 16:43:29 +0100 Subject: [PATCH 04/15] docs: simplify README and restructure sphinx docs - README trimmed to install + quick start + sample datasets pointer; full content moved to sphinx docs - user-guide/ folder with usage, scoring, protocols, building-blocks - installation.md with dev setup (moved from contributing) - tutorials.md placeholder - index.md: citation block, grid cards, toctree reorganization - references.bib: add Miller 2025 and Vollenweider 2026 entries - conf.py: show_navbar_depth for stable sidebar --- CONTRIBUTORS.md | 6 +- README.md | 495 +---------------------------- docs/api.md | 14 +- docs/conf.py | 1 + docs/contributing.md | 56 +--- docs/index.md | 82 ++++- docs/installation.md | 44 +++ docs/references.bib | 34 +- docs/tutorials.md | 7 + docs/user-guide/building-blocks.md | 68 ++++ docs/user-guide/index.md | 10 + docs/user-guide/protocols.md | 221 +++++++++++++ docs/user-guide/scoring.md | 36 +++ docs/user-guide/usage.md | 96 ++++++ 14 files changed, 609 insertions(+), 561 deletions(-) create mode 100644 docs/installation.md create mode 100644 docs/tutorials.md create mode 100644 docs/user-guide/building-blocks.md create mode 100644 docs/user-guide/index.md create mode 100644 docs/user-guide/protocols.md create mode 100644 docs/user-guide/scoring.md create mode 100644 docs/user-guide/usage.md diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 32fbed5..37ab212 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -7,9 +7,9 @@ welcome. There are two paths, depending on what you're changing. If you're adding a protocol (a new metric, or a new combination of an existing metric with a space / centering / controls), **open a PR directly.** This is the common case and the -whole point of the project. See [Create a protocol](README.md#create-a-protocol) for the -two-step pattern (a pure function in `scperteval/protocols/algorithms.py` plus a row in -`scperteval/protocols/table.py`). Adding a new building block (feature space, DE method, control +whole point of the project. See [Create a protocol](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/docs/protocols.md#create-a-protocol) for the +two-step pattern (a pure function in `src/scperteval/protocols/metrics.py` plus a row in +`src/scperteval/protocols/table.py`). Adding a new building block (feature space, DE method, control source, calibrator) the same way is also welcome as a PR. Please include: diff --git a/README.md b/README.md index 625359f..0bd9523 100644 --- a/README.md +++ b/README.md @@ -1,504 +1,35 @@ # scPertEval — Evaluation Protocols for Perturbation Sequencing scPertEval is a command-line tool for **experimenting with and sharing reference implementations of -evaluation protocols** in single-cell perturbation studies. +evaluation protocols** in single-cell perturbation studies. It calibrates each protocol +against empirical positive and negative controls per perturbation, outputting the +**Dynamic Range Fraction (DRF)** and the **Bound Discrimination Score (BDS)**. -Evaluating predictions across a dataset's -perturbations reduces to a single question: how different is one group of cells from another? To answer this, an **evaluation protocol** is defined: a specific formulation of a metric, along with some representation of the perturbation data fed to the metric. However, there are a multitude of possibilities -- many already reflected in the literature -- and it can be challenging to compare and contrast protocols across the field and ultimately choose the right approach for a given dataset and problem space. +Our accompanying publication: TODO_LINK_HERE -scPertEval renders each protocol as a short, readable building block to run, read, reuse, and contribute back -- a place for -collaboration and alignment in the field. - -The same catalog of protocols backs three commands, each a different use case: - -- **`score`** — score a model's predictions against ground truth. Each protocol's metric is - applied to your **predicted** cells vs the **real** cells, one score per perturbation — the - conventional "how good is my prediction" evaluation (see - [Scoring predictions](#scoring-predictions-against-ground-truth)). -- **`calibrate`** — calibrate a protocol against empirical positive/negative controls built from - the dataset itself, reporting the **Dynamic Range Fraction (DRF)** and the **Bound Discrimination - Score (BDS)** — quantifying how well the protocol separates real perturbation signal from an - uninformative baseline (see [How calibration works](#how-calibration-works)). Use this to decide - whether a metric is trustworthy in the first place. -- **`de`** — export per-gene differential expression (statistic + adjusted p) to HDF5, since DE - is tightly coupled with several protocols. - -Our accompanying publiciation: TODO_LINK_HERE +**→ Full documentation at ** ## Install ```bash -pip install -e . # provides the `scperteval` command -``` - -## Input data - -scPertEval reads one preprocessed AnnData (`.h5ad`) per dataset. Only three things are required: - -- **`adata.X`** — normalized expression, cells × genes (e.g. `sc.pp.normalize_total` + `sc.pp.log1p`); sparse or dense float. -- **`adata.obs["perturbation"]`** — the perturbation label for each cell; control cells use the label `"control"`. Both names are configurable (`--perturbation-key` / `--control-label`). -- **`adata.var_names`** — gene identifiers, used as the DEG labels. - -Perturbations with at least `--min-cells` cells (default 30) are evaluated. Nothing else is -needed — references, DE, and PCA are all recomputed in memory, so no `uns`/`obsm`/`layers` are read. - -**Sample datasets.** Seven preprocessed perturbation datasets live in a public, read-only GCS -bucket and serve as a template for the format above: - -```bash -gsutil ls gs://scperteval/processed/ # wessels23, replogle22{k562,rpe1}, nadig25{hepg2,jurkat}, arch1, kaden25rpe1 -gsutil cp gs://scperteval/processed/wessels23_processed_complete.h5ad . -``` - -No gcloud account is needed — each file is also reachable over plain HTTPS at -`https://storage.googleapis.com/scperteval/processed/_processed_complete.h5ad`. - -## Run it - -```bash -# protocols by name — including parameterised ones (set k / padj per protocol) -scperteval calibrate data/wessels23.h5ad -p pearson_ctrl,unbiased_mmd_median_pca_k=20,de_overlap_k=10 --de-method t-test - -# a parameterised protocol with no value uses its default (k=50, padj=0.05) -scperteval calibrate data/wessels23.h5ad -p unbiased_mmd_median_top_k --de-method MWU - -# a whole group, or everything (parameterised protocols use their defaults) -scperteval calibrate data/wessels23.h5ad -p distributional --de-method MWU -scperteval calibrate data/wessels23.h5ad -p all --de-method t-test - -# DRF calibration only (compute DRF only; exclude BDS) -scperteval calibrate data/wessels23.h5ad -p pearson_ctrl --de-method t-test --output drf - -# SCORE predictions against ground truth — predicted cells vs real cells, per protocol. -# predictions.h5ad must have the same genes and perturbation labels as the dataset. -scperteval score data/wessels23.h5ad predictions.h5ad -p pearson,mse,de_auprc --de-method t-test - -# DE only — writes per-gene statistic + adjusted p to HDF5 (no protocol calibration) -# Provided as a convenience, since DE methods are tightly coupled with some evaluation protocols -scperteval de data/wessels23.h5ad --methods MWU - -# discover what's available -scperteval list protocols # also: de-methods | spaces | sources | calibrators -``` - -Each command prints a summary table and writes a per-perturbation CSV named -`____.csv`: `calibrate` writes the raw control values and the -calibrated DRF/BDS per perturbation (`…__drf.csv` / `…__bds.csv`); `score` writes the raw metric -value per perturbation (`…__score.csv`). `--profile` adds a per-protocol wall-clock timing CSV. - -**DE backends** (`scperteval list de-methods`): `t-test` (default, Welch's, moment-based), -`MWU` (Cliff's δ via illico), and `t-test_overestim_var` (scanpy's conservative-variance -variant — the reference variance is scaled by the target's cell count). Select one with -`--de-method` for a `calibrate`/`score`, or list several with `--methods` for a `de` export. The overestim -variant is a selectable backend for new protocols; no current protocol uses it. - -
scperteval calibrate --help - -``` -usage: scperteval calibrate [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] - [--subsample SUBSAMPLE] [--seed SEED] [--positive POSITIVE] - [--negative NEGATIVE] [--output {drf,bds}] [--out-dir OUT_DIR] - [--workers WORKERS] [--perturbation-key PERTURBATION_KEY] - [--control-label CONTROL_LABEL] [--min-cells MIN_CELLS] - [--profile] [--quiet] - dataset - - -p, --protocols comma-separated names (parameterised as name=value, e.g. - mse_top_k=30), a group (pseudobulk|distributional|de), or 'all' - --de-method {MWU, t-test, t-test_overestim_var} DE backend for every DE unit: - the interpolated positive control, the top_k/degs spaces, - the de_* protocols, and the WMSE weights - --subsample cells in the single-cell reference sample (default 8192) - --output {drf, bds} how per-perturbation values are calibrated - --positive/--negative override a protocol's controls by source name - --min-cells skip perturbations with fewer cells - --profile also write a per-protocol wall-clock timing table -``` -
- -
scperteval score --help - +pip install scperteval ``` -usage: scperteval score [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] - [--subsample SUBSAMPLE] [--seed SEED] [--out-dir OUT_DIR] [--workers WORKERS] - [--perturbation-key PERTURBATION_KEY] [--control-label CONTROL_LABEL] - [--min-cells MIN_CELLS] [--profile] [--quiet] - dataset predictions - dataset preprocessed .h5ad — the ground truth (real cells) - predictions predicted .h5ad — same genes and perturbation labels as the dataset - -p, --protocols comma-separated names, a group, or 'all' - --de-method DE backend for the de_* protocols, the top_k/degs spaces, and WMSE weights - --subsample cells in the all-perturbed reference (the ground truth is never subsampled) -``` - -Unlike `calibrate`, there are no `--positive`/`--negative`/`--output` options: the candidate is -always your prediction and the output is always the raw `score`. -
- -## Use it from Python - -Install with `pip install scperteval` (or, from this repo, -`pip install "scperteval @ git+https://github.com/Virtual-Cell-Research-Community/scPertEval.git"`). -The simplest path mirrors the CLI — call it via subprocess, exactly as the figure notebook does: - -```python -import subprocess, sys - -subprocess.run([sys.executable, "-m", "scperteval", "calibrate", "data/wessels23.h5ad", - "-p", "all", "--de-method", "t-test", "--out-dir", "results"], check=True) -# -> results/wessels23____drf.csv (raw control values + calibrated DRF per perturbation) - -# score predictions against ground truth instead: -subprocess.run([sys.executable, "-m", "scperteval", "score", "data/wessels23.h5ad", - "predictions.h5ad", "-p", "all", "--out-dir", "results"], check=True) -# -> results/wessels23____score.csv (raw metric value per perturbation) -``` - -## Look up an Evaluation Protocol - -Two files define each protocol: - -- **[`scperteval/protocols/metrics.py`](scperteval/protocols/metrics.py)** — the metric, as a - pure function of the ground truth and a `prediction` (the candidate being scored — a positive - or negative control under `calibrate`, or your model's output under `score`). e.g. `mse`, `mmd`, - `de_auprc`: - ```python - def mse(gt, prediction, ctx): - return float(np.mean((gt - prediction) ** 2)) - ``` -- **[`scperteval/protocols/table.py`](scperteval/protocols/table.py)** — one row wiring that function - to its data: the data representation it receives (`representation`), feature space, - reference centering, positive/negative controls, which direction is `better` - (`"higher"`/`"lower"`), and the `perfect` score: - ```python - Protocol("mse", M.mse, representation="centroid", - positive="interpolated", negative="all_perturbed_mean", better="lower", perfect=0.0) - ``` - -The next section breaks these arguments down while building one up from scratch. - -## Create a protocol - -A protocol is two things: a pure metric **function** and a one-line **spec** that wires it -to data and scoring. We'll ease in — the simplest possible protocol first, then the spec -broken down, then a few richer examples. - -### Start simple - -Here is a complete new protocol: mean absolute error on the standard pseudobulk profiles. - -1. Add a pure function to [`scperteval/protocols/metrics.py`](scperteval/protocols/metrics.py): - ```python - def mae(gt, prediction, ctx): - return float(np.mean(np.abs(gt - prediction))) - ``` - Every metric function has this signature. `gt` is one perturbation's ground-truth - profile; `prediction` is the candidate being compared against it (under `calibrate`, scPertEval - calls the function once for the positive control and once for the negative; under `score`, once - with your model's prediction). `ctx` is the dataset context, needed by only a few metrics — - ignore it otherwise. Return a single number. - -2. Add a row to [`scperteval/protocols/table.py`](scperteval/protocols/table.py): - ```python - Protocol("mae", M.mae, representation="centroid", - positive="interpolated", negative="all_perturbed_mean", - better="lower", perfect=0.0) - ``` - -Run it with `scperteval calibrate data.h5ad -p mae`. That is the whole protocol: MAE between each -perturbation's pseudobulk profile and its positive and negative controls, scored as -lower-is-better toward a perfect of 0. - -### The spec - -That row is the spec; parameters include: - -| argument | meaning | -|---|---| -| `name` | selects the protocol on the CLI (`-p mae`) | -| `representation` | the shape of each datapoint your function receives (see below) | -| `scope` | `"perturbation"` (default) or `"dataset"` — how many perturbations at once (see below) | -| `space` | which features to score — `full` (default), or a feature space like `top_50` | -| `centering` | a baseline subtracted before scoring, e.g. `"ctrl"` (default: none) | -| `positive` / `negative` | the two control sources to compare | -| `better` | `"higher"` or `"lower"` — which direction is an improvement | -| `perfect` | the value a flawless prediction attains | -| `param` | optional — a parameter family (`top_k`, `pca_k`, `degs_padj`, `overlap_k`) that makes the protocol tunable from the CLI; omit for a fixed protocol | - -**`representation`** decides the *shape* of each datapoint — the format `gt` and -`prediction` arrive in — so you never deal with sampling, references, or projection yourself: - -| `representation` | a datapoint is | -|---|---| -| `centroid` | a 1-D pseudobulk vector (one value per gene) | -| `population` | a `(cells × genes)` matrix | -| `de` | a `DEResult` (for the ground truth) / per-gene `|score|` ranking (for a prediction) | - -**`scope`** is the independent companion axis — *how many* perturbations the metric sees at once: - -| `scope` | the metric is called | -|---|---| -| `perturbation` (default) | once per perturbation — gets that perturbation's `(gt, prediction)` datapoints and returns a scalar | -| `dataset` | once for the whole dataset — gets the **list** of every perturbation's `gt` and `prediction` datapoints and returns one score per perturbation (e.g. a retrieval `rank`) | - -The two compose freely: `rank` is just `representation="centroid", scope="dataset"`; a -distributional retrieval metric would be `representation="population", scope="dataset"`. - -Many rows repeat the same wiring, so the top of `table.py` predefines the common -combinations as plain dicts. You then unpack one into a row with `**` (Python's -keyword-expansion syntax) to avoid retyping it: -```python -_PB = dict(group="pseudobulk", positive="interpolated", negative="all_perturbed_mean") -_LOWER = dict(better="lower", perfect=0.0) -``` -With those, the `mae` row above is exactly `Protocol("mae", M.mae, -representation="centroid", **_PB, **_LOWER)` — same protocol, less repetition. You'll see -these bundles reused throughout the table. - -### Building blocks — the palette - -The values those arguments take — feature spaces, control sources, DE methods, calibrators -— are registered building blocks. `scperteval list ` shows what's available -in each, with descriptions: - -**Feature spaces** (the `space` argument) +Or from this repo: ```bash -$ scperteval list spaces -degs_0.05 — ground-truth DEGs at adjusted p < 0.05, per perturbation -full — all genes, no transform -pca_50 — top 50 principal components (fit on the dataset) -top_50 — top 50 genes by ground-truth effect size, per perturbation +pip install "scperteval @ git+https://github.com/Virtual-Cell-Research-Community/scPertEval.git" ``` -`top_` / `pca_` / `degs_` are parameterised families (the defaults are shown); -a protocol template picks the value. If the space you need isn't here, see -[Add a feature space](#add-a-feature-space). - -**DE methods** (the `--de-method` choice) +## Quick start ```bash -$ scperteval list de-methods -MWU — Mann-Whitney U / Cliff's delta effect size (via illico) -t-test — Welch's t-test (default) — moment-based and fast +scperteval run data/wessels23.h5ad -p all --de-method t-test +scperteval list protocols # also: de-methods | spaces | sources | calibrators ``` -Chosen with `--de-method`; it applies to **every** DE-dependent unit (the `interpolated` -positive control, the `top_k`/`degs` spaces, the `de_*` protocols, and the WMSE weights). -To add another, see [Add a DE method](#add-a-de-method). - -**Control sources** (the `positive` / `negative` arguments) - -```bash -$ scperteval list sources -all_perturbed (cells) — all-perturbed reference sample, leave-one-out (single-cell negative control) -all_perturbed_mean (centroid) — all-perturbed mean, excluding the target — leave-one-out (pseudobulk sibling of all_perturbed; pseudobulk negative control) -control (cells) — non-targeting control cells -global_mean (centroid) — mean of all perturbations — shared baseline for the ranking protocols -gt_all_cells (cells) — ground truth — all of a perturbation's real cells (prediction-scoring truth) -gt_half (cells) — ground truth — the first half of a perturbation's cells (calibration truth) -interpolated (centroid) — interpolated duplicate — DE-weighted blend of the held-out half and the dataset mean (pseudobulk positive control) -prediction (cells) — model-predicted cells for the perturbation, from the --predictions h5ad -tech_dup (cells) — technical duplicate — the held-out second half (single-cell positive control) -``` - -The truth source is chosen by the command, not by a protocol: `calibrate` uses `gt_half` and -holds the other half out to build the positive control; `score` uses `gt_all_cells` and compares -it to `prediction`. - -Each `provides` cells or a pseudobulk `centroid`. Use via `positive=`/`negative=` (or -`--positive`/`--negative`). To add another, see [Add a control source](#add-a-control-source). - -**Calibrators** (the `--output` choice) - -```bash -$ scperteval list calibrators -drf — Dynamic Range Fraction — mean/median over perturbations (Miller et al. 2025) -bds — Bound Discrimination Score — fraction of perturbations the positive control wins (SBB 2026) -score — raw metric of a prediction vs ground truth — mean/median over perturbations (prediction-scoring mode) -``` - -`drf`/`bds` are chosen with `calibrate --output`; `score` is selected automatically by the -`score` command. To add another, see [Add a calibrator](#add-a-calibrator). - -### More examples - -With the spec and the palette in hand, richer protocols are just different combinations. - -**Same wiring, different metric.** Cosine distance on pseudobulk reuses the bundles wholesale: -```python -def cosine(gt, prediction, ctx): - return 1.0 - float(gt @ prediction / (np.linalg.norm(gt) * np.linalg.norm(prediction))) -``` -```python -Protocol("cosine", M.cosine, representation="centroid", **_PB, **_LOWER) -``` - -**Restrict to a feature space.** Set `space` to score only some genes — e.g. MAE on the -top-50 DEGs: -```python -Protocol("mae_top50", M.mae, representation="centroid", space="top_50", **_PB, **_LOWER) -``` - -**Expose the space as a knob (parameterised).** To make `k` adjustable per invocation, add a -`param` to the same `Protocol(...)` row — nothing else changes. The row's name carries the -parameter, and the value is supplied on the CLI: -```python -Protocol("mae_top_k", M.mae, representation="centroid", param=top_k, **_PB, **_LOWER) -``` -Then `scperteval calibrate data.h5ad -p mae_top_k=30` (or `mae_top_k` for the default `k=50`). The -families are `top_k` (top-k DEGs), `pca_k` (k PCs), and `degs_padj` (DEGs at adjusted -p < padj) for the space, and `overlap_k` to feed an integer straight to the metric. - -**A metric over cells, not profiles.** Switch `representation` to `population` and your -function receives `(cells × genes)` matrices; pair it with the single-cell controls: -```python -def my_mmd(gt, prediction, ctx): # gt, prediction are (cells × genes) - ... -``` -```python -Protocol("my_mmd_top50", M.my_mmd, representation="population", space="top_50", - positive="tech_dup", negative="all_perturbed", better="lower", perfect=0.0) -``` -This changes two pieces at once — the `representation` (so the function sees cells) and the controls -(the single-cell positive/negative) — which is the general pattern for a distributional -protocol. - -By now you've seen every moving part: the function, the spec, the building blocks the spec -draws on, fixed and parameterised spaces, and switching the representation the function -sees. Most new metrics are some combination of these. - -## Add a building block - -Spaces, DE methods, control sources, and calibrators are registered units — add one when -the palette is missing what a new protocol needs. Each is a small function (or object) plus -a one-line registration. - -### Add a feature space - -A space is a function `(X, ctx, pert) -> dense (cells × genes) array` that transforms the -gene axis. Register it with `@SPACES.register` in -[`scperteval/blocks/spaces.py`](scperteval/blocks/spaces.py); pass `global_space=True` if it doesn't -depend on the perturbation (so it can be computed once and shared): - -```python -@SPACES.register("hvg_100", global_space=True, description="100 highest-variance genes") -def space_hvg(X, ctx, pert): - keep = ... # indices of the genes to keep - return to_dense(X[:, keep]) -``` - -For a per-perturbation subset derived from the ground-truth DE (like `top_k` / `degs`), use -the `register_de_space(name, field=..., top=...)` helper in the same file instead. - -### Add a DE method - -A DE method maps `(target_cells, reference_cells) -> DEResult(score, pvalue, pvalue_adj)`. -Register it with `@DE_METHODS.register` in [`scperteval/blocks/de.py`](scperteval/blocks/de.py) (the -`bh` helper there BH-adjusts p-values): - -```python -@DE_METHODS.register("my_test", description="…") -def de_my_test(target, reference): - score, pvalue = ... # per-gene statistic and raw p-value - return DEResult(score=score, pvalue=pvalue, pvalue_adj=bh(pvalue)) -``` - -Then `--de-method my_test` routes every DE-dependent unit through it. - -### Add a control source - -A source maps `(ctx, pert) -> cells or a 1-D centroid`, declaring which with `provides`. -Register it with `@SOURCES.register` in [`scperteval/sources.py`](scperteval/sources.py): - -```python -@SOURCES.register("my_baseline", provides="centroid", description="…") -def src_my_baseline(ctx, pert): - return ... # a 1-D centroid (or cells, if provides="cells") -``` - -Use it as a control via `positive=`/`negative=` in a row, or `--positive`/`--negative` at -the CLI. - -### Add a calibrator - -A calibrator declares the control roles it needs, a per-perturbation combine, and a -cross-perturbation aggregate. Add a `Calibrator` to the `CALIBRATORS` dict in -[`scperteval/calibrators.py`](scperteval/calibrators.py): - -```python -CALIBRATORS["my_score"] = Calibrator( - "my_score", ("positive", "negative"), - per_pert=lambda raws, p: ..., # raws["positive"], raws["negative"] -> one number - aggregate=lambda v: {"my_score": float(np.nanmean(v))}, - description="…", -) -``` - -Then `--output my_score` reports it. - -## Scoring predictions against ground truth - -`scperteval score dataset.h5ad predictions.h5ad` is the conventional evaluation: each protocol's -metric is applied to your **predicted** cells against the **real** cells, one score per -perturbation. It runs the *same* protocol catalog as `calibrate`; only two pieces differ. - -- **ground truth** — *all* of a perturbation's real cells (the `gt_all_cells` source). Unlike - calibration, no half is held out and no positive/negative controls are built — the ground - truth is the whole real population. -- **prediction** — the matching cells from your `predictions.h5ad` (the `prediction` source). - The prediction file must contain the dataset's exact gene set (any order — columns are - reordered by name so the comparison lines up gene-for-gene) and the same perturbation labels. - A gene-set mismatch, or a perturbation present in the dataset but absent from the predictions, - raises an error naming exactly what's wrong. - -The `score` calibrator reports each protocol's raw metric value per perturbation and its -mean/median across perturbations, written to `____score.csv`. Higher- vs -lower-is-better follows each protocol's `better` field, exactly as in calibration. - -Architecturally this reuses everything — the per-perturbation loop, every metric, representation, -and feature space are shared with `calibrate`. The only differences are the **truth source** -(`gt_all_cells` instead of the held-out `gt_half`) and the **calibrator** (`score`, which needs -only the prediction, instead of `drf`/`bds`, which need both controls). The DE-derived feature -spaces (`top_k`, `degs`) and the WMSE weights are computed from this same all-cells ground truth. - -## How calibration works - -scPertEval's claim — a usable catalog of protocols — rests on **calibrating** each protocol -against two empirical controls per perturbation, so you can see whether a metric actually -separates signal from baseline rather than read a raw, uninterpretable number. - -- **positive control** — the best realistic candidate: the **technical duplicate** (a - held-out replicate) for single-cell protocols, the **interpolated duplicate** for pseudobulk. -- **negative control** — an uninformative baseline: the **all-perturbed reference, - excluding the target perturbation** (a full-resolution mean for pseudobulk; an 8192-cell - subsample for single-cell distances). - -**Dynamic Range Fraction (DRF)** — where the protocol's value sits between the negative -control (floor) and the perfect score, anchored by the positive control: - -``` -DRF = (positive − negative) / (perfect − negative) # per perturbation, clipped to [-1, 1] -``` - -`--output drf` reports the mean/median across perturbations. High DRF means the protocol -discriminates real signal; near zero means it doesn't. Introduced by Miller et al., -*Deep Learning-Based Genetic Perturbation Models Do Outperform Uninformative Baselines on -Well-Calibrated Metrics* (2025) — . - -**Bound Discrimination Score (BDS)** — the fraction of perturbations for which the positive -control beats the negative control under this protocol: - -``` -BDS = fraction of perturbations where positive control beats negative control # in [0, 1] -``` - -`--output bds` reports this fraction. It's a sensitivity check: a protocol with low BDS -can't even tell a technical replicate from an uninformative baseline, so its scores -shouldn't be trusted. Introduced by Vollenweider & Bühlmann, *Signal, Bounds, and -Baselines* (SBB, 2026) — (code: -). +Sample datasets are available at +`https://storage.googleapis.com/scperteval/processed/_processed_complete.h5ad`. --- diff --git a/docs/api.md b/docs/api.md index 71467b5..0b0dabc 100644 --- a/docs/api.md +++ b/docs/api.md @@ -28,17 +28,9 @@ ## Protocols -```{eval-rst} -.. module:: scperteval.protocols -.. currentmodule:: scperteval.protocols - -.. autosummary:: - :toctree: generated - - PROTOCOLS - GROUPS - TABLE -``` +`scperteval.protocols.TABLE` — list of all `Protocol` objects. +`scperteval.protocols.PROTOCOLS` — `{name: Protocol}` dict. +`scperteval.protocols.GROUPS` — sorted list of group names. ### Metrics diff --git a/docs/conf.py b/docs/conf.py index cdb841c..7469883 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -110,6 +110,7 @@ "use_repository_button": True, "path_to_docs": "docs/", "navigation_with_keys": False, + "show_navbar_depth": 1, } pygments_style = "default" diff --git a/docs/contributing.md b/docs/contributing.md index 3771dea..69b9914 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,56 +1,6 @@ -# Contributing guide +```{include} ../CONTRIBUTORS.md -We welcome contributions! Please open an issue or pull request on [GitHub](https://github.com/Virtual-Cell-Research-Community/scPertEval). - -## Installing dev dependencies - -:::::{tab-set} -::::{tab-item} uv - -```bash -uv sync --group dev ``` -:::: - -::::{tab-item} pip - -```bash -pip install -e ".[dev]" -``` - -:::: -::::: - -## Code style - -This project uses [ruff][] for formatting and linting, and [mypy][]/[pyright][] for type checking. - -```bash -uv run ruff format . -uv run ruff check . -uv run mypy scperteval -``` - -[ruff]: https://docs.astral.sh/ruff/ - -## Building the docs locally - -```bash -uv sync --group doc -uv run sphinx-build -M html docs docs/_build -W -``` - -Then open `docs/_build/html/index.html`. - -## Publishing a release - -Update the version in `pyproject.toml`, commit, push, and create a GitHub release tagged `vX.Y.Z`. - -## Writing documentation - -- Use [numpy-style docstrings][numpydoc]. -- Add tutorials as Jupyter notebooks in `docs/notebooks/`. -- Add intersphinx entries to `docs/conf.py` for cross-references to external packages. - -[numpydoc]: https://numpydoc.readthedocs.io/en/latest/format.html +For development setup (installing dependencies, running linters, building docs locally), +see [Installation](installation.md#development-setup). diff --git a/docs/index.md b/docs/index.md index 53ff8f2..2ab513a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,13 +1,89 @@ -```{include} ../README.md +# scPertEval — Evaluation Protocols for Perturbation Sequencing +scPertEval is a command-line tool for **experimenting with and sharing reference implementations of +evaluation protocols** in single-cell perturbation studies. + +Evaluating predictions across a dataset's perturbations reduces to a single question: how +different is one group of cells from another? To answer this, an **evaluation protocol** is +defined: a specific formulation of a metric, along with some representation of the +perturbation data fed to the metric. However, there are a multitude of possibilities — many +already reflected in the literature — and it can be challenging to compare and contrast +protocols across the field and ultimately choose the right approach for a given dataset and +problem space. + +scPertEval renders each protocol as a short, readable building block to run, read, reuse, +and contribute back — a place for collaboration and alignment in the field. Run the tool by +specifying a dataset, one or more protocols, and a method of differential expression; the +tool outputs calibration data: the **Dynamic Range Fraction (DRF)** and the **Bound +Discrimination Score (BDS)** — quantifying how well the protocol separates real perturbation +signal from an uninformative baseline (see [How scoring works](user-guide/scoring.md)). + +## Quick start + +```bash +pip install scperteval +scperteval run data/wessels23.h5ad -p all --de-method t-test +``` + +::::{grid} 1 2 3 3 +:gutter: 2 + +:::{grid-item-card} {octicon}`desktop-download;1em;` Installation +:link: installation +:link-type: doc +Get scPertEval installed and set up your development environment. +::: + +:::{grid-item-card} {octicon}`book;1em;` User guide +:link: user-guide/index +:link-type: doc +Learn how to run protocols, interpret scores, and explore the building blocks. +::: + +:::{grid-item-card} {octicon}`mortar-board;1em;` Tutorials +:link: tutorials +:link-type: doc +Step-by-step notebooks: CLI walkthrough, Python API, and extending the tool. +::: + +:::{grid-item-card} {octicon}`code-square;1em;` API reference +:link: api +:link-type: doc +Full reference for the Python API. +::: + +:::{grid-item-card} {octicon}`mark-github;1em;` GitHub +:link: https://github.com/Virtual-Cell-Research-Community/scPertEval +:link-type: url +Browse the source code, open issues, or contribute a pull request. +::: + +:::: + +## Citation + +If you use scPertEval, please cite {cite}`Schafer_2026`. + +```bibtex +@unpublished{Schafer_2026, + author = {Schäfer, Philipp S. L. and Reid, Kendall A. and Boldyga, Zach + and Aksu, Ekin Deniz and Hakem, Hugo and Saez-Rodriguez, Julio}, + title = {Towards a Principled Evaluation of Single-Cell Perturbation + Response Prediction Models}, + note = {In preparation}, + year = {2026}, +} ``` ```{toctree} :hidden: true -:maxdepth: 1 +:maxdepth: 2 +installation.md +user-guide/index +tutorials.md api.md changelog.md -contributing.md +Contributing references.md ``` diff --git a/docs/installation.md b/docs/installation.md new file mode 100644 index 0000000..6e149af --- /dev/null +++ b/docs/installation.md @@ -0,0 +1,44 @@ +# Installation + +## From PyPI + +```bash +pip install scperteval +``` + +## From source + +```bash +pip install "scperteval @ git+https://github.com/Virtual-Cell-Research-Community/scPertEval.git" +``` + +Or, for an editable install from a local clone: + +```bash +git clone https://github.com/Virtual-Cell-Research-Community/scPertEval.git +cd scPertEval +pip install -e . +``` + +## Development setup + +Install all dev dependencies (linting + docs + tests): + +```bash +uv sync --group dev +``` + +Run linters: + +```bash +uv run ruff format . +uv run ruff check . +uv run mypy src/scperteval +``` + +Build the docs locally with live reload: + +```bash +uv sync --group docs +uv run sphinx-autobuild docs docs/_build/html +``` diff --git a/docs/references.bib b/docs/references.bib index 0ce3036..fbcd467 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -1,10 +1,26 @@ -@article{Virshup_2023, - doi = {10.1038/s41587-023-01733-8}, - url = {https://doi.org/10.1038%2Fs41587-023-01733-8}, - year = 2023, - month = {apr}, - publisher = {Springer Science and Business Media {LLC}}, - author = {Isaac Virshup and Danila Bredikhin and Lukas Heumos and Giovanni Palla and Gregor Sturm and Adam Gayoso and Ilia Kats and Mikaela Koutrouli and Philipp Angerer and Volker Bergen and Pierre Boyeau and Maren Büttner and Gokcen Eraslan and David Fischer and Max Frank and Justin Hong and Michal Klein and Marius Lange and Romain Lopez and Mohammad Lotfollahi and Malte D. Luecken and Fidel Ramirez and Jeffrey Regier and Sergei Rybakov and Anna C. Schaar and Valeh Valiollah Pour Amiri and Philipp Weiler and Galen Xing and Bonnie Berger and Dana Pe'er and Aviv Regev and Sarah A. Teichmann and Francesca Finotello and F. Alexander Wolf and Nir Yosef and Oliver Stegle and Fabian J. Theis}, - title = {The scverse project provides a computational ecosystem for single-cell omics data analysis}, - journal = {Nature Biotechnology} +@unpublished{Schafer_2026, + author = {Schäfer, Philipp S. L. and Reid, Kendall A. and Boldyga, Zach and Aksu, Ekin Deniz and Hakem, Hugo and Saez-Rodriguez, Julio}, + title = {Towards a Principled Evaluation of Single-Cell Perturbation Response Prediction Models}, + note = {In preparation}, + year = {2026}, +} + +@misc{Miller_2025, + title = {Deep {{Learning-Based Genetic Perturbation Models Do Outperform Uninformative Baselines}} on {{Well-Calibrated Metrics}}}, + author = {Miller, Henry E. and Mejia, Gabriel M. and Leblanc, Francis J. A. and Wang, Bo and Swain, Brendan and Camillo, Lucas Paulo de Lima}, + year = {2025}, + month = oct, + publisher = {bioRxiv}, + pages = {2025.10.20.683304}, + doi = {10.1101/2025.10.20.683304}, +} + +@misc{Vollenweider_2026, + title = {Signal, {{Bounds}}, and {{Baselines}}: {{Principles}} for {{Rigorous Evaluation}} of {{High-Dimensional Biological Perturbation Prediction}}}, + author = {Vollenweider, Michael and B{\"u}hlmann, Peter}, + year = {2026}, + month = apr, + publisher = {bioRxiv}, + pages = {2026.04.20.719650}, + doi = {10.64898/2026.04.20.719650}, } diff --git a/docs/tutorials.md b/docs/tutorials.md new file mode 100644 index 0000000..9cd4fb2 --- /dev/null +++ b/docs/tutorials.md @@ -0,0 +1,7 @@ +# Tutorials + +Notebooks are coming soon. Planned tutorials: + +- **CLI walkthrough** — run protocols on a dataset end-to-end from the command line +- **Python API** — use scPertEval programmatically from a notebook or script +- **Extending scPertEval** — add a new protocol, feature space, or control source diff --git a/docs/user-guide/building-blocks.md b/docs/user-guide/building-blocks.md new file mode 100644 index 0000000..ffd2ede --- /dev/null +++ b/docs/user-guide/building-blocks.md @@ -0,0 +1,68 @@ +# Building blocks + +Spaces, DE methods, control sources, and calibrators are registered units — add one when +the palette is missing what a new protocol needs. Each is a small function (or object) plus +a one-line registration. + +## Add a feature space + +A space is a function `(X, ctx, pert) -> dense (cells × genes) array` that transforms the +gene axis. Register it with `@SPACES.register` in +[`src/scperteval/blocks/spaces.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/blocks/spaces.py); pass `global_space=True` if it doesn't +depend on the perturbation (so it can be computed once and shared): + +```python +@SPACES.register("hvg_100", global_space=True, description="100 highest-variance genes") +def space_hvg(X, ctx, pert): + keep = ... # indices of the genes to keep + return to_dense(X[:, keep]) +``` + +For a per-perturbation subset derived from the ground-truth DE (like `top_k` / `degs`), use +the `register_de_space(name, field=..., top=...)` helper in the same file instead. + +## Add a DE method + +A DE method maps `(target_cells, reference_cells) -> DEResult(score, pvalue, pvalue_adj)`. +Register it with `@DE_METHODS.register` in [`src/scperteval/blocks/de.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/blocks/de.py) (the +`bh` helper there BH-adjusts p-values): + +```python +@DE_METHODS.register("my_test", description="…") +def de_my_test(target, reference): + score, pvalue = ... # per-gene statistic and raw p-value + return DEResult(score=score, pvalue=pvalue, pvalue_adj=bh(pvalue)) +``` + +Then `--de-method my_test` routes every DE-dependent unit through it. + +## Add a control source + +A source maps `(ctx, pert) -> cells or a 1-D centroid`, declaring which with `provides`. +Register it with `@SOURCES.register` in [`src/scperteval/sources.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/sources.py): + +```python +@SOURCES.register("my_baseline", provides="centroid", description="…") +def src_my_baseline(ctx, pert): + return ... # a 1-D centroid (or cells, if provides="cells") +``` + +Use it as a control via `positive=`/`negative=` in a row, or `--positive`/`--negative` at +the CLI. + +## Add a calibrator + +A calibrator declares the control roles it needs, a per-perturbation combine, and a +cross-perturbation aggregate. Add a `Calibrator` to the `CALIBRATORS` dict in +[`src/scperteval/calibrators.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/calibrators.py): + +```python +CALIBRATORS["my_score"] = Calibrator( + "my_score", ("positive", "negative"), + per_pert=lambda raws, p: ..., # raws["positive"], raws["negative"] -> one number + aggregate=lambda v: {"my_score": float(np.nanmean(v))}, + description="…", +) +``` + +Then `--output my_score` reports it. diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md new file mode 100644 index 0000000..a5a568f --- /dev/null +++ b/docs/user-guide/index.md @@ -0,0 +1,10 @@ +# User guide + +```{toctree} +:maxdepth: 1 + +usage +scoring +protocols +building-blocks +``` diff --git a/docs/user-guide/protocols.md b/docs/user-guide/protocols.md new file mode 100644 index 0000000..5aebf80 --- /dev/null +++ b/docs/user-guide/protocols.md @@ -0,0 +1,221 @@ +# Protocols + +## Look up an evaluation protocol + +Two files define each protocol: + +- **[`src/scperteval/protocols/metrics.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/metrics.py)** — the metric, as a + pure function of the ground truth and a `prediction` (whichever control is being scored, + positive or negative). e.g. `mse`, `mmd`, `de_auprc`: + + ```python + def mse(gt, prediction, ctx): + return float(np.mean((gt - prediction) ** 2)) + ``` + +- **[`src/scperteval/protocols/table.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/table.py)** — one row wiring that function + to its data: the data representation it receives (`representation`), feature space, + reference centering, positive/negative controls, which direction is `better` + (`"higher"`/`"lower"`), and the `perfect` score: + + ```python + Protocol("mse", M.mse, representation="centroid", + positive="interpolated", negative="all_perturbed_mean", better="lower", perfect=0.0) + ``` + +The next section breaks these arguments down while building one up from scratch. + +## Create a protocol + +A protocol is two things: a pure metric **function** and a one-line **spec** that wires it +to data and scoring. We'll ease in — the simplest possible protocol first, then the spec +broken down, then a few richer examples. + +### Start simple + +Here is a complete new protocol: mean absolute error on the standard pseudobulk profiles. + +1. Add a pure function to [`src/scperteval/protocols/metrics.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/metrics.py): + + ```python + def mae(gt, prediction, ctx): + return float(np.mean(np.abs(gt - prediction))) + ``` + + Every metric function has this signature. `gt` is one perturbation's ground-truth + profile; `prediction` is a control being compared against it (scPertEval calls the function + once for the positive control and once for the negative). `ctx` is the dataset context, + needed by only a few metrics — ignore it otherwise. Return a single number. + +2. Add a row to [`src/scperteval/protocols/table.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/table.py): + + ```python + Protocol("mae", M.mae, representation="centroid", + positive="interpolated", negative="all_perturbed_mean", + better="lower", perfect=0.0) + ``` + +Run it with `scperteval run data.h5ad -p mae`. That is the whole protocol: MAE between each +perturbation's pseudobulk profile and its positive and negative controls, scored as +lower-is-better toward a perfect of 0. + +### The spec + +That row is the spec; parameters include: + +| argument | meaning | +|---|---| +| `name` | selects the protocol on the CLI (`-p mae`) | +| `representation` | the shape of each datapoint your function receives (see below) | +| `scope` | `"perturbation"` (default) or `"dataset"` — how many perturbations at once (see below) | +| `space` | which features to score — `full` (default), or a feature space like `top_50` | +| `centering` | a baseline subtracted before scoring, e.g. `"ctrl"` (default: none) | +| `positive` / `negative` | the two control sources to compare | +| `better` | `"higher"` or `"lower"` — which direction is an improvement | +| `perfect` | the value a flawless prediction attains | +| `param` | optional — a parameter family (`top_k`, `pca_k`, `degs_padj`, `overlap_k`) that makes the protocol tunable from the CLI; omit for a fixed protocol | + +**`representation`** decides the *shape* of each datapoint — the format `gt` and +`prediction` arrive in — so you never deal with sampling, references, or projection yourself: + +| `representation` | a datapoint is | +|---|---| +| `centroid` | a 1-D pseudobulk vector (one value per gene) | +| `population` | a `(cells × genes)` matrix | +| `de` | a `DEResult` (for `gt`) / per-gene `\|score\|` ranking (for a prediction) | + +**`scope`** is the independent companion axis — *how many* perturbations the metric sees at once: + +| `scope` | the metric is called | +|---|---| +| `perturbation` (default) | once per perturbation — gets that perturbation's `(gt, prediction)` datapoints and returns a scalar | +| `dataset` | once for the whole dataset — gets the **list** of every perturbation's `gt` and `prediction` datapoints and returns one score per perturbation (e.g. a retrieval `rank`) | + +The two compose freely: `rank` is just `representation="centroid", scope="dataset"`; a +distributional retrieval metric would be `representation="population", scope="dataset"`. + +Many rows repeat the same wiring, so the top of `table.py` predefines the common +combinations as plain dicts. You then unpack one into a row with `**` (Python's +keyword-expansion syntax) to avoid retyping it: + +```python +_PB = dict(group="pseudobulk", positive="interpolated", negative="all_perturbed_mean") +_LOWER = dict(better="lower", perfect=0.0) +``` + +With those, the `mae` row above is exactly `Protocol("mae", M.mae, +representation="centroid", **_PB, **_LOWER)` — same protocol, less repetition. You'll see +these bundles reused throughout the table. + +### Building blocks — the palette + +The values those arguments take — feature spaces, control sources, DE methods, calibrators +— are registered building blocks. `scperteval list ` shows what's available +in each, with descriptions: + +**Feature spaces** (the `space` argument) + +```bash +$ scperteval list spaces +degs_0.05 — ground-truth DEGs at adjusted p < 0.05, per perturbation +full — all genes, no transform +pca_50 — top 50 principal components (fit on the dataset) +top_50 — top 50 genes by ground-truth effect size, per perturbation +``` + +`top_` / `pca_` / `degs_` are parameterised families (the defaults are shown); +a protocol template picks the value. If the space you need isn't here, see +[Add a feature space](building-blocks.md#add-a-feature-space). + +**DE methods** (the `--de-method` choice) + +```bash +$ scperteval list de-methods +MWU — Mann-Whitney U / Cliff's delta effect size (via illico) +t-test — Welch's t-test (default) — moment-based and fast +``` + +Chosen with `--de-method`; it applies to **every** DE-dependent unit (the `interpolated` +positive control, the `top_k`/`degs` spaces, the `de_*` protocols, and the WMSE weights). +To add another, see [Add a DE method](building-blocks.md#add-a-de-method). + +**Control sources** (the `positive` / `negative` arguments) + +```text +$ scperteval list sources +all_perturbed (cells) — all-perturbed reference sample, leave-one-out (single-cell negative control) +all_perturbed_mean (centroid) — all-perturbed mean, excluding the target — leave-one-out (pseudobulk sibling of all_perturbed; pseudobulk negative control) +control (cells) — non-targeting control cells +global_mean (centroid) — mean of all perturbations — shared baseline for the ranking protocols +gt (cells) — ground truth — the first half of a perturbation's cells +interpolated (centroid) — interpolated duplicate — DE-weighted blend of the held-out half and the dataset mean (pseudobulk positive control) +tech_dup (cells) — technical duplicate — the held-out second half (single-cell positive control) +``` + +Each `provides` cells or a pseudobulk `centroid`. Use via `positive=`/`negative=` (or +`--positive`/`--negative`). To add another, see [Add a control source](building-blocks.md#add-a-control-source). + +**Calibrators** (the `--output` choice) + +```bash +$ scperteval list calibrators +drf — Dynamic Range Fraction — mean/median over perturbations (Miller et al. 2025) +bds — Bound Discrimination Score — fraction of perturbations the positive control wins (SBB 2026) +``` + +Chosen with `--output`. To add another, see [Add a calibrator](building-blocks.md#add-a-calibrator). + +### More examples + +With the spec and the palette in hand, richer protocols are just different combinations. + +**Same wiring, different metric.** Cosine distance on pseudobulk reuses the bundles wholesale: + +```python +def cosine(gt, prediction, ctx): + return 1.0 - float(gt @ prediction / (np.linalg.norm(gt) * np.linalg.norm(prediction))) +``` + +```python +Protocol("cosine", M.cosine, representation="centroid", **_PB, **_LOWER) +``` + +**Restrict to a feature space.** Set `space` to score only some genes — e.g. MAE on the +top-50 DEGs: + +```python +Protocol("mae_top50", M.mae, representation="centroid", space="top_50", **_PB, **_LOWER) +``` + +**Expose the space as a knob (parameterised).** To make `k` adjustable per run, add a +`param` to the same `Protocol(...)` row — nothing else changes. The row's name carries the +parameter, and the value is supplied on the CLI: + +```python +Protocol("mae_top_k", M.mae, representation="centroid", param=top_k, **_PB, **_LOWER) +``` + +Then `scperteval run data.h5ad -p mae_top_k=30` (or `mae_top_k` for the default `k=50`). The +families are `top_k` (top-k DEGs), `pca_k` (k PCs), and `degs_padj` (DEGs at adjusted +p < padj) for the space, and `overlap_k` to feed an integer straight to the metric. + +**A metric over cells, not profiles.** Switch `representation` to `population` and your +function receives `(cells × genes)` matrices; pair it with the single-cell controls: + +```python +def my_mmd(gt, prediction, ctx): # gt, prediction are (cells × genes) + ... +``` + +```python +Protocol("my_mmd_top50", M.my_mmd, representation="population", space="top_50", + positive="tech_dup", negative="all_perturbed", better="lower", perfect=0.0) +``` + +This changes two pieces at once — the `representation` (so the function sees cells) and the controls +(the single-cell positive/negative) — which is the general pattern for a distributional +protocol. + +By now you've seen every moving part: the function, the spec, the building blocks the spec +draws on, fixed and parameterised spaces, and switching the representation the function +sees. Most new metrics are some combination of these. diff --git a/docs/user-guide/scoring.md b/docs/user-guide/scoring.md new file mode 100644 index 0000000..f23e394 --- /dev/null +++ b/docs/user-guide/scoring.md @@ -0,0 +1,36 @@ +# How scoring works (the calibration) + +scPertEval's claim — a usable catalog of protocols — rests on **calibrating** each protocol +against two empirical controls per perturbation, so you can see whether a metric actually +separates signal from baseline rather than read a raw, uninterpretable number. + +- **positive control** — the best realistic candidate: the **technical duplicate** (a + held-out replicate) for single-cell protocols, the **interpolated duplicate** for pseudobulk. +- **negative control** — an uninformative baseline: the **all-perturbed reference, + excluding the target perturbation** (a full-resolution mean for pseudobulk; an 8192-cell + subsample for single-cell distances). + +## Dynamic Range Fraction (DRF) + +Where the protocol's value sits between the negative control (floor) and the perfect score, +anchored by the positive control: + +```text +DRF = (positive − negative) / (perfect − negative) # per perturbation, clipped to [-1, 1] +``` + +`--output drf` reports the mean/median across perturbations. High DRF means the protocol +discriminates real signal; near zero means it doesn't. Introduced by {cite:t}`Miller_2025`. + +## Bound Discrimination Score (BDS) + +The fraction of perturbations for which the positive control beats the negative control +under this protocol: + +```text +BDS = fraction of perturbations where positive control beats negative control # in [0, 1] +``` + +`--output bds` reports this fraction. It's a sensitivity check: a protocol with low BDS +can't even tell a technical replicate from an uninformative baseline, so its scores +shouldn't be trusted. Introduced by {cite:t}`Vollenweider_2026`. diff --git a/docs/user-guide/usage.md b/docs/user-guide/usage.md new file mode 100644 index 0000000..3134453 --- /dev/null +++ b/docs/user-guide/usage.md @@ -0,0 +1,96 @@ +# Usage + +## Input data + +scPertEval reads one preprocessed AnnData (`.h5ad`) per dataset. Only three things are required: + +- **`adata.X`** — normalized expression, cells × genes (e.g. `sc.pp.normalize_total` + `sc.pp.log1p`); sparse or dense float. +- **`adata.obs["perturbation"]`** — the perturbation label for each cell; control cells use the label `"control"`. Both names are configurable (`--perturbation-key` / `--control-label`). +- **`adata.var_names`** — gene identifiers, used as the DEG labels. + +Perturbations with at least `--min-cells` cells (default 30) are evaluated. Nothing else is +needed — references, DE, and PCA are all recomputed in memory, so no `uns`/`obsm`/`layers` are read. + +**Sample datasets.** Seven preprocessed perturbation datasets live in a public, read-only GCS +bucket and serve as a template for the format above: + +```bash +gsutil ls gs://scperteval/processed/ # wessels23, replogle22{k562,rpe1}, nadig25{hepg2,jurkat}, arch1, kaden25rpe1 +gsutil cp gs://scperteval/processed/wessels23_processed_complete.h5ad . +``` + +No gcloud account is needed — each file is also reachable over plain HTTPS at +`https://storage.googleapis.com/scperteval/processed/_processed_complete.h5ad`. + +## Run it + +```bash +# protocols by name — including parameterised ones (set k / padj per protocol) +scperteval run data/wessels23.h5ad -p pearson_ctrl,unbiased_mmd_median_pca_k=20,de_overlap_k=10 --de-method t-test + +# a parameterised protocol with no value uses its default (k=50, padj=0.05) +scperteval run data/wessels23.h5ad -p unbiased_mmd_median_top_k --de-method MWU + +# a whole group, or everything (parameterised protocols use their defaults) +scperteval run data/wessels23.h5ad -p distributional --de-method MWU +scperteval run data/wessels23.h5ad -p all --de-method t-test + +# DRF calibration only (compute DRF only; exclude BDS) +scperteval run data/wessels23.h5ad -p pearson_ctrl --de-method t-test --output drf + +# DE only — writes per-gene statistic + adjusted p to HDF5 (no protocol calibration) +# Provided as a convenience, since DE methods are tightly coupled with some evaluation protocols +scperteval de data/wessels23.h5ad --methods MWU + +# discover what's available +scperteval list protocols # also: de-methods | spaces | sources | calibrators +``` + +Each run prints a summary table and writes a per-perturbation CSV +`____drf.csv` (the raw control values and the calibrated score for +every perturbation). `--profile` adds a per-protocol wall-clock timing CSV. + +**DE backends** (`scperteval list de-methods`): `t-test` (default, Welch's, moment-based), +`MWU` (Cliff's δ via illico), and `t-test_overestim_var` (scanpy's conservative-variance +variant — the reference variance is scaled by the target's cell count). Select one with +`--de-method` for a `run`, or list several with `--methods` for a `de` export. The overestim +variant is a selectable backend for new protocols; no current protocol uses it. + +
scperteval run --help + +```text +usage: scperteval run [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] + [--subsample SUBSAMPLE] [--seed SEED] [--positive POSITIVE] + [--negative NEGATIVE] [--output {drf,bds}] [--out-dir OUT_DIR] + [--workers WORKERS] [--perturbation-key PERTURBATION_KEY] + [--control-label CONTROL_LABEL] [--min-cells MIN_CELLS] + [--profile] [--quiet] + dataset + + -p, --protocols comma-separated names (parameterised as name=value, e.g. + mse_top_k=30), a group (pseudobulk|distributional|de), or 'all' + --de-method {MWU, t-test, t-test_overestim_var} DE backend for every DE unit: + the interpolated positive control, the top_k/degs spaces, + the de_* protocols, and the WMSE weights + --subsample cells in the single-cell reference sample (default 8192) + --output {drf, bds} how per-perturbation values are calibrated + --positive/--negative override a protocol's controls by source name + --min-cells skip perturbations with fewer cells + --profile also write a per-protocol wall-clock timing table +``` + +
+ +## Use it from Python + +Install with `pip install scperteval` (or, from this repo, +`pip install "scperteval @ git+https://github.com/Virtual-Cell-Research-Community/scPertEval.git"`). +The simplest path mirrors the CLI — call it via subprocess, exactly as the figure notebook does: + +```python +import subprocess, sys + +subprocess.run([sys.executable, "-m", "scperteval", "run", "data/wessels23.h5ad", + "-p", "all", "--de-method", "t-test", "--out-dir", "results"], check=True) +# -> results/wessels23____drf.csv (raw control values + calibrated DRF per perturbation) +``` From ceb72fdd30ed640dcd2796d5d0274bb8cffaad55 Mon Sep 17 00:00:00 2001 From: HugoHakem Date: Mon, 29 Jun 2026 09:55:52 +0100 Subject: [PATCH 05/15] docs: replace scoring.md with calibration.md and improve user guide flow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename scoring.md → calibration.md and move it first in the user guide toctree, so readers encounter the conceptual framing (what a protocol is, what DRF and BDS measure) before the CLI usage page. Expand the page with proper math formalism: introduces s_bench notation, the two empirical controls, and the DRF/BDS formulas with inline interpretation. Add a single orienting sentence to user-guide/index.md. Update the cross-link in docs/index.md. Also include a pre-existing conf.py fix (bibtex_reference_style). Co-Authored-By: Claude Sonnet 4.6 --- docs/conf.py | 1 + docs/index.md | 2 +- docs/user-guide/calibration.md | 70 ++++++++++++++++++++++++++++++++++ docs/user-guide/index.md | 4 +- docs/user-guide/scoring.md | 36 ----------------- 5 files changed, 75 insertions(+), 38 deletions(-) create mode 100644 docs/user-guide/calibration.md delete mode 100644 docs/user-guide/scoring.md diff --git a/docs/conf.py b/docs/conf.py index 7469883..0d9561a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,6 +26,7 @@ release = info["Version"] bibtex_bibfiles = ["references.bib"] +bibtex_reference_style = "author_year" templates_path = ["_templates"] nitpicky = True needs_sphinx = "4.0" diff --git a/docs/index.md b/docs/index.md index 2ab513a..4686f1f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -16,7 +16,7 @@ and contribute back — a place for collaboration and alignment in the field. Ru specifying a dataset, one or more protocols, and a method of differential expression; the tool outputs calibration data: the **Dynamic Range Fraction (DRF)** and the **Bound Discrimination Score (BDS)** — quantifying how well the protocol separates real perturbation -signal from an uninformative baseline (see [How scoring works](user-guide/scoring.md)). +signal from an uninformative baseline (see [Calibration](user-guide/calibration.md)). ## Quick start diff --git a/docs/user-guide/calibration.md b/docs/user-guide/calibration.md new file mode 100644 index 0000000..72353d6 --- /dev/null +++ b/docs/user-guide/calibration.md @@ -0,0 +1,70 @@ +# Calibration + +scPertEval assesses each **evaluation protocol** — a representation $\phi$ paired with a +metric $d$ — for its **separability**: can it reliably distinguish a perturbation's true +response from an uninformative baseline? + +Let $s(\mathcal{X}, \mathcal{Y}) = d(\phi(\mathcal{X}), \phi(\mathcal{Y}))$ denote the +protocol-induced score, where **smaller values indicate better agreement** (similarity scores +are converted beforehand). + +## Controls + +For each perturbation $a$, every protocol is evaluated against two empirical controls: + +- **Positive control** $s_{\text{pos}}^{(a)}$ — the best realistic score: comparing the + observed cells against a **technical duplicate** (a held-out replicate); for pseudobulk + protocols, an **interpolated duplicate** is used for stability. +- **Negative control** $s_{\text{neg}}^{(a)}$ — an uninformative baseline: comparing against + the **all-perturbed reference excluding $a$** (full mean for pseudobulk; a subsample of + 8 192 cells by default for single-cell distances, configurable with `--subsample`). +Ideally $s_{\text{pos}}^{(a)} < s_{\text{neg}}^{(a)}$. + +## Dynamic Range Fraction (DRF) + +DRF asks: **how much of the available signal range does the protocol actually recover?** + +$$ +\operatorname{DRF}(a) += \frac{s_{\text{neg}}^{(a)} - s_{\text{pos}}^{(a)}}{s_{\text{neg}}^{(a)} - s_{\text{optim}} + \xi} +$$ + +where $s_{\text{optim}}$ is the protocol's ideal score (0 for distance metrics; set by the +`perfect` field in the protocol spec) and $\xi > 0$ is a small stabilising constant. +The numerator is the recovered gap (how much the positive control beats the negative); +the denominator is the total available dynamic range from baseline down to optimal. + +| $\operatorname{DRF}(a)$ | Meaning | +|---|---| +| $= 1$ | positive control achieves the optimal score | +| $= 0$ | positive and negative controls score equally | +| $< 0$ | positive control performs *worse* than the uninformative baseline | + +`--output drf` reports the mean/median of $\operatorname{DRF}(a)$ across perturbations. +Introduced by {cite}`Miller_2025`. + +## Bound Discrimination Score (BDS) + +BDS asks a simpler, binary question: **for what fraction of perturbations does the protocol +get the ordering right?** + +$$ +\operatorname{BDS} += \frac{1}{|\mathcal{P}|} + \sum_{a \in \mathcal{P}} + \mathbf{1}\!\left[s_{\text{pos}}^{(a)} < s_{\text{neg}}^{(a)}\right] +$$ + +It records whether the positive control beats the negative, but not by how much. +A protocol with low BDS cannot distinguish a technical replicate from a random reference; +its scores should not be trusted regardless of their magnitude. + +`--output bds` reports this fraction. Introduced by {cite}`Vollenweider_2026`. + +## DRF vs BDS + +The two scores are complementary. BDS checks the **sign** — does +$s_{\text{pos}}^{(a)} < s_{\text{neg}}^{(a)}$? DRF checks the **magnitude** — how far along +the full dynamic range is that gap? A protocol can have high BDS (ordering consistently +correct) yet low DRF (margin negligible relative to what is achievable). Use both together: +BDS as a pass/fail gate on directionality, DRF as a quantitative measure of signal recovery. diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index a5a568f..a34b3f3 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -1,10 +1,12 @@ # User guide +Start with [Calibration](calibration) to understand what scPertEval measures before diving into usage. + ```{toctree} :maxdepth: 1 +calibration usage -scoring protocols building-blocks ``` diff --git a/docs/user-guide/scoring.md b/docs/user-guide/scoring.md deleted file mode 100644 index f23e394..0000000 --- a/docs/user-guide/scoring.md +++ /dev/null @@ -1,36 +0,0 @@ -# How scoring works (the calibration) - -scPertEval's claim — a usable catalog of protocols — rests on **calibrating** each protocol -against two empirical controls per perturbation, so you can see whether a metric actually -separates signal from baseline rather than read a raw, uninterpretable number. - -- **positive control** — the best realistic candidate: the **technical duplicate** (a - held-out replicate) for single-cell protocols, the **interpolated duplicate** for pseudobulk. -- **negative control** — an uninformative baseline: the **all-perturbed reference, - excluding the target perturbation** (a full-resolution mean for pseudobulk; an 8192-cell - subsample for single-cell distances). - -## Dynamic Range Fraction (DRF) - -Where the protocol's value sits between the negative control (floor) and the perfect score, -anchored by the positive control: - -```text -DRF = (positive − negative) / (perfect − negative) # per perturbation, clipped to [-1, 1] -``` - -`--output drf` reports the mean/median across perturbations. High DRF means the protocol -discriminates real signal; near zero means it doesn't. Introduced by {cite:t}`Miller_2025`. - -## Bound Discrimination Score (BDS) - -The fraction of perturbations for which the positive control beats the negative control -under this protocol: - -```text -BDS = fraction of perturbations where positive control beats negative control # in [0, 1] -``` - -`--output bds` reports this fraction. It's a sensitivity check: a protocol with low BDS -can't even tell a technical replicate from an uninformative baseline, so its scores -shouldn't be trusted. Introduced by {cite:t}`Vollenweider_2026`. From c6149a7476685fde095a285157689b90155f71d5 Mon Sep 17 00:00:00 2001 From: HugoHakem Date: Mon, 29 Jun 2026 18:32:46 -0700 Subject: [PATCH 06/15] docs: streamline API reference pages --- docs/api.md | 89 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 81 insertions(+), 8 deletions(-) diff --git a/docs/api.md b/docs/api.md index 0b0dabc..ff5acac 100644 --- a/docs/api.md +++ b/docs/api.md @@ -12,6 +12,8 @@ RunConfig Protocol Calibrator + DEResult + Param ``` ## Runner @@ -28,9 +30,13 @@ ## Protocols -`scperteval.protocols.TABLE` — list of all `Protocol` objects. -`scperteval.protocols.PROTOCOLS` — `{name: Protocol}` dict. -`scperteval.protocols.GROUPS` — sorted list of group names. +- `scperteval.protocols.TABLE` — list of all `Protocol` objects. +- `scperteval.protocols.PROTOCOLS` — `{name: Protocol}` dict. +- `scperteval.protocols.GROUPS` — sorted list of group names. + +```{eval-rst} +.. protocol-table:: +``` ### Metrics @@ -38,20 +44,35 @@ .. module:: scperteval.protocols.metrics .. currentmodule:: scperteval.protocols.metrics +.. automodule:: scperteval.protocols.metrics + :no-members: + :no-index: + .. autosummary:: :toctree: generated + + pearson + mse + weighted_mse + energy_distance + unbiased_mmd_median + sinkhorn_w2 + rank_retrieval + de_auprc + de_auroc + de_overlap ``` ## Calibrators ```{eval-rst} -.. module:: scperteval.calibrators -.. currentmodule:: scperteval.calibrators - -.. autosummary:: - :toctree: generated +.. automodule:: scperteval.calibrators + :no-members: ``` +`scperteval.calibrators.CALIBRATORS` — `{name: Calibrator}` dict of built-in calibrators (`drf`, `bds`). +Add entries here to register a new calibrator; see [Add a calibrator](user-guide/building-blocks.md#add-a-calibrator). + ## Building blocks ### Differential expression @@ -60,8 +81,20 @@ .. module:: scperteval.blocks.de .. currentmodule:: scperteval.blocks.de +.. automodule:: scperteval.blocks.de + :no-members: + :no-index: + .. autosummary:: :toctree: generated + + DE_METHODS + moments + bh + ttest_from_moments + de_ttest + de_ttest_overestim + de_mwu ``` ### Feature spaces @@ -70,8 +103,40 @@ .. module:: scperteval.blocks.spaces .. currentmodule:: scperteval.blocks.spaces +.. automodule:: scperteval.blocks.spaces + :no-members: + :no-index: + +.. autosummary:: + :toctree: generated + + SPACES + register_de_space + top_space + pca_space + degs_space +``` + +### Control sources + +```{eval-rst} +.. automodule:: scperteval.sources + :no-members: +``` + +`scperteval.sources.SOURCES` — registry of all control/reference sources. +Add entries here to register a new source; see [Add a control source](user-guide/building-blocks.md#add-a-control-source). + +## Context + +```{eval-rst} +.. module:: scperteval.context +.. currentmodule:: scperteval.context + .. autosummary:: :toctree: generated + + Context ``` ## Registry @@ -94,6 +159,9 @@ .. autosummary:: :toctree: generated + + Dataset + to_dense ``` ```{eval-rst} @@ -102,4 +170,9 @@ .. autosummary:: :toctree: generated + + print_summary + write_rows + write_timing + write_de ``` From 8cc46ff0918f3ae6c03d52d672b7cd6d4a3c597d Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 29 Jun 2026 18:33:30 -0700 Subject: [PATCH 07/15] chore: move package to src/ layout (carry score-mode code) --- {scperteval => src/scperteval}/__init__.py | 0 {scperteval => src/scperteval}/__main__.py | 0 {scperteval => src/scperteval}/blocks/__init__.py | 0 {scperteval => src/scperteval}/blocks/de.py | 0 {scperteval => src/scperteval}/blocks/spaces.py | 0 {scperteval => src/scperteval}/calibrators.py | 0 {scperteval => src/scperteval}/cli.py | 0 {scperteval => src/scperteval}/context.py | 0 {scperteval => src/scperteval}/dataset.py | 0 {scperteval => src/scperteval}/io.py | 0 {scperteval => src/scperteval}/predictions.py | 0 {scperteval => src/scperteval}/protocols/__init__.py | 0 {scperteval => src/scperteval}/protocols/metrics.py | 0 {scperteval => src/scperteval}/protocols/table.py | 0 {scperteval => src/scperteval}/reference.py | 0 {scperteval => src/scperteval}/registry.py | 0 {scperteval => src/scperteval}/runner.py | 0 {scperteval => src/scperteval}/sources.py | 0 {scperteval => src/scperteval}/types.py | 0 19 files changed, 0 insertions(+), 0 deletions(-) rename {scperteval => src/scperteval}/__init__.py (100%) rename {scperteval => src/scperteval}/__main__.py (100%) rename {scperteval => src/scperteval}/blocks/__init__.py (100%) rename {scperteval => src/scperteval}/blocks/de.py (100%) rename {scperteval => src/scperteval}/blocks/spaces.py (100%) rename {scperteval => src/scperteval}/calibrators.py (100%) rename {scperteval => src/scperteval}/cli.py (100%) rename {scperteval => src/scperteval}/context.py (100%) rename {scperteval => src/scperteval}/dataset.py (100%) rename {scperteval => src/scperteval}/io.py (100%) rename {scperteval => src/scperteval}/predictions.py (100%) rename {scperteval => src/scperteval}/protocols/__init__.py (100%) rename {scperteval => src/scperteval}/protocols/metrics.py (100%) rename {scperteval => src/scperteval}/protocols/table.py (100%) rename {scperteval => src/scperteval}/reference.py (100%) rename {scperteval => src/scperteval}/registry.py (100%) rename {scperteval => src/scperteval}/runner.py (100%) rename {scperteval => src/scperteval}/sources.py (100%) rename {scperteval => src/scperteval}/types.py (100%) diff --git a/scperteval/__init__.py b/src/scperteval/__init__.py similarity index 100% rename from scperteval/__init__.py rename to src/scperteval/__init__.py diff --git a/scperteval/__main__.py b/src/scperteval/__main__.py similarity index 100% rename from scperteval/__main__.py rename to src/scperteval/__main__.py diff --git a/scperteval/blocks/__init__.py b/src/scperteval/blocks/__init__.py similarity index 100% rename from scperteval/blocks/__init__.py rename to src/scperteval/blocks/__init__.py diff --git a/scperteval/blocks/de.py b/src/scperteval/blocks/de.py similarity index 100% rename from scperteval/blocks/de.py rename to src/scperteval/blocks/de.py diff --git a/scperteval/blocks/spaces.py b/src/scperteval/blocks/spaces.py similarity index 100% rename from scperteval/blocks/spaces.py rename to src/scperteval/blocks/spaces.py diff --git a/scperteval/calibrators.py b/src/scperteval/calibrators.py similarity index 100% rename from scperteval/calibrators.py rename to src/scperteval/calibrators.py diff --git a/scperteval/cli.py b/src/scperteval/cli.py similarity index 100% rename from scperteval/cli.py rename to src/scperteval/cli.py diff --git a/scperteval/context.py b/src/scperteval/context.py similarity index 100% rename from scperteval/context.py rename to src/scperteval/context.py diff --git a/scperteval/dataset.py b/src/scperteval/dataset.py similarity index 100% rename from scperteval/dataset.py rename to src/scperteval/dataset.py diff --git a/scperteval/io.py b/src/scperteval/io.py similarity index 100% rename from scperteval/io.py rename to src/scperteval/io.py diff --git a/scperteval/predictions.py b/src/scperteval/predictions.py similarity index 100% rename from scperteval/predictions.py rename to src/scperteval/predictions.py diff --git a/scperteval/protocols/__init__.py b/src/scperteval/protocols/__init__.py similarity index 100% rename from scperteval/protocols/__init__.py rename to src/scperteval/protocols/__init__.py diff --git a/scperteval/protocols/metrics.py b/src/scperteval/protocols/metrics.py similarity index 100% rename from scperteval/protocols/metrics.py rename to src/scperteval/protocols/metrics.py diff --git a/scperteval/protocols/table.py b/src/scperteval/protocols/table.py similarity index 100% rename from scperteval/protocols/table.py rename to src/scperteval/protocols/table.py diff --git a/scperteval/reference.py b/src/scperteval/reference.py similarity index 100% rename from scperteval/reference.py rename to src/scperteval/reference.py diff --git a/scperteval/registry.py b/src/scperteval/registry.py similarity index 100% rename from scperteval/registry.py rename to src/scperteval/registry.py diff --git a/scperteval/runner.py b/src/scperteval/runner.py similarity index 100% rename from scperteval/runner.py rename to src/scperteval/runner.py diff --git a/scperteval/sources.py b/src/scperteval/sources.py similarity index 100% rename from scperteval/sources.py rename to src/scperteval/sources.py diff --git a/scperteval/types.py b/src/scperteval/types.py similarity index 100% rename from scperteval/types.py rename to src/scperteval/types.py From 4358c71b37cb41020b007b338059bd2dc87a1f17 Mon Sep 17 00:00:00 2001 From: HugoHakem Date: Mon, 29 Jun 2026 18:36:47 -0700 Subject: [PATCH 08/15] chore(ruff,docstrings): lint-clean + numpy docstrings for unchanged modules --- src/scperteval/__init__.py | 10 +- src/scperteval/blocks/__init__.py | 3 +- src/scperteval/blocks/de.py | 102 +++++++++- src/scperteval/dataset.py | 11 +- src/scperteval/io.py | 27 ++- src/scperteval/protocols/__init__.py | 3 +- src/scperteval/protocols/metrics.py | 274 +++++++++++++++++++++++++-- src/scperteval/protocols/table.py | 49 +++-- src/scperteval/reference.py | 9 +- src/scperteval/registry.py | 31 ++- 10 files changed, 460 insertions(+), 59 deletions(-) diff --git a/src/scperteval/__init__.py b/src/scperteval/__init__.py index a320ab4..2a196bc 100644 --- a/src/scperteval/__init__.py +++ b/src/scperteval/__init__.py @@ -1,6 +1,12 @@ """Evaluation Protocols for Perturbation Studies.""" + import os as _os -for _v in ("OMP_NUM_THREADS", "OPENBLAS_NUM_THREADS", "MKL_NUM_THREADS", - "NUMEXPR_NUM_THREADS", "VECLIB_MAXIMUM_THREADS"): +for _v in ( + "OMP_NUM_THREADS", + "OPENBLAS_NUM_THREADS", + "MKL_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + "VECLIB_MAXIMUM_THREADS", +): _os.environ.setdefault(_v, "1") diff --git a/src/scperteval/blocks/__init__.py b/src/scperteval/blocks/__init__.py index 404db76..162dacd 100644 --- a/src/scperteval/blocks/__init__.py +++ b/src/scperteval/blocks/__init__.py @@ -1,2 +1,3 @@ """Pluggable building blocks: DE methods and feature spaces.""" -from . import de, spaces # noqa: F401 (import for registration side effects) + +from . import de, spaces diff --git a/src/scperteval/blocks/de.py b/src/scperteval/blocks/de.py index f4518ed..204fc80 100644 --- a/src/scperteval/blocks/de.py +++ b/src/scperteval/blocks/de.py @@ -4,6 +4,7 @@ expressed through reusable ``moments`` / ``ttest_from_moments`` helpers so the context can cache a shared reference's moments and combine them cheaply. """ + from __future__ import annotations import numpy as np @@ -14,10 +15,43 @@ from ..types import DEResult DE_METHODS = Registry("de-method") +"""Registry of DE backends; keys are the ``--de-method`` names. + +Use :meth:`~scperteval.registry.Registry.register` to add a custom backend:: + + from scperteval.blocks.de import DE_METHODS, bh + from scperteval.types import DEResult + + @DE_METHODS.register("my_test", description="My custom DE test") + def de_my_test(target, reference): + score = ... # per-gene statistic, shape (G,) + pvalue = ... # per-gene raw p-values, shape (G,) + return DEResult(score=score, pvalue=pvalue, pvalue_adj=bh(pvalue)) + +Then ``--de-method my_test`` routes every DE-dependent unit through it. +""" def moments(X): - """Per-gene (mean, sample variance, n) for a cell matrix, sparse- or dense-aware.""" + """Per-gene mean, sample variance, and cell count for a cell matrix. + + Sparse- and dense-aware; uses :math:`\\text{Var}(X) = E[X^2] - E[X]^2` + with Bessel's correction. + + Parameters + ---------- + X : array-like, shape ``(n, G)`` + Cell matrix (sparse or dense). + + Returns + ------- + mean : numpy.ndarray, shape ``(G,)`` + Per-gene sample mean. + variance : numpy.ndarray, shape ``(G,)`` + Per-gene sample variance (floored at 0). + n : int + Number of cells. + """ n = X.shape[0] if sp.issparse(X): m = np.asarray(X.mean(0)).ravel() @@ -31,7 +65,22 @@ def moments(X): def bh(pvalue: np.ndarray) -> np.ndarray: - """Benjamini-Hochberg adjusted p-values.""" + """Benjamini-Hochberg adjusted p-values for FDR control :cite:p:`Vollenweider_2026`. + + Applied gene-wise inside each DE method to control the false discovery rate + across genes. The same procedure is used across perturbations to summarise + the overall sensitivity of a metric — see :cite:t:`Vollenweider_2026`. + + Parameters + ---------- + pvalue : numpy.ndarray + Array of raw p-values; non-finite values are carried through as ``nan``. + + Returns + ------- + numpy.ndarray + BH-adjusted p-values clipped to [0, 1]; same shape as ``pvalue``. + """ p = np.asarray(pvalue, dtype=np.float64) out = np.full(p.shape, np.nan) idx = np.where(np.isfinite(p))[0] @@ -44,11 +93,36 @@ def bh(pvalue: np.ndarray) -> np.ndarray: def ttest_from_moments(mt, vt, nt, mr, vr, nr) -> DEResult: - """Welch's t-test (scanpy convention); score = t-statistic.""" + """Welch's t-test from pre-computed per-gene moments (scanpy convention). + + Accepts moments directly so the context can cache the reference's moments + once and combine them cheaply for every perturbation. The ``score`` field + of the returned :class:`~scperteval.types.DEResult` is the t-statistic. + + Parameters + ---------- + mt : numpy.ndarray, shape ``(G,)`` + Target per-gene means. + vt : numpy.ndarray, shape ``(G,)`` + Target per-gene sample variances. + nt : int + Number of target cells. + mr : numpy.ndarray, shape ``(G,)`` + Reference per-gene means. + vr : numpy.ndarray, shape ``(G,)`` + Reference per-gene sample variances. + nr : int + Number of reference cells. + + Returns + ------- + ~scperteval.types.DEResult + ``score`` is the Welch t-statistic; ``pvalue_adj`` is BH-adjusted. + """ se2 = vt / nt + vr / nr with np.errstate(divide="ignore", invalid="ignore"): t = (mt - mr) / np.sqrt(se2) - df = se2 ** 2 / ((vt / nt) ** 2 / max(nt - 1, 1) + (vr / nr) ** 2 / max(nr - 1, 1)) + df = se2**2 / ((vt / nt) ** 2 / max(nt - 1, 1) + (vr / nr) ** 2 / max(nr - 1, 1)) t = np.nan_to_num(t, nan=0.0, posinf=0.0, neginf=0.0) df = np.where(np.isfinite(df) & (df > 0), df, 1.0) pval = np.nan_to_num(2.0 * stats.t.sf(np.abs(t), df), nan=1.0) @@ -57,6 +131,7 @@ def ttest_from_moments(mt, vt, nt, mr, vr, nr) -> DEResult: @DE_METHODS.register("t-test", description="Welch's t-test (default) — moment-based and fast") def de_ttest(target, reference) -> DEResult: + """Welch's t-test between target and reference cell matrices.""" return ttest_from_moments(*moments(target), *moments(reference)) @@ -71,8 +146,7 @@ def de_ttest_overestim(target, reference) -> DEResult: Identical to Welch's t-test except the reference group's cell count is replaced by the target's, which inflates the reference standard-error term ("overestimating" its variance for small target groups) and yields a more conservative statistic. Selectable as a DE - backend (``--de-method``/``--methods``) so new evaluation protocols can use it; no current - protocol does. + backend (``--de-method``/``--methods``); no current protocol uses it. """ mt, vt, nt = moments(target) mr, vr, _nr = moments(reference) @@ -97,10 +171,18 @@ def de_mwu(target, reference) -> DEResult: adata = ad.AnnData(np.vstack([Xt, Xr]).astype(np.float64)) adata.var_names = genes adata.obs["_g"] = ["target"] * nt + ["reference"] * nr - df = asymptotic_wilcoxon(adata, is_log1p=True, group_keys="_g", reference="reference", - n_threads=1, alternative="two-sided", use_continuity=True, - tie_correct=True, return_as_scanpy=False) - sub = df.xs("target", level=0).reindex(genes) + df = asymptotic_wilcoxon( + adata, + is_log1p=True, + group_keys="_g", + reference="reference", + n_threads=1, + alternative="two-sided", + use_continuity=True, + tie_correct=True, + return_as_scanpy=False, + ) + sub = df.xs("target", level=0).reindex(genes) # pyright: ignore[reportAttributeAccessIssue] u = sub["statistic"].to_numpy(dtype=np.float64) pval = np.nan_to_num(sub["p_value"].to_numpy(dtype=np.float64), nan=1.0) cliff = 2.0 * u / (nt * nr) - 1.0 diff --git a/src/scperteval/dataset.py b/src/scperteval/dataset.py index fa8e307..8b7e742 100644 --- a/src/scperteval/dataset.py +++ b/src/scperteval/dataset.py @@ -1,4 +1,5 @@ """Thin wrapper over a preprocessed AnnData with a perturbation column.""" + from __future__ import annotations import zlib @@ -11,7 +12,7 @@ def _seed(seed: int, *tags) -> np.random.Generator: - key = (seed,) + tuple(zlib.crc32(str(t).encode()) for t in tags) + key = (seed, *(zlib.crc32(str(t).encode()) for t in tags)) return np.random.default_rng(np.array(key, dtype=np.uint32)) @@ -27,7 +28,7 @@ def __init__(self, adata, cfg: RunConfig): self._index() @classmethod - def load(cls, path: str, cfg: RunConfig) -> "Dataset": + def load(cls, path: str, cfg: RunConfig) -> Dataset: return cls(ad.read_h5ad(path), cfg) def _index(self): @@ -63,7 +64,8 @@ def control_cells(self, cap: int) -> np.ndarray: def all_perturbed_indices(self, cap: int) -> np.ndarray: """One all-perturbed subsample (the reference sample, shared across perturbations). - The "pool" tag is a fixed reproducibility salt for the draw, not a public name.""" + The "pool" tag is a fixed reproducibility salt for the draw, not a public name. + """ return self._cap(np.where(self.pert != self.cfg.control_label)[0], cap, "pool") def allpert_mean_except(self, pert: str) -> np.ndarray: @@ -72,7 +74,8 @@ def allpert_mean_except(self, pert: str) -> np.ndarray: def allpert_mean(self) -> np.ndarray: """Mean of all per-perturbation means (no target exclusion); a single vector - shared across perturbations, used as the cross-perturbation ranking baseline.""" + shared across perturbations, used as the cross-perturbation ranking baseline. + """ return self._mean_sum / max(len(self.perturbations), 1) def control_mean(self) -> np.ndarray: diff --git a/src/scperteval/io.py b/src/scperteval/io.py index b0c53cd..03b8fc5 100644 --- a/src/scperteval/io.py +++ b/src/scperteval/io.py @@ -1,4 +1,5 @@ """Human-readable summary plus a per-perturbation CSV named with dataset + time.""" + from __future__ import annotations from pathlib import Path @@ -8,6 +9,7 @@ def print_summary(cfg, aggregates: dict, calibrator, protocols) -> None: + """Print a formatted table of aggregate scores for every protocol.""" name = Path(cfg.dataset).stem print(f"\n{name} · {cfg.de_method} · subsample={cfg.subsample} · seed={cfg.seed} · output={cfg.output}\n") agg_keys = sorted({k for v in aggregates.values() for k in v}) @@ -22,11 +24,16 @@ def print_summary(cfg, aggregates: dict, calibrator, protocols) -> None: def write_rows(cfg, rows: list, timestamp: str) -> Path: + """Write per-perturbation rows (raw controls + calibrated score) to a timestamped CSV.""" out_dir = Path(cfg.out_dir) out_dir.mkdir(parents=True, exist_ok=True) df = pd.DataFrame(rows) - for col, val in (("dataset", Path(cfg.dataset).stem), ("de_method", cfg.de_method), - ("subsample", cfg.subsample), ("seed", cfg.seed)): + for col, val in ( + ("dataset", Path(cfg.dataset).stem), + ("de_method", cfg.de_method), + ("subsample", cfg.subsample), + ("seed", cfg.seed), + ): df[col] = val path = out_dir / f"{Path(cfg.dataset).stem}__{timestamp}__{cfg.output}.csv" df.to_csv(path, index=False) @@ -37,9 +44,16 @@ def write_timing(cfg, timed: list, timestamp: str) -> Path: """Write per-protocol wall-clock seconds (one row per protocol).""" out_dir = Path(cfg.out_dir) out_dir.mkdir(parents=True, exist_ok=True) - rows = [{"dataset": Path(cfg.dataset).stem, "protocol": p.name, - "representation": p.representation, "space": p.space, "seconds": seconds} - for p, seconds in timed] + rows = [ + { + "dataset": Path(cfg.dataset).stem, + "protocol": p.name, + "representation": p.representation, + "space": p.space, + "seconds": seconds, + } + for p, seconds in timed + ] path = out_dir / f"{Path(cfg.dataset).stem}__{timestamp}__timing.csv" pd.DataFrame(rows).to_csv(path, index=False) return path @@ -49,7 +63,8 @@ def write_de(cfg, genes, perturbations, results: dict, timestamp: str) -> Path: """Write per-gene DE (statistic + adjusted p) per method to an HDF5 file. Layout: ``genes``, ``perturbations``, and one group per method holding - ``statistic`` and ``pvalue_adj`` matrices (perturbations x genes).""" + ``statistic`` and ``pvalue_adj`` matrices (perturbations x genes). + """ import h5py out_dir = Path(cfg.out_dir) diff --git a/src/scperteval/protocols/__init__.py b/src/scperteval/protocols/__init__.py index ed1c0e8..2f351c2 100644 --- a/src/scperteval/protocols/__init__.py +++ b/src/scperteval/protocols/__init__.py @@ -1,2 +1,3 @@ """Evaluation protocols: pure metrics plus the declarative protocol table.""" -from .table import GROUPS, PROTOCOLS, TABLE # noqa: F401 + +from .table import GROUPS, PROTOCOLS, TABLE diff --git a/src/scperteval/protocols/metrics.py b/src/scperteval/protocols/metrics.py index 58515e7..89dc766 100644 --- a/src/scperteval/protocols/metrics.py +++ b/src/scperteval/protocols/metrics.py @@ -1,9 +1,9 @@ -"""Evaluation-protocol metrics — the exact implementation of every metric. +r"""Evaluation-protocol metrics — the exact implementation of every metric. A metric takes the ground-truth and a prediction (whichever control is being scored) plus the context, and returns a score. The protocol's ``representation`` sets each datapoint's shape — ``centroid`` -> a 1-D pseudobulk vector, ``population`` -> a (cells x genes) array, -``de`` -> a DEResult (GT) / |score| ranking (prediction). Its ``scope`` sets the call: a +``de`` -> a DEResult (GT) / \|score\| ranking (prediction). Its ``scope`` sets the call: a ``perturbation``-scope metric gets one perturbation's (gt, prediction) and returns a scalar; a ``dataset``-scope metric gets the list of every perturbation's gt and prediction and returns one score per perturbation (e.g. ``rank_retrieval``). @@ -12,12 +12,78 @@ scikit-learn, geomloss) are relied upon. So a metric is completely defined by its function below plus its row in ``table.py`` — nothing is hidden behind another layer. """ + from __future__ import annotations import numpy as np from sklearn.metrics import average_precision_score, roc_auc_score +# --- shared parameter blocks, substituted into docstrings at decoration time --- + +_CENTROID = """\ +gt : numpy.ndarray + Ground-truth pseudobulk profile, shape ``(G,)``. +prediction : numpy.ndarray + Predicted pseudobulk profile, shape ``(G,)``. +ctx : Context + Unused; present for signature compatibility.""" + +_CENTROID_W = """\ +gt : numpy.ndarray + Ground-truth pseudobulk profile, shape ``(G,)``. +prediction : numpy.ndarray + Predicted pseudobulk profile, shape ``(G,)``. +ctx : Context + Provides per-gene WMSE weights via ``ctx.wmse_weights``.""" + +_POPULATION = """\ +gt : numpy.ndarray + Ground-truth cell matrix, shape ``(n, G)``. +prediction : numpy.ndarray + Predicted cell matrix, shape ``(m, G)``. +ctx : Context + Unused; present for signature compatibility.""" + +_DATASET = """\ +gt : list of numpy.ndarray + Ground-truth centroids, one per perturbation, each shape ``(G,)``. +prediction : list of numpy.ndarray + Predicted centroids, one per perturbation, each shape ``(G,)``. +ctx : Context + Unused; present for signature compatibility.""" + +_DE = """\ +gt : ~scperteval.types.DEResult + Ground-truth DE result; ``gt.pvalue_adj`` defines the positive class. +prediction : numpy.ndarray + Per-gene absolute DE score ranking from the candidate source, shape ``(G,)``. +ctx : Context + Unused; present for signature compatibility.""" + + +def _doc(**subs): + """Decorator that substitutes %(key)s placeholders, propagating surrounding indentation. + + Python's ``%`` substitution only indents the first line of a multi-line value. + This decorator detects the column position of each placeholder and re-indents all + continuation lines to match, so the substituted text stays inside the RST section. + """ + def deco(fn): + doc = fn.__doc__ + for key, value in subs.items(): + placeholder = f"%({key})s" + while placeholder in doc: + idx = doc.index(placeholder) + line_start = doc.rfind('\n', 0, idx) + 1 + indent = ' ' * (idx - line_start) + indented = ('\n' + indent).join(value.split('\n')) + doc = doc[:idx] + indented + doc[idx + len(placeholder):] + fn.__doc__ = doc + return fn + return deco + + def _sq_dists(X, Y): """Pairwise squared euclidean distances via ||x||^2 + ||y||^2 - 2 x.y. @@ -37,23 +103,98 @@ def _within_unbiased(sq, n): return float(np.sqrt(sq).sum() / (n * (n - 1))) +@_doc(params=_CENTROID) def pearson(gt, prediction, ctx): + """Pearson correlation between pseudobulk profiles. + + .. math:: + + r = \\frac{\\sum_g (gt_g - \\bar{gt})(pred_g - \\bar{pred})}{ + \\sqrt{\\sum_g (gt_g - \\bar{gt})^2 \\cdot \\sum_g (pred_g - \\bar{pred})^2}} + + Parameters + ---------- + %(params)s + + Returns + ------- + float + Pearson r in [-1, 1]; 1 is perfect. + """ return float(np.corrcoef(gt, prediction)[0, 1]) +@_doc(params=_CENTROID) def mse(gt, prediction, ctx): + """Mean squared error between pseudobulk profiles. + + .. math:: + + \\text{MSE} = \\frac{1}{G}\\sum_{g=1}^G (gt_g - pred_g)^2 + + Parameters + ---------- + %(params)s + + Returns + ------- + float + Non-negative MSE; 0 is perfect. + """ return float(np.mean((gt - prediction) ** 2)) +@_doc(params=_CENTROID_W) def weighted_mse(gt, prediction, ctx, exp=2.0): + """MSE weighted by ground-truth effect size raised to ``exp``. + + Weights are min-max normalised per-gene; high-effect genes contribute more. + + .. math:: + + \\text{wMSE} = \\sum_g w_g \\,(gt_g - pred_g)^2, \\quad + w_g \\propto |s_g|^{\\text{exp}} / \\sum_{g'} |s_{g'}|^{\\text{exp}} + + where :math:`s_g` is the ground-truth DE t-statistic for gene :math:`g`. + + Parameters + ---------- + %(params)s + exp : float + Exponent applied to the effect-size weights (default 2.0). + + Returns + ------- + float + Non-negative weighted MSE; 0 is perfect. + """ w = ctx.wmse_weights(ctx.current_pert) ** exp total = w.sum() w = w / total if total > 0 else np.full(w.size, 1.0 / w.size) return float(np.sum(w * (gt - prediction) ** 2)) +@_doc(params=_POPULATION) def energy_distance(gt, prediction, ctx): - """Szekely-Rizzo energy distance with bias-corrected within terms.""" + """Székely–Rizzo energy distance with bias-corrected within-population terms. + + .. math:: + + E(X, Y) = 2\\,\\mathbb{E}[\\|X - Y\\|] + - \\mathbb{E}[\\|X - X'\\|] - \\mathbb{E}[\\|Y - Y'\\|] + + Within-population terms use the unbiased (U-statistic) estimator. + + Parameters + ---------- + %(params)s + + Returns + ------- + float + Energy distance >= 0; 0 is perfect (identical distributions). + Returns ``nan`` if either population is empty. + """ if len(gt) == 0 or len(prediction) == 0: return float("nan") X = gt.astype(np.float64) @@ -64,8 +205,30 @@ def energy_distance(gt, prediction, ctx): return float(2.0 * cross - xx - yy) +@_doc(params=_POPULATION) def unbiased_mmd_median(gt, prediction, ctx): - """Unbiased RBF-MMD^2 with a single median-heuristic bandwidth (Gretton 2012).""" + """Unbiased RBF-MMD² with median-heuristic bandwidth (Gretton 2012). + + .. math:: + + \\widehat{\\text{MMD}}^2(X, Y) + = \\frac{1}{n(n-1)} \\sum_{i \\neq j} k(x_i, x_j) + + \\frac{1}{m(m-1)} \\sum_{i \\neq j} k(y_i, y_j) + - \\frac{2}{nm} \\sum_{i,j} k(x_i, y_j) + + with :math:`k(x,y) = \\exp(-\\|x-y\\|^2 / 2\\sigma^2)` and :math:`\\sigma` the median + pairwise Euclidean distance over the pooled sample. + + Parameters + ---------- + %(params)s + + Returns + ------- + float + MMD² (may be slightly negative due to estimation variance); 0 is perfect. + Returns ``nan`` if either population has fewer than 2 cells. + """ if len(gt) < 2 or len(prediction) < 2: return float("nan") X = gt.astype(np.float64) @@ -89,8 +252,29 @@ def unbiased_mmd_median(gt, prediction, ctx): _geomloss_cache: dict = {} +@_doc(params=_POPULATION) def sinkhorn_w2(gt, prediction, ctx, blur=0.05): - """Debiased Sinkhorn 2-Wasserstein distance (geomloss, p=2).""" + """Debiased Sinkhorn 2-Wasserstein distance (geomloss, p=2). + + .. math:: + + W_2(X, Y) = \\sqrt{2\\,S_\\varepsilon(X, Y)} + + where :math:`S_\\varepsilon` is the debiased Sinkhorn divergence with blur + :math:`\\varepsilon`. Requires ``geomloss`` and ``torch``. + + Parameters + ---------- + %(params)s + blur : float + Sinkhorn entropic regularisation parameter (default 0.05). + + Returns + ------- + float + W2 distance >= 0; 0 is perfect. + Returns ``nan`` if either population is empty. + """ if len(gt) == 0 or len(prediction) == 0: return float("nan") import torch @@ -110,14 +294,32 @@ def sinkhorn_w2(gt, prediction, ctx, blur=0.05): return float(np.sqrt(max(2.0 * val, 0.0))) +@_doc(params=_DATASET) def rank_retrieval(gt, prediction, ctx, transpose=False): - """Cross-perturbation retrieval rank (0 = best, lower is better) — a dataset-scope metric. + """Cross-perturbation retrieval rank — dataset-scope metric, lower is better. + + Builds the ``(n x n)`` squared-distance matrix between all predicted and ground-truth + centroids, then reads off the diagonal rank (column-wise by default). + + .. math:: + + \\text{rank}(a) = \\frac{\\text{rank}_{\\text{col}}(D_{aa})}{n - 1}, \\quad + D_{ij} = \\|P_i - G_j\\|^2 + + where :math:`P_i` and :math:`G_j` are the predicted and ground-truth centroids. + ``transpose_rank`` transposes the matrix first (each prediction ranked among all GTs). + Tie-breaking noise (seed 42) matches the DRF calibration convention. - ``gt`` and ``prediction`` are the lists of every perturbation's centroid (one per - perturbation); this returns one score per perturbation. In the prediction-vs-GT - squared-distance matrix, ``rank`` ranks each GT's own prediction against all predictions - (column-wise); ``transpose_rank`` ranks each prediction's own GT against all GTs - (row-wise). Normalised by n-1, with the drf tie-breaking noise (seed 42). + Parameters + ---------- + %(params)s + transpose : bool + If ``True``, rank row-wise (each prediction vs all GTs) instead of column-wise. + + Returns + ------- + np.ndarray + Per-perturbation normalised rank in [0, 1]; 0 is a perfect top-1 retrieval. """ G = np.vstack(gt) P = np.vstack(prediction) @@ -130,21 +332,71 @@ def rank_retrieval(gt, prediction, ctx, transpose=False): return np.diag(ranks).astype(np.float64) / max(n - 1, 1) +@_doc(params=_DE) def de_auprc(gt, prediction, ctx): + """Area under the precision-recall curve for DEG recovery. + + Positive class: ground-truth DEGs with ``gt.pvalue_adj < 0.05``. + + Parameters + ---------- + %(params)s + + Returns + ------- + float + AUPRC in [0, 1]; higher is better. + Returns ``nan`` if all genes fall in the same class. + """ labels = gt.pvalue_adj < 0.05 if labels.sum() == 0 or labels.sum() == labels.size: return float("nan") return float(average_precision_score(labels, prediction)) +@_doc(params=_DE) def de_auroc(gt, prediction, ctx): + """Area under the ROC curve for DEG recovery. + + Positive class: ground-truth DEGs with ``gt.pvalue_adj < 0.05``. + + Parameters + ---------- + %(params)s + + Returns + ------- + float + AUROC in [0, 1]; higher is better. + Returns ``nan`` if all genes fall in the same class. + """ labels = gt.pvalue_adj < 0.05 if labels.sum() == 0 or labels.sum() == labels.size: return float("nan") return float(roc_auc_score(labels, prediction)) +@_doc(params=_DE) def de_overlap(gt, prediction, ctx, k=50): + """Top-k overlap between ground-truth and predicted DE gene rankings. + + .. math:: + + \\text{Overlap}_k + = \\frac{|\\text{top-}k(|gt.score|) \\cap \\text{top-}k(pred)|}{k} + + Parameters + ---------- + %(params)s + k : int + Number of top genes to intersect (default 50). + + Returns + ------- + float + Fraction of top-k genes shared, in [0, 1]; higher is better. + Returns ``nan`` if k >= number of genes. + """ truth = np.abs(gt.score) if k >= truth.size: return float("nan") diff --git a/src/scperteval/protocols/table.py b/src/scperteval/protocols/table.py index 946609b..0cd535e 100644 --- a/src/scperteval/protocols/table.py +++ b/src/scperteval/protocols/table.py @@ -5,30 +5,42 @@ with no value the family default is used. To add a protocol, write a metric in ``metrics.py`` and add one row below. """ + from __future__ import annotations from functools import partial +from typing import Any from ..blocks.spaces import degs_space, pca_space, top_space from ..types import Param, Protocol from . import metrics as M - # --- parameter families: a CLI value selects a feature space (or feeds the metric) --- -top_k = Param("k", int, 50, space=top_space) # top-k DEGs by effect size -pca_k = Param("k", int, 50, space=pca_space) # k principal components +top_k = Param("k", int, 50, space=top_space) # top-k DEGs by effect size +pca_k = Param("k", int, 50, space=pca_space) # k principal components degs_padj = Param("padj", float, 0.05, space=degs_space) # DEGs at adjusted p < padj -overlap_k = Param("k", int, 50) # passed straight to de_overlap's k +overlap_k = Param("k", int, 50) # passed straight to de_overlap's k # --- shared wiring bundles (controls + score scale), splatted into rows with ** --- -_PB = dict(group="pseudobulk", positive="interpolated", negative="all_perturbed_mean") -_PB_CTRL = dict(group="pseudobulk", positive="interpolated", negative="control") -_LOWER = dict(better="lower", perfect=0.0) -_DIST = dict(group="distributional", positive="tech_dup", negative="all_perturbed", better="lower", perfect=0.0) -_DE = dict(group="de", positive="tech_dup", negative="all_perturbed", reference="all_perturbed", - neg_reference="control", better="higher", perfect=1.0) -_RANK = dict(group="pseudobulk", positive="interpolated", negative="global_mean", better="lower", perfect=0.0) +_PB: dict[str, Any] = dict(group="pseudobulk", positive="interpolated", negative="all_perturbed_mean") +_PB_CTRL: dict[str, Any] = dict(group="pseudobulk", positive="interpolated", negative="control") +_LOWER: dict[str, Any] = dict(better="lower", perfect=0.0) +_DIST: dict[str, Any] = dict( + group="distributional", positive="tech_dup", negative="all_perturbed", better="lower", perfect=0.0 +) +_DE: dict[str, Any] = dict( + group="de", + positive="tech_dup", + negative="all_perturbed", + reference="all_perturbed", + neg_reference="control", + better="higher", + perfect=1.0, +) +_RANK: dict[str, Any] = dict( + group="pseudobulk", positive="interpolated", negative="global_mean", better="lower", perfect=0.0 +) TABLE = [ @@ -43,14 +55,14 @@ Protocol("mse_top_k", M.mse, representation="centroid", param=top_k, **_PB, **_LOWER), Protocol("mse_degs_padj", M.mse, representation="centroid", param=degs_padj, **_PB, **_LOWER), Protocol("pearson_pert_top_k", M.pearson, representation="centroid", centering="allpert", param=top_k, **_PB_CTRL), - Protocol("pearson_pert_degs_padj", M.pearson, representation="centroid", centering="allpert", param=degs_padj, **_PB_CTRL), - + Protocol( + "pearson_pert_degs_padj", M.pearson, representation="centroid", centering="allpert", param=degs_padj, **_PB_CTRL + ), # --- cross-perturbation retrieval rank (dataset-wide over centroids) --- - Protocol("rank", partial(M.rank_retrieval, transpose=False), - representation="centroid", scope="dataset", **_RANK), - Protocol("transpose_rank", partial(M.rank_retrieval, transpose=True), - representation="centroid", scope="dataset", **_RANK), - + Protocol("rank", partial(M.rank_retrieval, transpose=False), representation="centroid", scope="dataset", **_RANK), + Protocol( + "transpose_rank", partial(M.rank_retrieval, transpose=True), representation="centroid", scope="dataset", **_RANK + ), # --- distributional: distances between cell populations (positive = technical duplicate) --- Protocol("unbiased_mmd_median_top_k", M.unbiased_mmd_median, representation="population", param=top_k, **_DIST), Protocol("unbiased_mmd_median_pca_k", M.unbiased_mmd_median, representation="population", param=pca_k, **_DIST), @@ -58,7 +70,6 @@ Protocol("energy_distance_pca_k", M.energy_distance, representation="population", param=pca_k, **_DIST), Protocol("sinkhorn_w2_top_k", M.sinkhorn_w2, representation="population", param=top_k, **_DIST), Protocol("sinkhorn_w2_pca_k", M.sinkhorn_w2, representation="population", param=pca_k, **_DIST), - # --- differential expression: GT DEGs vs prediction ranking --- Protocol("de_auprc", M.de_auprc, representation="de", **_DE), Protocol("de_auroc", M.de_auroc, representation="de", **_DE), diff --git a/src/scperteval/reference.py b/src/scperteval/reference.py index 040c893..b140301 100644 --- a/src/scperteval/reference.py +++ b/src/scperteval/reference.py @@ -1,8 +1,11 @@ """The comparison reference: a fixed cell sample served leave-one-out.""" + from __future__ import annotations import warnings +import numpy as np + class Reference: """A comparison sample of cells (the all-perturbed subsample, or non-targeting @@ -16,13 +19,13 @@ class Reference: """ def __init__(self, cells, labels=None, warn_frac: float = 0.10): - self.cells = cells # densified once, (n_cells, n_genes) - self.labels = labels # per-cell perturbation, or None + self.cells = cells # densified once, (n_cells, n_genes) + self.labels = labels # per-cell perturbation, or None self.warn_frac = warn_frac self._n = len(cells) self._warned: set = set() - def keep(self, exclude) -> "object": + def keep(self, exclude) -> np.ndarray | None: """Boolean mask of the cells to keep, or None when nothing is excluded.""" if self.labels is None: return None diff --git a/src/scperteval/registry.py b/src/scperteval/registry.py index 25418c0..d6ba0da 100644 --- a/src/scperteval/registry.py +++ b/src/scperteval/registry.py @@ -1,35 +1,62 @@ """A minimal decorator registry for the pluggable building blocks.""" + from __future__ import annotations -from typing import Callable +from collections.abc import Callable class Registry: - """Maps a name to a function plus optional metadata, populated by decoration.""" + """Maps a name to a function plus optional metadata, populated by decoration. + + Parameters + ---------- + kind : str + Human-readable label used in error messages (e.g. ``"de-method"``). + + Example + ------- + >>> from scperteval.registry import Registry + >>> MY_REG = Registry("example") + >>> @MY_REG.register("double", description="multiply by 2") + ... def double(x): + ... return x * 2 + >>> MY_REG["double"](3) + 6 + >>> MY_REG.meta("double") + {'description': 'multiply by 2'} + >>> MY_REG.names() + ['double'] + """ def __init__(self, kind: str): self.kind = kind self._items: dict[str, tuple[Callable, dict]] = {} def register(self, name: str, **meta) -> Callable: + """Decorator that registers a function under ``name`` with optional metadata.""" def deco(fn: Callable) -> Callable: self._items[name] = (fn, meta) return fn + return deco def add(self, name: str, fn: Callable, **meta) -> None: + """Register a function under ``name`` without using the decorator form.""" self._items[name] = (fn, meta) def __getitem__(self, name: str) -> Callable: + """Return the function registered under ``name``.""" if name not in self._items: raise KeyError(f"unknown {self.kind} {name!r}; available: {self.names()}") return self._items[name][0] def meta(self, name: str) -> dict: + """Return the metadata dict registered alongside ``name``.""" return self._items[name][1] def __contains__(self, name: str) -> bool: return name in self._items def names(self) -> list[str]: + """Sorted list of registered names.""" return sorted(self._items) From c648b58757efb907f9cc23ebc317a0ba8a889875 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 29 Jun 2026 18:36:47 -0700 Subject: [PATCH 09/15] chore(ruff): apply format to score-mode modules --- src/scperteval/blocks/spaces.py | 22 ++++-- src/scperteval/calibrators.py | 13 +++- src/scperteval/cli.py | 133 +++++++++++++++++++++----------- src/scperteval/context.py | 38 +++++---- src/scperteval/predictions.py | 6 +- src/scperteval/runner.py | 21 +++-- src/scperteval/sources.py | 69 +++++++++++------ src/scperteval/types.py | 15 ++-- 8 files changed, 212 insertions(+), 105 deletions(-) diff --git a/src/scperteval/blocks/spaces.py b/src/scperteval/blocks/spaces.py index 96f8ab7..26e34f9 100644 --- a/src/scperteval/blocks/spaces.py +++ b/src/scperteval/blocks/spaces.py @@ -7,6 +7,7 @@ templates); the default instances created at import are what ``scperteval list spaces`` shows. ``description`` is shown by ``scperteval list spaces``. """ + from __future__ import annotations import numpy as np @@ -42,8 +43,9 @@ def top_space(k: int) -> str: """top-k genes by |ground-truth effect size| (registered on demand).""" name = f"top_{k}" if name not in SPACES: - register_de_space(name, field="score", top=k, - description=f"top {k} genes by ground-truth effect size, per perturbation") + register_de_space( + name, field="score", top=k, description=f"top {k} genes by ground-truth effect size, per perturbation" + ) return name @@ -51,8 +53,12 @@ def degs_space(padj: float) -> str: """ground-truth DEGs at adjusted p < padj (registered on demand).""" name = f"degs_{padj:g}" if name not in SPACES: - register_de_space(name, field="pvalue_adj", threshold=(lambda v, p=padj: v < p), - description=f"ground-truth DEGs at adjusted p < {padj:g}, per perturbation") + register_de_space( + name, + field="pvalue_adj", + threshold=(lambda v, p=padj: v < p), + description=f"ground-truth DEGs at adjusted p < {padj:g}, per perturbation", + ) return name @@ -60,8 +66,12 @@ def pca_space(k: int) -> str: """top-k principal components (registered on demand).""" name = f"pca_{k}" if name not in SPACES: - SPACES.add(name, lambda X, ctx, pert, k=k: ctx.pca(k).transform(to_dense(X))[:, :k], - global_space=True, description=f"top {k} principal components (fit on the dataset)") + SPACES.add( + name, + lambda X, ctx, pert, k=k: ctx.pca(k).transform(to_dense(X))[:, :k], + global_space=True, + description=f"top {k} principal components (fit on the dataset)", + ) return name diff --git a/src/scperteval/calibrators.py b/src/scperteval/calibrators.py index 1a1ce00..a1b9a45 100644 --- a/src/scperteval/calibrators.py +++ b/src/scperteval/calibrators.py @@ -2,6 +2,7 @@ per-metric score. Each declares the control roles it needs, a per-perturbation combine, and a cross-perturbation aggregate. """ + from __future__ import annotations import numpy as np @@ -29,17 +30,23 @@ def _bds_per_pert(raws, p): CALIBRATORS = { "drf": Calibrator( - "drf", ("positive", "negative"), _drf_per_pert, + "drf", + ("positive", "negative"), + _drf_per_pert, lambda v: {"mean": float(np.nanmean(v)), "median": float(np.nanmedian(v))}, description="Dynamic Range Fraction — mean/median over perturbations (Miller et al. 2025)", ), "bds": Calibrator( - "bds", ("positive", "negative"), _bds_per_pert, + "bds", + ("positive", "negative"), + _bds_per_pert, lambda v: {"bds": float(np.nanmean(v))}, description="Bound Discrimination Score — fraction of perturbations the positive control wins (SBB 2026)", ), "score": Calibrator( - "score", ("prediction",), lambda raws, p: raws["prediction"], + "score", + ("prediction",), + lambda raws, p: raws["prediction"], lambda v: {"mean": float(np.nanmean(v)), "median": float(np.nanmedian(v))}, description="raw metric of a prediction vs ground truth — mean/median over perturbations (prediction-scoring mode)", ), diff --git a/src/scperteval/cli.py b/src/scperteval/cli.py index 048e665..4814654 100644 --- a/src/scperteval/cli.py +++ b/src/scperteval/cli.py @@ -1,4 +1,5 @@ """scPertEval command-line interface.""" + from __future__ import annotations import argparse @@ -27,7 +28,7 @@ def _resolve_token(token: str) -> list[Protocol]: return [_concrete(p) for p in TABLE] if token in GROUPS: return [_concrete(p) for p in TABLE if p.group == token] - if "=" in token: # a tunable protocol with a value, e.g. mse_top_k=30 + if "=" in token: # a tunable protocol with a value, e.g. mse_top_k=30 name, _, value = token.partition("=") p = PROTOCOLS.get(name) if p is None or not p.parameterised: @@ -55,7 +56,8 @@ def resolve_protocols(specs: list[str]) -> list[Protocol]: def _evaluate(cfg: RunConfig, protocols, ctx, quiet: bool) -> None: """Run every protocol over the dataset, print the summary, and write the per-perturbation CSV. Shared by ``calibrate`` and ``score`` (prediction vs ground truth); they differ only - in how ``ctx`` is built and which calibrator ``cfg.output`` selects.""" + in how ``ctx`` is built and which calibrator ``cfg.output`` selects. + """ calibrator = CALIBRATORS[cfg.output] ctx.warm(protocols) aggregates, rows, timed = {}, [], [] @@ -76,11 +78,19 @@ def _evaluate(cfg: RunConfig, protocols, ctx, quiet: bool) -> None: def cmd_calibrate(args) -> None: protocols = resolve_protocols(args.protocols or ["all"]) cfg = RunConfig( - dataset=args.dataset, protocols=[p.name for p in protocols], de_method=args.de_method, - subsample=args.subsample, seed=args.seed, positive=args.positive, - negative=args.negative, output=args.output, out_dir=args.out_dir, - workers=args.workers, perturbation_key=args.perturbation_key, - control_label=args.control_label, min_cells=args.min_cells, + dataset=args.dataset, + protocols=[p.name for p in protocols], + de_method=args.de_method, + subsample=args.subsample, + seed=args.seed, + positive=args.positive, + negative=args.negative, + output=args.output, + out_dir=args.out_dir, + workers=args.workers, + perturbation_key=args.perturbation_key, + control_label=args.control_label, + min_cells=args.min_cells, profile=args.profile, ) ctx = Context(Dataset.load(cfg.dataset, cfg), cfg) @@ -90,11 +100,20 @@ def cmd_calibrate(args) -> None: def cmd_score(args) -> None: protocols = resolve_protocols(args.protocols or ["all"]) cfg = RunConfig( - dataset=args.dataset, protocols=[p.name for p in protocols], de_method=args.de_method, - subsample=args.subsample, seed=args.seed, output="score", out_dir=args.out_dir, - workers=args.workers, perturbation_key=args.perturbation_key, - control_label=args.control_label, min_cells=args.min_cells, profile=args.profile, - predictions=args.predictions, truth="gt_all_cells", + dataset=args.dataset, + protocols=[p.name for p in protocols], + de_method=args.de_method, + subsample=args.subsample, + seed=args.seed, + output="score", + out_dir=args.out_dir, + workers=args.workers, + perturbation_key=args.perturbation_key, + control_label=args.control_label, + min_cells=args.min_cells, + profile=args.profile, + predictions=args.predictions, + truth="gt_all_cells", ) ds = Dataset.load(cfg.dataset, cfg) ctx = Context(ds, cfg) @@ -105,10 +124,16 @@ def cmd_score(args) -> None: def cmd_de(args) -> None: methods = [m.strip() for m in args.methods.split(",") if m.strip()] cfg = RunConfig( - dataset=args.dataset, protocols=[], de_method=methods[0], - subsample=args.subsample, seed=args.seed, out_dir=args.out_dir, - workers=args.workers, min_cells=args.min_cells, - perturbation_key=args.perturbation_key, control_label=args.control_label, + dataset=args.dataset, + protocols=[], + de_method=methods[0], + subsample=args.subsample, + seed=args.seed, + out_dir=args.out_dir, + workers=args.workers, + min_cells=args.min_cells, + perturbation_key=args.perturbation_key, + control_label=args.control_label, ) ctx = Context(Dataset.load(cfg.dataset, cfg), cfg) ctx._ensure_ref_sums() @@ -123,10 +148,12 @@ def reg(registry, fmt): return [fmt(n, registry.meta(n)) for n in registry.names()] if args.what == "protocols": + def descr(p): scope = "" if p.scope == "perturbation" else f", {p.scope}-wide" knob = f"{p.param.name}=…" if p.parameterised else f"space={p.space}" return f"{p.group}, {p.representation}{scope}, {knob}" + lines = [f"{p.name:24s} ({descr(p)})" for p in TABLE] elif args.what == "de-methods": lines = reg(DE_METHODS, lambda n, m: f"{n:10s} — {m.get('description', '')}") @@ -143,60 +170,78 @@ def main(argv=None) -> None: parser = argparse.ArgumentParser(prog="scperteval", description=__doc__) sub = parser.add_subparsers(dest="cmd", required=True) - calibrate = sub.add_parser( - "calibrate", help="calibrate protocols against positive/negative controls (DRF/BDS)") + calibrate = sub.add_parser("calibrate", help="calibrate protocols against positive/negative controls (DRF/BDS)") calibrate.add_argument("dataset", help="preprocessed .h5ad") - calibrate.add_argument("-p", "--protocols", action="append", default=[], - help="comma-separated names, a group (pseudobulk|distributional|de), or 'all'") - calibrate.add_argument("--de-method", choices=DE_METHODS.names(), default="t-test", - help="DE backend for EVERY DE-dependent unit in the run: the interpolated " - "positive control, the top_k/degs spaces, the de_* protocols, and the WMSE weights") + calibrate.add_argument( + "-p", + "--protocols", + action="append", + default=[], + help="comma-separated names, a group (pseudobulk|distributional|de), or 'all'", + ) + calibrate.add_argument( + "--de-method", + choices=DE_METHODS.names(), + default="t-test", + help="DE backend for EVERY DE-dependent unit in the run: the interpolated " + "positive control, the top_k/degs spaces, the de_* protocols, and the WMSE weights", + ) calibrate.add_argument("--subsample", type=int, default=8192) calibrate.add_argument("--seed", type=int, default=42) calibrate.add_argument("--positive", default="auto") calibrate.add_argument("--negative", default="auto") calibrate.add_argument( - "--output", default="drf", + "--output", + default="drf", choices=[n for n, c in CALIBRATORS.items() if "prediction" not in c.requires], - help="how per-perturbation values are calibrated (drf/bds)") + help="how per-perturbation values are calibrated (drf/bds)", + ) calibrate.add_argument("--out-dir", default="results") calibrate.add_argument("--workers", type=int, default=0, help="threads (0 = auto)") calibrate.add_argument("--perturbation-key", default="perturbation") calibrate.add_argument("--control-label", default="control") - calibrate.add_argument("--min-cells", type=int, default=30, - help="skip perturbations with fewer cells") - calibrate.add_argument("--profile", action="store_true", - help="also write a per-protocol wall-clock timing table") + calibrate.add_argument("--min-cells", type=int, default=30, help="skip perturbations with fewer cells") + calibrate.add_argument("--profile", action="store_true", help="also write a per-protocol wall-clock timing table") calibrate.add_argument("--quiet", action="store_true") calibrate.set_defaults(func=cmd_calibrate) - score = sub.add_parser( - "score", help="score model predictions against ground truth (real cells), per protocol") + score = sub.add_parser("score", help="score model predictions against ground truth (real cells), per protocol") score.add_argument("dataset", help="preprocessed .h5ad — the ground truth (real cells)") score.add_argument("predictions", help="predicted .h5ad — same genes and perturbation labels") - score.add_argument("-p", "--protocols", action="append", default=[], - help="comma-separated names, a group (pseudobulk|distributional|de), or 'all'") - score.add_argument("--de-method", choices=DE_METHODS.names(), default="t-test", - help="DE backend for every DE-dependent unit (the top_k/degs spaces, the " - "de_* protocols, and the WMSE weights)") - score.add_argument("--subsample", type=int, default=8192, - help="cells in the all-perturbed reference sample (the ground truth itself is never subsampled)") + score.add_argument( + "-p", + "--protocols", + action="append", + default=[], + help="comma-separated names, a group (pseudobulk|distributional|de), or 'all'", + ) + score.add_argument( + "--de-method", + choices=DE_METHODS.names(), + default="t-test", + help="DE backend for every DE-dependent unit (the top_k/degs spaces, the de_* protocols, and the WMSE weights)", + ) + score.add_argument( + "--subsample", + type=int, + default=8192, + help="cells in the all-perturbed reference sample (the ground truth itself is never subsampled)", + ) score.add_argument("--seed", type=int, default=42) score.add_argument("--out-dir", default="results") score.add_argument("--workers", type=int, default=0, help="threads (0 = auto)") score.add_argument("--perturbation-key", default="perturbation") score.add_argument("--control-label", default="control") - score.add_argument("--min-cells", type=int, default=30, - help="skip perturbations with fewer cells") - score.add_argument("--profile", action="store_true", - help="also write a per-protocol wall-clock timing table") + score.add_argument("--min-cells", type=int, default=30, help="skip perturbations with fewer cells") + score.add_argument("--profile", action="store_true", help="also write a per-protocol wall-clock timing table") score.add_argument("--quiet", action="store_true") score.set_defaults(func=cmd_score) de = sub.add_parser("de", help="write per-gene DE (statistic + adj p) per method to HDF5") de.add_argument("dataset", help="preprocessed .h5ad") - de.add_argument("--methods", default="t-test,MWU", - help="comma-separated DE methods to compute (GT first-half vs all-perturbed)") + de.add_argument( + "--methods", default="t-test,MWU", help="comma-separated DE methods to compute (GT first-half vs all-perturbed)" + ) de.add_argument("--subsample", type=int, default=8192) de.add_argument("--seed", type=int, default=42) de.add_argument("--out-dir", default="results") diff --git a/src/scperteval/context.py b/src/scperteval/context.py index 976f332..40ba80d 100644 --- a/src/scperteval/context.py +++ b/src/scperteval/context.py @@ -1,5 +1,7 @@ """The per-run engine: lazily builds and caches the shared building blocks, and -turns a (perturbation, source) into the exact view a protocol consumes.""" +turns a (perturbation, source) into the exact view a protocol consumes. +""" + from __future__ import annotations import threading @@ -24,7 +26,7 @@ class Context: def __init__(self, dataset: Dataset, cfg: RunConfig): self.ds = dataset self.cfg = cfg - self.predictions = None # a PredictionSet in prediction-scoring mode, else None + self.predictions = None # a PredictionSet in prediction-scoring mode, else None self._local = threading.local() # Reentrant: several lazy initialisers (e.g. _ensure_ref_sums, ref_projection) # call reference() while already holding this lock, which a plain Lock would @@ -54,7 +56,8 @@ def current_pert(self, value): def warm(self, protocols): """Precompute shared singletons before the parallel loop so per-perturbation - threads only ever write per-perturbation cache keys.""" + threads only ever write per-perturbation cache keys. + """ self.control_mean() if any(p.representation in ("population", "de") for p in protocols): self.reference() @@ -63,8 +66,9 @@ def warm(self, protocols): self._moments("control", None) if any(p.space == "pca50" for p in protocols): self.pca() - for space in {p.space for p in protocols - if p.representation == "population" and SPACES.meta(p.space).get("global_space")}: + for space in { + p.space for p in protocols if p.representation == "population" and SPACES.meta(p.space).get("global_space") + }: self.ref_projection(space) def view(self, pert: str, source: str, p: Protocol): @@ -94,7 +98,8 @@ def centroid(self, pert, source, centering): def _de_view(self, pert, source, p): """GT -> truth labels (its DEResult); a candidate -> its |score| ranking. The negative candidate is tested against ``neg_reference`` (e.g. control) - rather than ``reference`` (the all-perturbed sample), the hybrid DE setup.""" + rather than ``reference`` (the all-perturbed sample), the hybrid DE setup. + """ if source == self.cfg.truth: return self.de(pert, self.cfg.truth, p.reference) reference = p.neg_reference if (source == p.negative and p.neg_reference) else p.reference @@ -102,16 +107,15 @@ def _de_view(self, pert, source, p): def de(self, pert, source, reference="all_perturbed"): """DE for one (source vs reference) comparison; the reference moments are - leave-one-out, so a perturbation is never compared against a sample of itself.""" + leave-one-out, so a perturbation is never compared against a sample of itself. + """ method = self.cfg.de_method key = (self._mom_key(source, pert), self._mom_key(reference, pert), method) if key not in self._de: if method == "t-test": - self._de[key] = ttest_from_moments(*self._moments(source, pert), - *self._moments(reference, pert)) + self._de[key] = ttest_from_moments(*self._moments(source, pert), *self._moments(reference, pert)) else: - self._de[key] = DE_METHODS[method](self._de_cells(source, pert), - self._de_cells(reference, pert)) + self._de[key] = DE_METHODS[method](self._de_cells(source, pert), self._de_cells(reference, pert)) return self._de[key] def _moments(self, source, pert): @@ -147,7 +151,8 @@ def wmse_weights(self, pert): def reference(self) -> Reference: """The all-perturbed sample (subsampled + densified once), with each cell's - perturbation recorded so it can be served leave-one-out.""" + perturbation recorded so it can be served leave-one-out. + """ if self._reference is None: with self._init_lock: if self._reference is None: @@ -158,7 +163,8 @@ def reference(self) -> Reference: def _reference_population(self, space, pert): """The reference in a feature space with the target perturbation removed: - project the whole sample (cached for global spaces) then drop its rows.""" + project the whole sample (cached for global spaces) then drop its rows. + """ ref = self.reference() if SPACES.meta(space).get("global_space"): proj = self.ref_projection(space) @@ -177,7 +183,8 @@ def ref_projection(self, space): def _ensure_ref_sums(self): """Cache the reference's column sums and sums-of-squares once, so leave-one-out - moments are an O(target cells) subtraction rather than a re-densify per perturbation.""" + moments are an O(target cells) subtraction rather than a re-densify per perturbation. + """ if self._ref_sums is None: with self._init_lock: if self._ref_sums is None: @@ -219,7 +226,8 @@ def pca(self, k=50): def _fit_pca(self, n_components): """Fit PCA on (nearly) all cells; the subsample cap is for the O(n^2) - distance populations, not the PCA basis, which needs many cells to be stable.""" + distance populations, not the PCA basis, which needs many cells to be stable. + """ from sklearn.decomposition import PCA n = self.ds.adata.n_obs diff --git a/src/scperteval/predictions.py b/src/scperteval/predictions.py index c239e04..35a7da2 100644 --- a/src/scperteval/predictions.py +++ b/src/scperteval/predictions.py @@ -5,6 +5,7 @@ any order) and the same perturbation column; columns are reordered to the dataset's gene order so every metric's positional ``gt - prediction`` comparison lines up. """ + from __future__ import annotations import anndata as ad @@ -18,7 +19,8 @@ def _align_genes(pred_genes: np.ndarray, ds_genes: np.ndarray) -> np.ndarray: """Indices that reorder the prediction's genes into the dataset's gene order. Errors (naming what's wrong) unless the two gene sets are identical -- metrics compare - gene vectors positionally, so a mismatch would silently compare the wrong genes.""" + gene vectors positionally, so a mismatch would silently compare the wrong genes. + """ pred_set, ds_set = set(map(str, pred_genes)), set(map(str, ds_genes)) missing = [g for g in map(str, ds_genes) if g not in pred_set] extra = [g for g in map(str, pred_genes) if g not in ds_set] @@ -47,7 +49,7 @@ def __init__(self, adata, ds: Dataset, cfg: RunConfig): self.pert = np.asarray(adata.obs[cfg.perturbation_key]).astype(str) @classmethod - def load(cls, path: str, ds: Dataset, cfg: RunConfig) -> "PredictionSet": + def load(cls, path: str, ds: Dataset, cfg: RunConfig) -> PredictionSet: return cls(ad.read_h5ad(path), ds, cfg) def cells(self, pert: str) -> np.ndarray: diff --git a/src/scperteval/runner.py b/src/scperteval/runner.py index b17de77..f24ff31 100644 --- a/src/scperteval/runner.py +++ b/src/scperteval/runner.py @@ -1,4 +1,5 @@ """Runs one protocol over every perturbation and applies the chosen calibrator.""" + from __future__ import annotations import os @@ -18,7 +19,8 @@ def n_workers(cfg) -> int: def resolve_roles(p: Protocol, cfg) -> dict: """Map each candidate role a calibrator may require to a source name. ``positive`` / ``negative`` come from the protocol (or a CLI override); ``prediction`` is always the - model-prediction source (used by the ``score`` calibrator).""" + model-prediction source (used by the ``score`` calibrator). + """ return { "positive": cfg.positive if cfg.positive != "auto" else p.positive, "negative": cfg.negative if cfg.negative != "auto" else p.negative, @@ -42,15 +44,21 @@ def run_protocol(p: Protocol, ctx, calibrator: Calibrator): def _finalize(p, calibrator, perts, raws_list): """Per-perturbation rows + the aggregate, from each perturbation's raw control values.""" per_pert = [calibrator.per_pert(raws, p) for raws in raws_list] - rows = [{"protocol": p.name, "perturbation": pert, - **{f"raw_{role}": raws[role] for role in raws}, - calibrator.name: value} - for pert, raws, value in zip(perts, raws_list, per_pert)] + rows = [ + { + "protocol": p.name, + "perturbation": pert, + **{f"raw_{role}": raws[role] for role in raws}, + calibrator.name: value, + } + for pert, raws, value in zip(perts, raws_list, per_pert) + ] return calibrator.aggregate(np.asarray(per_pert, dtype=float)), rows def _run_per_perturbation(p: Protocol, ctx, calibrator: Calibrator, needed: dict): """Score one perturbation at a time (across a thread pool), gt vs each control.""" + def work(pert): ctx.current_pert = pert gt = ctx.view(pert, ctx.cfg.truth, p) @@ -93,7 +101,8 @@ def collect(source): def compute_de_export(ctx, methods): """{method: (statistic, pvalue_adj)} matrices (perturbations x genes) for each - method's GT(first-half)-vs-all-perturbed differential expression.""" + method's GT(first-half)-vs-all-perturbed differential expression. + """ out = {} for method in methods: ctx.cfg.de_method = method diff --git a/src/scperteval/sources.py b/src/scperteval/sources.py index 93ffd70..6bb0f6d 100644 --- a/src/scperteval/sources.py +++ b/src/scperteval/sources.py @@ -5,6 +5,7 @@ check and how the context turns a source into a view. ``description`` is shown by ``scperteval list sources``. """ + from __future__ import annotations import numpy as np @@ -15,61 +16,85 @@ SOURCES = Registry("source") -@SOURCES.register("gt_half", provides="cells", - description="ground truth — the first half of a perturbation's cells (calibration truth)") +@SOURCES.register( + "gt_half", + provides="cells", + description="ground truth — the first half of a perturbation's cells (calibration truth)", +) def src_gt_half(ctx, pert): return ctx.ds.cells(pert, half="first") -@SOURCES.register("gt_all_cells", provides="cells", - description="ground truth — all of a perturbation's real cells (prediction-scoring truth)") +@SOURCES.register( + "gt_all_cells", + provides="cells", + description="ground truth — all of a perturbation's real cells (prediction-scoring truth)", +) def src_gt_all_cells(ctx, pert): return ctx.ds.cells(pert) -@SOURCES.register("prediction", provides="cells", - description="model-predicted cells for the perturbation, from the --predictions h5ad") +@SOURCES.register( + "prediction", + provides="cells", + description="model-predicted cells for the perturbation, from the --predictions h5ad", +) def src_prediction(ctx, pert): return ctx.predictions.cells(pert) -@SOURCES.register("tech_dup", provides="cells", - description="technical duplicate — the held-out second half (single-cell positive control)") +@SOURCES.register( + "tech_dup", + provides="cells", + description="technical duplicate — the held-out second half (single-cell positive control)", +) def src_tech_dup(ctx, pert): return ctx.ds.cells(pert, half="second") -@SOURCES.register("control", provides="cells", - description="non-targeting control cells") +@SOURCES.register("control", provides="cells", description="non-targeting control cells") def src_control(ctx, pert): return ctx.ds.control_cells(ctx.cfg.subsample) -@SOURCES.register("all_perturbed", provides="cells", - description="all-perturbed reference sample, leave-one-out (single-cell negative control)") +@SOURCES.register( + "all_perturbed", + provides="cells", + description="all-perturbed reference sample, leave-one-out (single-cell negative control)", +) def src_all_perturbed(ctx, pert): return ctx.reference().subset(pert) -@SOURCES.register("all_perturbed_mean", provides="centroid", - description="all-perturbed mean, excluding the target — leave-one-out " - "(pseudobulk sibling of all_perturbed; pseudobulk negative control)") +@SOURCES.register( + "all_perturbed_mean", + provides="centroid", + description="all-perturbed mean, excluding the target — leave-one-out " + "(pseudobulk sibling of all_perturbed; pseudobulk negative control)", +) def src_all_perturbed_mean(ctx, pert): return ctx.ds.allpert_mean_except(pert) -@SOURCES.register("global_mean", provides="centroid", - description="mean of all perturbations — shared baseline for the ranking protocols") +@SOURCES.register( + "global_mean", + provides="centroid", + description="mean of all perturbations — shared baseline for the ranking protocols", +) def src_global_mean(ctx, pert): return ctx.ds.allpert_mean() -@SOURCES.register("interpolated", provides="centroid", - description="interpolated duplicate — DE-weighted blend of the held-out half and " - "the dataset mean (pseudobulk positive control)") +@SOURCES.register( + "interpolated", + provides="centroid", + description="interpolated duplicate — DE-weighted blend of the held-out half and " + "the dataset mean (pseudobulk positive control)", +) def src_interpolated(ctx, pert): - """alpha = 1 - adjusted p per gene (from the run's DE method, vs control); blend toward - the held-out replicate where the gene is significant, else toward the all-perturbed mean.""" + """Alpha = 1 - adjusted p per gene (from the run's DE method, vs control); blend toward + the held-out replicate where the gene is significant, else toward the all-perturbed mean. + """ tech = np.asarray(to_dense(ctx.ds.cells(pert, half="second"))).mean(0) alpha = np.nan_to_num(1.0 - ctx.de(pert, "tech_dup", "control").pvalue_adj, nan=0.0) return alpha * tech + (1.0 - alpha) * ctx.ds.allpert_mean_except(pert) diff --git a/src/scperteval/types.py b/src/scperteval/types.py index ecb7bae..176299e 100644 --- a/src/scperteval/types.py +++ b/src/scperteval/types.py @@ -1,9 +1,10 @@ """Core dataclasses shared across the package.""" + from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass, field, replace from functools import partial -from typing import Callable, Optional import numpy as np @@ -20,7 +21,7 @@ class RunConfig: positive: str = "auto" negative: str = "auto" truth: str = "gt_half" - predictions: Optional[str] = None + predictions: str | None = None output: str = "drf" out_dir: str = "results" workers: int = 0 @@ -50,7 +51,7 @@ class Param: name: str cast: Callable default: float - space: Optional[Callable] = None + space: Callable | None = None @dataclass(frozen=True) @@ -93,21 +94,21 @@ class Protocol: representation: str scope: str = "perturbation" space: str = "full" - centering: Optional[str] = None + centering: str | None = None reference: str = "all_perturbed" - neg_reference: Optional[str] = None + neg_reference: str | None = None better: str = "higher" perfect: float = 1.0 positive: str = "auto" negative: str = "auto" group: str = "" - param: Optional[Param] = None + param: Param | None = None @property def parameterised(self) -> bool: return self.param is not None - def resolve(self, value) -> "Protocol": + def resolve(self, value) -> Protocol: """Concrete protocol for a tunable one at ``value`` (sets the space or metric arg).""" suffix = f"{value:g}" if isinstance(value, float) else str(value) name = f"{self.name}={suffix}" From 8223db26a99c81b8714f7f5fdc836b0396827ae7 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 29 Jun 2026 18:42:20 -0700 Subject: [PATCH 10/15] docs: reconcile to calibrate/score API; add scoring page --- README.md | 15 ++++++--- docs/api.md | 18 +++++++++- docs/index.md | 16 +++++---- docs/user-guide/index.md | 5 ++- docs/user-guide/protocols.md | 33 ++++++++++++------- docs/user-guide/scoring.md | 48 +++++++++++++++++++++++++++ docs/user-guide/usage.md | 64 ++++++++++++++++++++++++++++-------- 7 files changed, 162 insertions(+), 37 deletions(-) create mode 100644 docs/user-guide/scoring.md diff --git a/README.md b/README.md index 0bd9523..46eda33 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,11 @@ # scPertEval — Evaluation Protocols for Perturbation Sequencing scPertEval is a command-line tool for **experimenting with and sharing reference implementations of -evaluation protocols** in single-cell perturbation studies. It calibrates each protocol -against empirical positive and negative controls per perturbation, outputting the -**Dynamic Range Fraction (DRF)** and the **Bound Discrimination Score (BDS)**. +evaluation protocols** in single-cell perturbation studies. The same catalog of protocols backs +three commands: **`score`** (score a model's predictions against ground truth), **`calibrate`** +(calibrate a protocol against empirical positive/negative controls per perturbation, reporting the +**Dynamic Range Fraction (DRF)** and **Bound Discrimination Score (BDS)**), and **`de`** (export +per-gene differential expression). Our accompanying publication: TODO_LINK_HERE @@ -24,7 +26,12 @@ pip install "scperteval @ git+https://github.com/Virtual-Cell-Research-Community ## Quick start ```bash -scperteval run data/wessels23.h5ad -p all --de-method t-test +# calibrate protocols against built-in controls (DRF/BDS) +scperteval calibrate data/wessels23.h5ad -p all --de-method t-test + +# score a model's predictions against ground truth +scperteval score data/wessels23.h5ad predictions.h5ad -p all + scperteval list protocols # also: de-methods | spaces | sources | calibrators ``` diff --git a/docs/api.md b/docs/api.md index ff5acac..d1f4cc4 100644 --- a/docs/api.md +++ b/docs/api.md @@ -70,7 +70,8 @@ :no-members: ``` -`scperteval.calibrators.CALIBRATORS` — `{name: Calibrator}` dict of built-in calibrators (`drf`, `bds`). +`scperteval.calibrators.CALIBRATORS` — `{name: Calibrator}` dict of built-in calibrators (`drf`, +`bds`, and `score` for the prediction-scoring mode). Add entries here to register a new calibrator; see [Add a calibrator](user-guide/building-blocks.md#add-a-calibrator). ## Building blocks @@ -127,6 +128,21 @@ Add entries here to register a new calibrator; see [Add a calibrator](user-guide `scperteval.sources.SOURCES` — registry of all control/reference sources. Add entries here to register a new source; see [Add a control source](user-guide/building-blocks.md#add-a-control-source). +### Predictions + +```{eval-rst} +.. module:: scperteval.predictions +.. currentmodule:: scperteval.predictions + +.. autosummary:: + :toctree: generated + + PredictionSet +``` + +`scperteval.predictions.PredictionSet` — model-predicted cells loaded from a `.h5ad` and +gene-aligned to the dataset, used by the `score` command. + ## Context ```{eval-rst} diff --git a/docs/index.md b/docs/index.md index 4686f1f..4a2ea58 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,17 +12,21 @@ protocols across the field and ultimately choose the right approach for a given problem space. scPertEval renders each protocol as a short, readable building block to run, read, reuse, -and contribute back — a place for collaboration and alignment in the field. Run the tool by -specifying a dataset, one or more protocols, and a method of differential expression; the -tool outputs calibration data: the **Dynamic Range Fraction (DRF)** and the **Bound -Discrimination Score (BDS)** — quantifying how well the protocol separates real perturbation -signal from an uninformative baseline (see [Calibration](user-guide/calibration.md)). +and contribute back — a place for collaboration and alignment in the field. The same catalog +of protocols backs three commands: + +- **`score`** — score a model's predictions against ground truth, one metric value per + perturbation (see [Scoring predictions](user-guide/scoring.md)). +- **`calibrate`** — calibrate a protocol against built-in positive/negative controls, reporting + the **Dynamic Range Fraction (DRF)** and **Bound Discrimination Score (BDS)** — how well it + separates real signal from an uninformative baseline (see [Calibration](user-guide/calibration.md)). +- **`de`** — export per-gene differential expression to HDF5. ## Quick start ```bash pip install scperteval -scperteval run data/wessels23.h5ad -p all --de-method t-test +scperteval calibrate data/wessels23.h5ad -p all --de-method t-test ``` ::::{grid} 1 2 3 3 diff --git a/docs/user-guide/index.md b/docs/user-guide/index.md index a34b3f3..05fb8a6 100644 --- a/docs/user-guide/index.md +++ b/docs/user-guide/index.md @@ -1,10 +1,13 @@ # User guide -Start with [Calibration](calibration) to understand what scPertEval measures before diving into usage. +scPertEval runs in two modes: [Scoring predictions](scoring) compares a model's output to +ground truth, while [Calibration](calibration) measures whether a protocol can tell real signal +from an uninformative baseline. Start with whichever matches your goal, then see [Usage](usage). ```{toctree} :maxdepth: 1 +scoring calibration usage protocols diff --git a/docs/user-guide/protocols.md b/docs/user-guide/protocols.md index 5aebf80..a8a0b94 100644 --- a/docs/user-guide/protocols.md +++ b/docs/user-guide/protocols.md @@ -5,8 +5,9 @@ Two files define each protocol: - **[`src/scperteval/protocols/metrics.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/metrics.py)** — the metric, as a - pure function of the ground truth and a `prediction` (whichever control is being scored, - positive or negative). e.g. `mse`, `mmd`, `de_auprc`: + pure function of the ground truth and a `prediction` (the candidate being scored — a positive + or negative control under `calibrate`, or your model's output under `score`). e.g. `mse`, + `mmd`, `de_auprc`: ```python def mse(gt, prediction, ctx): @@ -43,9 +44,10 @@ Here is a complete new protocol: mean absolute error on the standard pseudobulk ``` Every metric function has this signature. `gt` is one perturbation's ground-truth - profile; `prediction` is a control being compared against it (scPertEval calls the function - once for the positive control and once for the negative). `ctx` is the dataset context, - needed by only a few metrics — ignore it otherwise. Return a single number. + profile; `prediction` is the candidate being compared against it (under `calibrate`, scPertEval + calls the function once for the positive control and once for the negative; under `score`, once + with your model's prediction). `ctx` is the dataset context, needed by only a few metrics — + ignore it otherwise. Return a single number. 2. Add a row to [`src/scperteval/protocols/table.py`](https://github.com/Virtual-Cell-Research-Community/scPertEval/blob/main/src/scperteval/protocols/table.py): @@ -55,7 +57,7 @@ Here is a complete new protocol: mean absolute error on the standard pseudobulk better="lower", perfect=0.0) ``` -Run it with `scperteval run data.h5ad -p mae`. That is the whole protocol: MAE between each +Run it with `scperteval calibrate data.h5ad -p mae`. That is the whole protocol: MAE between each perturbation's pseudobulk profile and its positive and negative controls, scored as lower-is-better toward a perfect of 0. @@ -82,7 +84,7 @@ That row is the spec; parameters include: |---|---| | `centroid` | a 1-D pseudobulk vector (one value per gene) | | `population` | a `(cells × genes)` matrix | -| `de` | a `DEResult` (for `gt`) / per-gene `\|score\|` ranking (for a prediction) | +| `de` | a `DEResult` (for the ground truth) / per-gene `\|score\|` ranking (for a prediction) | **`scope`** is the independent companion axis — *how many* perturbations the metric sees at once: @@ -147,13 +149,18 @@ all_perturbed (cells) — all-perturbed reference sample, leave-one-out (single all_perturbed_mean (centroid) — all-perturbed mean, excluding the target — leave-one-out (pseudobulk sibling of all_perturbed; pseudobulk negative control) control (cells) — non-targeting control cells global_mean (centroid) — mean of all perturbations — shared baseline for the ranking protocols -gt (cells) — ground truth — the first half of a perturbation's cells +gt_all_cells (cells) — ground truth — all of a perturbation's real cells (prediction-scoring truth) +gt_half (cells) — ground truth — the first half of a perturbation's cells (calibration truth) interpolated (centroid) — interpolated duplicate — DE-weighted blend of the held-out half and the dataset mean (pseudobulk positive control) +prediction (cells) — model-predicted cells for the perturbation, from the --predictions h5ad tech_dup (cells) — technical duplicate — the held-out second half (single-cell positive control) ``` Each `provides` cells or a pseudobulk `centroid`. Use via `positive=`/`negative=` (or -`--positive`/`--negative`). To add another, see [Add a control source](building-blocks.md#add-a-control-source). +`--positive`/`--negative`). The truth source is chosen by the command, not by a protocol: +`calibrate` uses `gt_half` (holding the other half out as the positive control), while `score` +uses `gt_all_cells` and compares it to `prediction`. To add another, see +[Add a control source](building-blocks.md#add-a-control-source). **Calibrators** (the `--output` choice) @@ -161,9 +168,11 @@ Each `provides` cells or a pseudobulk `centroid`. Use via `positive=`/`negative= $ scperteval list calibrators drf — Dynamic Range Fraction — mean/median over perturbations (Miller et al. 2025) bds — Bound Discrimination Score — fraction of perturbations the positive control wins (SBB 2026) +score — raw metric of a prediction vs ground truth — mean/median over perturbations (prediction-scoring mode) ``` -Chosen with `--output`. To add another, see [Add a calibrator](building-blocks.md#add-a-calibrator). +`drf`/`bds` are chosen with `calibrate --output`; `score` is selected automatically by the +`score` command. To add another, see [Add a calibrator](building-blocks.md#add-a-calibrator). ### More examples @@ -187,7 +196,7 @@ top-50 DEGs: Protocol("mae_top50", M.mae, representation="centroid", space="top_50", **_PB, **_LOWER) ``` -**Expose the space as a knob (parameterised).** To make `k` adjustable per run, add a +**Expose the space as a knob (parameterised).** To make `k` adjustable per invocation, add a `param` to the same `Protocol(...)` row — nothing else changes. The row's name carries the parameter, and the value is supplied on the CLI: @@ -195,7 +204,7 @@ parameter, and the value is supplied on the CLI: Protocol("mae_top_k", M.mae, representation="centroid", param=top_k, **_PB, **_LOWER) ``` -Then `scperteval run data.h5ad -p mae_top_k=30` (or `mae_top_k` for the default `k=50`). The +Then `scperteval calibrate data.h5ad -p mae_top_k=30` (or `mae_top_k` for the default `k=50`). The families are `top_k` (top-k DEGs), `pca_k` (k PCs), and `degs_padj` (DEGs at adjusted p < padj) for the space, and `overlap_k` to feed an integer straight to the metric. diff --git a/docs/user-guide/scoring.md b/docs/user-guide/scoring.md new file mode 100644 index 0000000..f0557ba --- /dev/null +++ b/docs/user-guide/scoring.md @@ -0,0 +1,48 @@ +# Scoring predictions + +`scperteval score dataset.h5ad predictions.h5ad` is the conventional evaluation: each protocol's +metric is applied to your **predicted** cells against the **real** cells, one score per +perturbation. It runs the *same* protocol catalog as [`calibrate`](calibration.md); only two +pieces differ. + +## Inputs + +- **ground truth** — *all* of a perturbation's real cells (the `gt_all_cells` source). Unlike + calibration, no half is held out and no positive/negative controls are built — the ground truth + is the whole real population. +- **prediction** — the matching cells from your `predictions.h5ad` (the `prediction` source). + The prediction file must contain the dataset's exact gene set (any order — columns are + reordered by name so the comparison lines up gene-for-gene) and the same perturbation labels. + A gene-set mismatch, or a perturbation present in the dataset but absent from the predictions, + raises an error naming exactly what's wrong. + +## Output + +The `score` calibrator reports each protocol's raw metric value per perturbation and its +mean/median across perturbations, written to `____score.csv`. Higher- vs +lower-is-better follows each protocol's `better` field, exactly as in calibration. + +```bash +scperteval score data/wessels23.h5ad predictions.h5ad -p pearson,mse,de_auprc,unbiased_mmd_median_pca_k=20 +``` + +| protocol | representation | perfect prediction | degraded prediction | +|---|---|---|---| +| `pearson` | centroid | 1.000 | 0.993 | +| `mse` | centroid | 0.000 | 0.004 | +| `de_auprc` | de | 1.000 | 0.297 | +| `unbiased_mmd_median_pca_k=20` | population | ≈0 | 0.199 | + +(An exact replica of the real cells scores optimally; a prediction degraded toward the control +mean scores worse on every representation.) + +## How it relates to calibration + +Architecturally this reuses everything — the per-perturbation loop, every metric, representation, +and feature space are shared with `calibrate`. The only differences are the **truth source** +(`gt_all_cells` instead of the held-out `gt_half`) and the **calibrator** (`score`, which needs +only the prediction, instead of `drf`/`bds`, which need both controls). The DE-derived feature +spaces (`top_k`, `degs`) and the WMSE weights are computed from this same all-cells ground truth. + +Use `score` to measure how good a model's predictions are; use [`calibrate`](calibration.md) to +decide whether a given protocol is trustworthy enough to report those scores in the first place. diff --git a/docs/user-guide/usage.md b/docs/user-guide/usage.md index 3134453..fe20403 100644 --- a/docs/user-guide/usage.md +++ b/docs/user-guide/usage.md @@ -24,19 +24,30 @@ No gcloud account is needed — each file is also reachable over plain HTTPS at ## Run it +The same protocol catalog backs three commands: **`calibrate`** (calibrate a protocol against +built-in controls → DRF/BDS, see [Calibration](calibration.md)), **`score`** (score a model's +predictions against ground truth, see [Scoring predictions](scoring.md)), and **`de`** (export +per-gene differential expression). + +## Calibrate it + ```bash # protocols by name — including parameterised ones (set k / padj per protocol) -scperteval run data/wessels23.h5ad -p pearson_ctrl,unbiased_mmd_median_pca_k=20,de_overlap_k=10 --de-method t-test +scperteval calibrate data/wessels23.h5ad -p pearson_ctrl,unbiased_mmd_median_pca_k=20,de_overlap_k=10 --de-method t-test # a parameterised protocol with no value uses its default (k=50, padj=0.05) -scperteval run data/wessels23.h5ad -p unbiased_mmd_median_top_k --de-method MWU +scperteval calibrate data/wessels23.h5ad -p unbiased_mmd_median_top_k --de-method MWU # a whole group, or everything (parameterised protocols use their defaults) -scperteval run data/wessels23.h5ad -p distributional --de-method MWU -scperteval run data/wessels23.h5ad -p all --de-method t-test +scperteval calibrate data/wessels23.h5ad -p distributional --de-method MWU +scperteval calibrate data/wessels23.h5ad -p all --de-method t-test # DRF calibration only (compute DRF only; exclude BDS) -scperteval run data/wessels23.h5ad -p pearson_ctrl --de-method t-test --output drf +scperteval calibrate data/wessels23.h5ad -p pearson_ctrl --de-method t-test --output drf + +# SCORE predictions against ground truth — predicted cells vs real cells, per protocol. +# predictions.h5ad must have the same genes and perturbation labels as the dataset. +scperteval score data/wessels23.h5ad predictions.h5ad -p pearson,mse,de_auprc --de-method t-test # DE only — writes per-gene statistic + adjusted p to HDF5 (no protocol calibration) # Provided as a convenience, since DE methods are tightly coupled with some evaluation protocols @@ -46,20 +57,21 @@ scperteval de data/wessels23.h5ad --methods MWU scperteval list protocols # also: de-methods | spaces | sources | calibrators ``` -Each run prints a summary table and writes a per-perturbation CSV -`____drf.csv` (the raw control values and the calibrated score for -every perturbation). `--profile` adds a per-protocol wall-clock timing CSV. +Each command prints a summary table and writes a per-perturbation CSV named +`____.csv`: `calibrate` writes the raw control values and the +calibrated DRF/BDS per perturbation (`…__drf.csv` / `…__bds.csv`); `score` writes the raw metric +value per perturbation (`…__score.csv`). `--profile` adds a per-protocol wall-clock timing CSV. **DE backends** (`scperteval list de-methods`): `t-test` (default, Welch's, moment-based), `MWU` (Cliff's δ via illico), and `t-test_overestim_var` (scanpy's conservative-variance variant — the reference variance is scaled by the target's cell count). Select one with -`--de-method` for a `run`, or list several with `--methods` for a `de` export. The overestim -variant is a selectable backend for new protocols; no current protocol uses it. +`--de-method` for a `calibrate`/`score`, or list several with `--methods` for a `de` export. The +overestim variant is a selectable backend for new protocols; no current protocol uses it. -
scperteval run --help +
scperteval calibrate --help ```text -usage: scperteval run [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] +usage: scperteval calibrate [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] [--subsample SUBSAMPLE] [--seed SEED] [--positive POSITIVE] [--negative NEGATIVE] [--output {drf,bds}] [--out-dir OUT_DIR] [--workers WORKERS] [--perturbation-key PERTURBATION_KEY] @@ -81,6 +93,27 @@ usage: scperteval run [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overes
+
scperteval score --help + +```text +usage: scperteval score [-h] [-p PROTOCOLS] [--de-method {MWU,t-test,t-test_overestim_var}] + [--subsample SUBSAMPLE] [--seed SEED] [--out-dir OUT_DIR] [--workers WORKERS] + [--perturbation-key PERTURBATION_KEY] [--control-label CONTROL_LABEL] + [--min-cells MIN_CELLS] [--profile] [--quiet] + dataset predictions + + dataset preprocessed .h5ad — the ground truth (real cells) + predictions predicted .h5ad — same genes and perturbation labels as the dataset + -p, --protocols comma-separated names, a group, or 'all' + --de-method DE backend for the de_* protocols, the top_k/degs spaces, and WMSE weights + --subsample cells in the all-perturbed reference (the ground truth is never subsampled) +``` + +Unlike `calibrate`, there are no `--positive`/`--negative`/`--output` options: the candidate is +always your prediction and the output is always the raw `score`. + +
+ ## Use it from Python Install with `pip install scperteval` (or, from this repo, @@ -90,7 +123,12 @@ The simplest path mirrors the CLI — call it via subprocess, exactly as the fig ```python import subprocess, sys -subprocess.run([sys.executable, "-m", "scperteval", "run", "data/wessels23.h5ad", +subprocess.run([sys.executable, "-m", "scperteval", "calibrate", "data/wessels23.h5ad", "-p", "all", "--de-method", "t-test", "--out-dir", "results"], check=True) # -> results/wessels23____drf.csv (raw control values + calibrated DRF per perturbation) + +# score predictions against ground truth instead: +subprocess.run([sys.executable, "-m", "scperteval", "score", "data/wessels23.h5ad", + "predictions.h5ad", "-p", "all", "--out-dir", "results"], check=True) +# -> results/wessels23____score.csv (raw metric value per perturbation) ``` From 2d368eb084356af14f770cf14481c38fdc822672 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 29 Jun 2026 19:38:27 -0700 Subject: [PATCH 11/15] chore(lint): make ruff, mypy, and pyright pass clean Add the missing public docstrings (D102/D103), fix D205 summary lines and D301 r-strings, and resolve mypy/pyright narrowing on Param|None and the None-initialised Context caches. --- src/scperteval/blocks/de.py | 6 +-- src/scperteval/blocks/spaces.py | 7 +++- src/scperteval/calibrators.py | 7 ++-- src/scperteval/cli.py | 21 +++++++--- src/scperteval/context.py | 65 +++++++++++++++++++---------- src/scperteval/dataset.py | 15 +++++-- src/scperteval/predictions.py | 1 + src/scperteval/protocols/metrics.py | 27 ++++++------ src/scperteval/reference.py | 4 +- src/scperteval/registry.py | 1 + src/scperteval/runner.py | 22 ++++++---- src/scperteval/sources.py | 12 +++++- src/scperteval/types.py | 7 +++- 13 files changed, 131 insertions(+), 64 deletions(-) diff --git a/src/scperteval/blocks/de.py b/src/scperteval/blocks/de.py index 204fc80..3879315 100644 --- a/src/scperteval/blocks/de.py +++ b/src/scperteval/blocks/de.py @@ -33,7 +33,7 @@ def de_my_test(target, reference): def moments(X): - """Per-gene mean, sample variance, and cell count for a cell matrix. + r"""Per-gene mean, sample variance, and cell count for a cell matrix. Sparse- and dense-aware; uses :math:`\\text{Var}(X) = E[X^2] - E[X]^2` with Bessel's correction. @@ -138,10 +138,10 @@ def de_ttest(target, reference) -> DEResult: @DE_METHODS.register( "t-test_overestim_var", description="scanpy's conservative t-test variant; reference variance scaled by the " - "target's cell count (selectable backend; not used by any current protocol)", + "target's cell count (selectable backend; not used by any current protocol)", ) def de_ttest_overestim(target, reference) -> DEResult: - """scanpy ``rank_genes_groups(method='t-test_overestim_var')``. + """Scanpy ``rank_genes_groups(method='t-test_overestim_var')``. Identical to Welch's t-test except the reference group's cell count is replaced by the target's, which inflates the reference standard-error term ("overestimating" its variance diff --git a/src/scperteval/blocks/spaces.py b/src/scperteval/blocks/spaces.py index 26e34f9..1f05059 100644 --- a/src/scperteval/blocks/spaces.py +++ b/src/scperteval/blocks/spaces.py @@ -20,6 +20,7 @@ @SPACES.register("full", global_space=True, description="all genes, no transform") def space_full(X, ctx, pert): + """Identity space: all genes, densified, no transform.""" return to_dense(X) @@ -32,7 +33,11 @@ def register_de_space(name, field, top=None, threshold=None, description=""): def space(X, ctx, pert): values = _field(ctx.de(pert, ctx.cfg.truth), field) - keep = np.argsort(-np.abs(values))[:top] if top is not None else np.where(threshold(values))[0] + if top is not None: + keep = np.argsort(-np.abs(values))[:top] + else: + assert threshold is not None # register_de_space takes exactly one of top/threshold + keep = np.where(threshold(values))[0] return to_dense(X[:, keep]) SPACES.add(name, space, description=description) diff --git a/src/scperteval/calibrators.py b/src/scperteval/calibrators.py index a1b9a45..c7c9298 100644 --- a/src/scperteval/calibrators.py +++ b/src/scperteval/calibrators.py @@ -1,6 +1,7 @@ -"""Calibrators turn the raw metric values measured on each control into a final -per-metric score. Each declares the control roles it needs, a per-perturbation -combine, and a cross-perturbation aggregate. +"""Calibrators turn raw control metric values into a final per-metric score. + +Each declares the control roles it needs, a per-perturbation combine, and a +cross-perturbation aggregate. """ from __future__ import annotations diff --git a/src/scperteval/cli.py b/src/scperteval/cli.py index 4814654..86909e3 100644 --- a/src/scperteval/cli.py +++ b/src/scperteval/cli.py @@ -20,7 +20,7 @@ def _concrete(p: Protocol) -> Protocol: """A tunable protocol at its default value; a fixed protocol unchanged.""" - return p.resolve(p.param.default) if p.parameterised else p + return p.resolve(p.param.default) if p.parameterised else p # type: ignore[union-attr] def _resolve_token(token: str) -> list[Protocol]: @@ -33,7 +33,7 @@ def _resolve_token(token: str) -> list[Protocol]: p = PROTOCOLS.get(name) if p is None or not p.parameterised: raise SystemExit(f"unknown tunable protocol {name!r}; try `scperteval list protocols`") - return [p.resolve(p.param.cast(value))] + return [p.resolve(p.param.cast(value))] # type: ignore[union-attr] p = PROTOCOLS.get(token) if p is None: raise SystemExit(f"unknown protocol {token!r}; try `scperteval list protocols`") @@ -41,6 +41,7 @@ def _resolve_token(token: str) -> list[Protocol]: def resolve_protocols(specs: list[str]) -> list[Protocol]: + """Resolve CLI protocol specs to a de-duplicated list of concrete protocols.""" out: list[Protocol] = [] for spec in specs: for token in spec.split(","): @@ -54,9 +55,10 @@ def resolve_protocols(specs: list[str]) -> list[Protocol]: def _evaluate(cfg: RunConfig, protocols, ctx, quiet: bool) -> None: - """Run every protocol over the dataset, print the summary, and write the per-perturbation - CSV. Shared by ``calibrate`` and ``score`` (prediction vs ground truth); they differ only - in how ``ctx`` is built and which calibrator ``cfg.output`` selects. + """Run every protocol over the dataset, print the summary, and write the CSV. + + Shared by ``calibrate`` and ``score`` (prediction vs ground truth); they differ only in + how ``ctx`` is built and which calibrator ``cfg.output`` selects. """ calibrator = CALIBRATORS[cfg.output] ctx.warm(protocols) @@ -76,6 +78,7 @@ def _evaluate(cfg: RunConfig, protocols, ctx, quiet: bool) -> None: def cmd_calibrate(args) -> None: + """Run the ``calibrate`` command: score protocols against built-in controls (DRF/BDS).""" protocols = resolve_protocols(args.protocols or ["all"]) cfg = RunConfig( dataset=args.dataset, @@ -98,6 +101,7 @@ def cmd_calibrate(args) -> None: def cmd_score(args) -> None: + """Run the ``score`` command: score predictions against ground truth, per protocol.""" protocols = resolve_protocols(args.protocols or ["all"]) cfg = RunConfig( dataset=args.dataset, @@ -115,6 +119,7 @@ def cmd_score(args) -> None: predictions=args.predictions, truth="gt_all_cells", ) + assert cfg.predictions is not None # required positional on the score subcommand ds = Dataset.load(cfg.dataset, cfg) ctx = Context(ds, cfg) ctx.predictions = PredictionSet.load(cfg.predictions, ds, cfg) @@ -122,6 +127,7 @@ def cmd_score(args) -> None: def cmd_de(args) -> None: + """Run the ``de`` command: export per-gene differential expression to HDF5.""" methods = [m.strip() for m in args.methods.split(",") if m.strip()] cfg = RunConfig( dataset=args.dataset, @@ -144,6 +150,8 @@ def cmd_de(args) -> None: def cmd_list(args) -> None: + """Run the ``list`` command: print the available building blocks of one category.""" + def reg(registry, fmt): return [fmt(n, registry.meta(n)) for n in registry.names()] @@ -163,10 +171,13 @@ def descr(p): lines = reg(SOURCES, lambda n, m: f"{n:14s} ({m.get('provides')}) — {m.get('description', '')}") elif args.what == "calibrators": lines = [f"{n:6s} — {c.description}" for n, c in CALIBRATORS.items()] + else: + raise AssertionError(f"unexpected list target: {args.what!r}") print("\n".join(lines)) def main(argv=None) -> None: + """Parse arguments and dispatch to the selected subcommand.""" parser = argparse.ArgumentParser(prog="scperteval", description=__doc__) sub = parser.add_subparsers(dest="cmd", required=True) diff --git a/src/scperteval/context.py b/src/scperteval/context.py index 40ba80d..6869f67 100644 --- a/src/scperteval/context.py +++ b/src/scperteval/context.py @@ -1,10 +1,13 @@ -"""The per-run engine: lazily builds and caches the shared building blocks, and -turns a (perturbation, source) into the exact view a protocol consumes. +"""The per-run engine that turns a (perturbation, source) into a protocol's view. + +Lazily builds and caches the shared building blocks, and turns a (perturbation, source) +into the exact view a protocol consumes. """ from __future__ import annotations import threading +from typing import TYPE_CHECKING, Any import numpy as np @@ -15,6 +18,9 @@ from .sources import SOURCES from .types import Protocol, RunConfig +if TYPE_CHECKING: + from .predictions import PredictionSet + class Context: """Owns the dataset, caches DE / PCA / control mean, and dispatches views. @@ -26,7 +32,7 @@ class Context: def __init__(self, dataset: Dataset, cfg: RunConfig): self.ds = dataset self.cfg = cfg - self.predictions = None # a PredictionSet in prediction-scoring mode, else None + self.predictions: PredictionSet | None = None # set in prediction-scoring mode self._local = threading.local() # Reentrant: several lazy initialisers (e.g. _ensure_ref_sums, ref_projection) # call reference() while already holding this lock, which a plain Lock would @@ -35,19 +41,21 @@ def __init__(self, dataset: Dataset, cfg: RunConfig): self._de: dict = {} self._mom: dict = {} self._weights: dict = {} - self._pca = None + self._pca: Any = None self._pca_k = 0 - self._control_mean = None - self._reference = None + self._control_mean: np.ndarray | None = None + self._reference: Reference | None = None self._ref_proj: dict = {} - self._ref_sums = None + self._ref_sums: tuple | None = None @property def perturbations(self): + """The list of perturbations evaluated in this run.""" return self.ds.perturbations @property def current_pert(self): + """The perturbation the current worker thread is processing.""" return getattr(self._local, "pert", None) @current_pert.setter @@ -55,8 +63,9 @@ def current_pert(self, value): self._local.pert = value def warm(self, protocols): - """Precompute shared singletons before the parallel loop so per-perturbation - threads only ever write per-perturbation cache keys. + """Precompute shared singletons before the parallel loop. + + So per-perturbation threads only ever write per-perturbation cache keys. """ self.control_mean() if any(p.representation in ("population", "de") for p in protocols): @@ -72,6 +81,7 @@ def warm(self, protocols): self.ref_projection(space) def view(self, pert: str, source: str, p: Protocol): + """Return ``source``'s datapoint for ``pert`` in the shape ``p`` consumes.""" if p.representation == "population": if source == "all_perturbed": return self._reference_population(p.space, pert) @@ -84,6 +94,7 @@ def view(self, pert: str, source: str, p: Protocol): raise ValueError(f"unknown protocol representation {p.representation!r}") def centroid(self, pert, source, centering): + """Pseudobulk centroid of ``source`` for ``pert``, optionally centered.""" arr = SOURCES[source](self, pert) if SOURCES.meta(source).get("provides") == "centroid": v = np.asarray(arr, dtype=np.float64).ravel() @@ -96,9 +107,10 @@ def centroid(self, pert, source, centering): return v def _de_view(self, pert, source, p): - """GT -> truth labels (its DEResult); a candidate -> its |score| ranking. - The negative candidate is tested against ``neg_reference`` (e.g. control) - rather than ``reference`` (the all-perturbed sample), the hybrid DE setup. + """Return the DE view: the truth's DEResult, or a candidate's ``|score|`` ranking. + + The negative candidate is tested against ``neg_reference`` (e.g. control) rather than + ``reference`` (the all-perturbed sample), the hybrid DE setup. """ if source == self.cfg.truth: return self.de(pert, self.cfg.truth, p.reference) @@ -106,8 +118,10 @@ def _de_view(self, pert, source, p): return np.abs(self.de(pert, source, reference).score) def de(self, pert, source, reference="all_perturbed"): - """DE for one (source vs reference) comparison; the reference moments are - leave-one-out, so a perturbation is never compared against a sample of itself. + """Differential expression for one (source vs reference) comparison, cached. + + The reference moments are leave-one-out, so a perturbation is never compared against + a sample of itself. """ method = self.cfg.de_method key = (self._mom_key(source, pert), self._mom_key(reference, pert), method) @@ -150,8 +164,9 @@ def wmse_weights(self, pert): # -- the all-perturbed reference: one sample, served leave-one-out ------------- def reference(self) -> Reference: - """The all-perturbed sample (subsampled + densified once), with each cell's - perturbation recorded so it can be served leave-one-out. + """The all-perturbed sample, subsampled and densified once. + + Each cell's perturbation is recorded so the sample can be served leave-one-out. """ if self._reference is None: with self._init_lock: @@ -162,8 +177,9 @@ def reference(self) -> Reference: return self._reference def _reference_population(self, space, pert): - """The reference in a feature space with the target perturbation removed: - project the whole sample (cached for global spaces) then drop its rows. + """The reference in a feature space with the target perturbation removed. + + Project the whole sample (cached for global spaces) then drop its rows. """ ref = self.reference() if SPACES.meta(space).get("global_space"): @@ -182,8 +198,10 @@ def ref_projection(self, space): return self._ref_proj[space] def _ensure_ref_sums(self): - """Cache the reference's column sums and sums-of-squares once, so leave-one-out - moments are an O(target cells) subtraction rather than a re-densify per perturbation. + """Cache the reference's column sums and sums-of-squares once. + + Leave-one-out moments are then an O(target cells) subtraction rather than a + re-densify per perturbation. """ if self._ref_sums is None: with self._init_lock: @@ -207,6 +225,7 @@ def _reference_moments(self, pert): return mean, var, k def control_mean(self): + """The control centroid (cached).""" if self._control_mean is None: with self._init_lock: if self._control_mean is None: @@ -225,8 +244,10 @@ def pca(self, k=50): PCA_FIT_CAP = 50000 def _fit_pca(self, n_components): - """Fit PCA on (nearly) all cells; the subsample cap is for the O(n^2) - distance populations, not the PCA basis, which needs many cells to be stable. + """Fit PCA on (nearly) all cells. + + The subsample cap is for the O(n^2) distance populations, not the PCA basis, which + needs many cells to be stable. """ from sklearn.decomposition import PCA diff --git a/src/scperteval/dataset.py b/src/scperteval/dataset.py index 8b7e742..da9da72 100644 --- a/src/scperteval/dataset.py +++ b/src/scperteval/dataset.py @@ -29,6 +29,7 @@ def __init__(self, adata, cfg: RunConfig): @classmethod def load(cls, path: str, cfg: RunConfig) -> Dataset: + """Load a dataset from a preprocessed ``.h5ad`` path.""" return cls(ad.read_h5ad(path), cfg) def _index(self): @@ -51,6 +52,7 @@ def _index(self): self._mean_sum = self._mean_matrix.sum(0) def cells(self, pert: str, half: str | None = None) -> np.ndarray: + """Cells for ``pert``: the first/second split half, or all cells when ``half`` is None.""" if half == "first": idx = self.halves[pert][0] elif half == "second": @@ -60,25 +62,31 @@ def cells(self, pert: str, half: str | None = None) -> np.ndarray: return self.adata.X[idx] def control_cells(self, cap: int) -> np.ndarray: + """A capped subsample of the non-targeting control cells.""" return self.adata.X[self._cap(self.control_idx, cap, "control")] def all_perturbed_indices(self, cap: int) -> np.ndarray: - """One all-perturbed subsample (the reference sample, shared across perturbations). + """Indices of one all-perturbed subsample (the shared reference sample). + The "pool" tag is a fixed reproducibility salt for the draw, not a public name. """ return self._cap(np.where(self.pert != self.cfg.control_label)[0], cap, "pool") def allpert_mean_except(self, pert: str) -> np.ndarray: + """Mean of all per-perturbation means, excluding ``pert`` (leave-one-out).""" k = len(self.perturbations) return (self._mean_sum - self._mean_matrix[self._row[pert]]) / max(k - 1, 1) def allpert_mean(self) -> np.ndarray: - """Mean of all per-perturbation means (no target exclusion); a single vector - shared across perturbations, used as the cross-perturbation ranking baseline. + """Mean of all per-perturbation means (no target exclusion). + + A single vector shared across perturbations, used as the cross-perturbation ranking + baseline. """ return self._mean_sum / max(len(self.perturbations), 1) def control_mean(self) -> np.ndarray: + """Pseudobulk centroid of the control cells.""" return np.asarray(self.adata.X[self.control_idx].mean(0)).ravel() def _cap(self, idx: np.ndarray, cap: int, *tags) -> np.ndarray: @@ -89,4 +97,5 @@ def _cap(self, idx: np.ndarray, cap: int, *tags) -> np.ndarray: def to_dense(X) -> np.ndarray: + """Return ``X`` as a dense array (densifying if sparse).""" return X.toarray() if sp.issparse(X) else np.asarray(X) diff --git a/src/scperteval/predictions.py b/src/scperteval/predictions.py index 35a7da2..a963214 100644 --- a/src/scperteval/predictions.py +++ b/src/scperteval/predictions.py @@ -50,6 +50,7 @@ def __init__(self, adata, ds: Dataset, cfg: RunConfig): @classmethod def load(cls, path: str, ds: Dataset, cfg: RunConfig) -> PredictionSet: + """Load a prediction ``.h5ad`` and gene-align it to ``ds``.""" return cls(ad.read_h5ad(path), ds, cfg) def cells(self, pert: str) -> np.ndarray: diff --git a/src/scperteval/protocols/metrics.py b/src/scperteval/protocols/metrics.py index 89dc766..5f62e49 100644 --- a/src/scperteval/protocols/metrics.py +++ b/src/scperteval/protocols/metrics.py @@ -18,7 +18,6 @@ import numpy as np from sklearn.metrics import average_precision_score, roc_auc_score - # --- shared parameter blocks, substituted into docstrings at decoration time --- _CENTROID = """\ @@ -69,18 +68,20 @@ def _doc(**subs): This decorator detects the column position of each placeholder and re-indents all continuation lines to match, so the substituted text stays inside the RST section. """ + def deco(fn): doc = fn.__doc__ for key, value in subs.items(): placeholder = f"%({key})s" while placeholder in doc: idx = doc.index(placeholder) - line_start = doc.rfind('\n', 0, idx) + 1 - indent = ' ' * (idx - line_start) - indented = ('\n' + indent).join(value.split('\n')) - doc = doc[:idx] + indented + doc[idx + len(placeholder):] + line_start = doc.rfind("\n", 0, idx) + 1 + indent = " " * (idx - line_start) + indented = ("\n" + indent).join(value.split("\n")) + doc = doc[:idx] + indented + doc[idx + len(placeholder) :] fn.__doc__ = doc return fn + return deco @@ -105,7 +106,7 @@ def _within_unbiased(sq, n): @_doc(params=_CENTROID) def pearson(gt, prediction, ctx): - """Pearson correlation between pseudobulk profiles. + r"""Pearson correlation between pseudobulk profiles. .. math:: @@ -126,7 +127,7 @@ def pearson(gt, prediction, ctx): @_doc(params=_CENTROID) def mse(gt, prediction, ctx): - """Mean squared error between pseudobulk profiles. + r"""Mean squared error between pseudobulk profiles. .. math:: @@ -146,7 +147,7 @@ def mse(gt, prediction, ctx): @_doc(params=_CENTROID_W) def weighted_mse(gt, prediction, ctx, exp=2.0): - """MSE weighted by ground-truth effect size raised to ``exp``. + r"""MSE weighted by ground-truth effect size raised to ``exp``. Weights are min-max normalised per-gene; high-effect genes contribute more. @@ -176,7 +177,7 @@ def weighted_mse(gt, prediction, ctx, exp=2.0): @_doc(params=_POPULATION) def energy_distance(gt, prediction, ctx): - """Székely–Rizzo energy distance with bias-corrected within-population terms. + r"""Székely–Rizzo energy distance with bias-corrected within-population terms. .. math:: @@ -207,7 +208,7 @@ def energy_distance(gt, prediction, ctx): @_doc(params=_POPULATION) def unbiased_mmd_median(gt, prediction, ctx): - """Unbiased RBF-MMD² with median-heuristic bandwidth (Gretton 2012). + r"""Unbiased RBF-MMD² with median-heuristic bandwidth (Gretton 2012). .. math:: @@ -254,7 +255,7 @@ def unbiased_mmd_median(gt, prediction, ctx): @_doc(params=_POPULATION) def sinkhorn_w2(gt, prediction, ctx, blur=0.05): - """Debiased Sinkhorn 2-Wasserstein distance (geomloss, p=2). + r"""Debiased Sinkhorn 2-Wasserstein distance (geomloss, p=2). .. math:: @@ -296,7 +297,7 @@ def sinkhorn_w2(gt, prediction, ctx, blur=0.05): @_doc(params=_DATASET) def rank_retrieval(gt, prediction, ctx, transpose=False): - """Cross-perturbation retrieval rank — dataset-scope metric, lower is better. + r"""Cross-perturbation retrieval rank — dataset-scope metric, lower is better. Builds the ``(n x n)`` squared-distance matrix between all predicted and ground-truth centroids, then reads off the diagonal rank (column-wise by default). @@ -378,7 +379,7 @@ def de_auroc(gt, prediction, ctx): @_doc(params=_DE) def de_overlap(gt, prediction, ctx, k=50): - """Top-k overlap between ground-truth and predicted DE gene rankings. + r"""Top-k overlap between ground-truth and predicted DE gene rankings. .. math:: diff --git a/src/scperteval/reference.py b/src/scperteval/reference.py index b140301..82b0fa3 100644 --- a/src/scperteval/reference.py +++ b/src/scperteval/reference.py @@ -8,8 +8,7 @@ class Reference: - """A comparison sample of cells (the all-perturbed subsample, or non-targeting - control), served leave-one-out. + """A fixed cell sample (all-perturbed subsample or control), served leave-one-out. ``subset(P)`` returns the sample with perturbation ``P``'s own cells removed, so a perturbation is never scored against a reference that contains itself. When @@ -34,6 +33,7 @@ def keep(self, exclude) -> np.ndarray | None: return mask def subset(self, exclude): + """Return the sample with ``exclude``'s own cells removed (leave-one-out).""" mask = self.keep(exclude) return self.cells if mask is None else self.cells[mask] diff --git a/src/scperteval/registry.py b/src/scperteval/registry.py index d6ba0da..9a81a11 100644 --- a/src/scperteval/registry.py +++ b/src/scperteval/registry.py @@ -34,6 +34,7 @@ def __init__(self, kind: str): def register(self, name: str, **meta) -> Callable: """Decorator that registers a function under ``name`` with optional metadata.""" + def deco(fn: Callable) -> Callable: self._items[name] = (fn, meta) return fn diff --git a/src/scperteval/runner.py b/src/scperteval/runner.py index f24ff31..a16587f 100644 --- a/src/scperteval/runner.py +++ b/src/scperteval/runner.py @@ -13,13 +13,15 @@ def n_workers(cfg) -> int: + """Resolve the worker-thread count (``0`` = auto: CPU count minus 2, capped at 16).""" return cfg.workers if cfg.workers > 0 else max(1, min(16, (os.cpu_count() or 2) - 2)) def resolve_roles(p: Protocol, cfg) -> dict: - """Map each candidate role a calibrator may require to a source name. ``positive`` / - ``negative`` come from the protocol (or a CLI override); ``prediction`` is always the - model-prediction source (used by the ``score`` calibrator). + """Map each candidate calibrator role to a source name. + + ``positive`` / ``negative`` come from the protocol (or a CLI override); ``prediction`` + is always the model-prediction source (used by the ``score`` calibrator). """ return { "positive": cfg.positive if cfg.positive != "auto" else p.positive, @@ -74,11 +76,11 @@ def work(pert): def _run_dataset(p: Protocol, ctx, calibrator: Calibrator, needed: dict): - """Dataset-scope protocols: build every perturbation's gt and control datapoints, hand - the metric the full lists at once, then read off each perturbation's score. + """Score dataset-scope protocols by handing the metric all perturbations at once. - Perturbations are treated as a single group (these datasets are single-covariate); - drf instead ranks within each covariate group. + Build every perturbation's gt and control datapoints, call the metric once on the full + lists, then read off each perturbation's score. Perturbations are treated as a single + group (these datasets are single-covariate); drf instead ranks within each covariate group. """ perts = ctx.perturbations @@ -100,8 +102,10 @@ def collect(source): def compute_de_export(ctx, methods): - """{method: (statistic, pvalue_adj)} matrices (perturbations x genes) for each - method's GT(first-half)-vs-all-perturbed differential expression. + """Per-gene DE matrices for each method, for export. + + Returns ``{method: (statistic, pvalue_adj)}`` matrices (perturbations x genes) for each + method's ground-truth-vs-all-perturbed differential expression. """ out = {} for method in methods: diff --git a/src/scperteval/sources.py b/src/scperteval/sources.py index 6bb0f6d..c482173 100644 --- a/src/scperteval/sources.py +++ b/src/scperteval/sources.py @@ -22,6 +22,7 @@ description="ground truth — the first half of a perturbation's cells (calibration truth)", ) def src_gt_half(ctx, pert): + """Ground-truth cells: the first half of the perturbation's cells.""" return ctx.ds.cells(pert, half="first") @@ -31,6 +32,7 @@ def src_gt_half(ctx, pert): description="ground truth — all of a perturbation's real cells (prediction-scoring truth)", ) def src_gt_all_cells(ctx, pert): + """Ground-truth cells: all of the perturbation's real cells.""" return ctx.ds.cells(pert) @@ -40,6 +42,7 @@ def src_gt_all_cells(ctx, pert): description="model-predicted cells for the perturbation, from the --predictions h5ad", ) def src_prediction(ctx, pert): + """Model-predicted cells for the perturbation, gene-aligned to the dataset.""" return ctx.predictions.cells(pert) @@ -49,11 +52,13 @@ def src_prediction(ctx, pert): description="technical duplicate — the held-out second half (single-cell positive control)", ) def src_tech_dup(ctx, pert): + """Technical-duplicate cells: the perturbation's held-out second half.""" return ctx.ds.cells(pert, half="second") @SOURCES.register("control", provides="cells", description="non-targeting control cells") def src_control(ctx, pert): + """Non-targeting control cells (subsampled).""" return ctx.ds.control_cells(ctx.cfg.subsample) @@ -63,6 +68,7 @@ def src_control(ctx, pert): description="all-perturbed reference sample, leave-one-out (single-cell negative control)", ) def src_all_perturbed(ctx, pert): + """All-perturbed reference cells, with the target perturbation removed.""" return ctx.reference().subset(pert) @@ -73,6 +79,7 @@ def src_all_perturbed(ctx, pert): "(pseudobulk sibling of all_perturbed; pseudobulk negative control)", ) def src_all_perturbed_mean(ctx, pert): + """All-perturbed pseudobulk mean, excluding the target perturbation.""" return ctx.ds.allpert_mean_except(pert) @@ -82,6 +89,7 @@ def src_all_perturbed_mean(ctx, pert): description="mean of all perturbations — shared baseline for the ranking protocols", ) def src_global_mean(ctx, pert): + """Pseudobulk mean over all perturbations (no target exclusion).""" return ctx.ds.allpert_mean() @@ -92,7 +100,9 @@ def src_global_mean(ctx, pert): "the dataset mean (pseudobulk positive control)", ) def src_interpolated(ctx, pert): - """Alpha = 1 - adjusted p per gene (from the run's DE method, vs control); blend toward + """DE-weighted blend toward the held-out replicate, else the all-perturbed mean. + + Alpha = 1 - adjusted p per gene (from the run's DE method, vs control); blend toward the held-out replicate where the gene is significant, else toward the all-perturbed mean. """ tech = np.asarray(to_dense(ctx.ds.cells(pert, half="second"))).mean(0) diff --git a/src/scperteval/types.py b/src/scperteval/types.py index 176299e..0db9605 100644 --- a/src/scperteval/types.py +++ b/src/scperteval/types.py @@ -43,8 +43,9 @@ class DEResult: @dataclass(frozen=True) class Param: - """A protocol's tunable knob — how a CLI value (``k=30``, ``padj=0.05``) is cast, - defaulted, and applied. ``space`` maps the value to a feature-space name; when it is + """A protocol's tunable knob: how a CLI value is cast, defaulted, and applied. + + ``space`` maps the value (``k=30``, ``padj=0.05``) to a feature-space name; when it is ``None`` the value is passed straight to the metric as a keyword argument. """ @@ -106,10 +107,12 @@ class Protocol: @property def parameterised(self) -> bool: + """Whether this protocol takes a CLI-supplied parameter.""" return self.param is not None def resolve(self, value) -> Protocol: """Concrete protocol for a tunable one at ``value`` (sets the space or metric arg).""" + assert self.param is not None # resolve() is only called on parameterised protocols suffix = f"{value:g}" if isinstance(value, float) else str(value) name = f"{self.name}={suffix}" if self.param.space is not None: From e28b813f4cc94d084888e8f5c4128376cb929563 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 29 Jun 2026 19:42:34 -0700 Subject: [PATCH 12/15] test: add calibrate, score, PredictionSet, and CLI coverage Cover DRF/BDS calibration, prediction-vs-ground-truth scoring (perfect=optimal, degraded=worse), PredictionSet gene alignment + mismatch/missing-perturbation errors, and end-to-end calibrate/score/de CLI dispatch. --- tests/conftest.py | 90 +++++++++++++++++++++++++++++++++++++++++ tests/test_calibrate.py | 53 ++++++++++++++++++++++++ tests/test_cli.py | 35 ++++++++++++++++ tests/test_de.py | 13 +++--- tests/test_score.py | 82 +++++++++++++++++++++++++++++++++++++ 5 files changed, 267 insertions(+), 6 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_calibrate.py create mode 100644 tests/test_cli.py create mode 100644 tests/test_score.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a3d3e06 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,90 @@ +"""Shared fixtures and tiny in-memory dataset builders for the test suite.""" + +from __future__ import annotations + +import anndata as ad +import numpy as np +import pytest + +from scperteval.types import RunConfig + +# Each perturbation gets a strong, *distinct* block of DE genes, so ground-truth DEGs +# form a proper subset (de_auprc/auroc are well-defined) and the perturbation signal is +# unambiguous for the calibration controls. +_DE_GENES = { + "pertA": range(0, 6), + "pertB": range(15, 21), + "pertC": range(30, 36), + "pertD": range(45, 51), +} + + +def make_dataset(seed: int = 0, ng: int = 60, n_ctrl: int = 150, n_pert: int = 120) -> ad.AnnData: + """A tiny log-normalised-looking dataset: control + 4 perturbations with distinct DE blocks.""" + rng = np.random.default_rng(seed) + parts = [rng.poisson(1.0, (n_ctrl, ng)).astype(np.float32)] + labels = ["control"] * n_ctrl + for lab, genes in _DE_GENES.items(): + x = rng.poisson(1.0, (n_pert, ng)).astype(np.float32) + x[:, list(genes)] += 6.0 + parts.append(x) + labels += [lab] * n_pert + adata = ad.AnnData(np.vstack(parts)) + adata.var_names = [f"g{i}" for i in range(ng)] + adata.obs["perturbation"] = labels + return adata + + +def make_predictions( + dataset: ad.AnnData, kind: str = "perfect", shuffle_genes: bool = False, seed: int = 1 +) -> ad.AnnData: + """Build a prediction AnnData from the dataset's perturbed cells. + + ``perfect`` is an exact replica (should score optimally); ``degraded`` shrinks each cell + toward the control mean plus noise (a worse prediction). ``shuffle_genes`` permutes the + gene columns to exercise the name-based alignment. + """ + rng = np.random.default_rng(seed) + pert = np.asarray(dataset.obs["perturbation"]).astype(str) + mask = pert != "control" + sub = dataset[mask].copy() + x = np.asarray(sub.X, dtype=np.float32) + if kind == "degraded": + ctrl_mean = np.asarray(dataset.X[pert == "control"]).mean(0) + x = np.clip(0.4 * x + 0.6 * ctrl_mean + rng.normal(0, 0.2, x.shape), 0, None).astype(np.float32) + elif kind != "perfect": + raise ValueError(f"unknown prediction kind {kind!r}") + pred = ad.AnnData(x, obs=sub.obs.copy()) + pred.var_names = list(dataset.var_names) + if shuffle_genes: + pred = pred[:, rng.permutation(pred.n_vars)].copy() + return pred + + +def make_cfg(**kw) -> RunConfig: + """A RunConfig with small, fast, deterministic defaults for tests.""" + base = dict(dataset="-", protocols=[], de_method="t-test", subsample=400, seed=0, min_cells=10, workers=1) + base.update(kw) + return RunConfig(**base) + + +@pytest.fixture +def dataset_adata() -> ad.AnnData: + return make_dataset() + + +@pytest.fixture +def dataset_path(tmp_path, dataset_adata) -> str: + path = tmp_path / "dataset.h5ad" + dataset_adata.write_h5ad(path) + return str(path) + + +@pytest.fixture +def cfg_factory(): + return make_cfg + + +@pytest.fixture +def predictions_factory(): + return make_predictions diff --git a/tests/test_calibrate.py b/tests/test_calibrate.py new file mode 100644 index 0000000..c81111a --- /dev/null +++ b/tests/test_calibrate.py @@ -0,0 +1,53 @@ +"""Calibration mode: DRF/BDS over built-in positive/negative controls.""" + +from __future__ import annotations + +import numpy as np + +from scperteval.calibrators import CALIBRATORS +from scperteval.cli import _concrete +from scperteval.context import Context +from scperteval.dataset import Dataset +from scperteval.protocols.table import PROTOCOLS +from scperteval.runner import run_protocol + + +def _run(name, calibrator, dataset_adata, cfg): + ctx = Context(Dataset(dataset_adata, cfg), cfg) + return run_protocol(_concrete(PROTOCOLS[name]), ctx, CALIBRATORS[calibrator]) + + +def test_calibrators_registered(): + assert {"drf", "bds", "score"} <= set(CALIBRATORS) + # drf/bds need both controls; score needs only the prediction + assert CALIBRATORS["drf"].requires == ("positive", "negative") + assert CALIBRATORS["score"].requires == ("prediction",) + + +def test_drf_rows_have_control_columns(dataset_adata, cfg_factory): + agg, rows, seconds = _run("pearson_ctrl", "drf", dataset_adata, cfg_factory()) + assert seconds >= 0.0 + assert len(rows) == 4 # one row per perturbation + cols = set(rows[0]) + assert {"protocol", "perturbation", "raw_positive", "raw_negative", "drf"} <= cols + assert {"mean", "median"} <= set(agg) + + +def test_drf_positive_for_real_signal(dataset_adata, cfg_factory): + # the positive control (held-out replicate) should beat the uninformative baseline, + # so mean DRF is clearly positive on a dataset with strong perturbation signal. + for name in ("pearson_ctrl", "mse"): + agg, _, _ = _run(name, "drf", dataset_adata, cfg_factory()) + assert agg["mean"] > 0.0, name + + +def test_bds_is_a_fraction(dataset_adata, cfg_factory): + agg, rows, _ = _run("pearson_ctrl", "bds", dataset_adata, cfg_factory()) + assert 0.0 <= agg["bds"] <= 1.0 + assert all(r["bds"] in (0.0, 1.0) for r in rows) # per-perturbation BDS is binary + + +def test_de_protocol_calibrates(dataset_adata, cfg_factory): + # the de representation should produce a finite, well-defined auprc (distinct DE blocks) + agg, _, _ = _run("de_auprc", "drf", dataset_adata, cfg_factory()) + assert np.isfinite(agg["mean"]) diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..f0382f6 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,35 @@ +"""End-to-end CLI dispatch for the calibrate / score / de subcommands.""" + +from __future__ import annotations + +import pytest + +from scperteval.cli import main + + +def test_calibrate_writes_drf_csv(dataset_path, tmp_path): + main(["calibrate", dataset_path, "-p", "pearson_ctrl,mse", "--out-dir", str(tmp_path), "--quiet"]) + assert len(list(tmp_path.glob("*__drf.csv"))) == 1 + + +def test_calibrate_bds_output(dataset_path, tmp_path): + main(["calibrate", dataset_path, "-p", "mse", "--output", "bds", "--out-dir", str(tmp_path), "--quiet"]) + assert len(list(tmp_path.glob("*__bds.csv"))) == 1 + + +def test_score_writes_score_csv(dataset_path, dataset_adata, predictions_factory, tmp_path): + pred_path = tmp_path / "pred.h5ad" + predictions_factory(dataset_adata, kind="degraded").write_h5ad(pred_path) + main(["score", dataset_path, str(pred_path), "-p", "pearson,mse,de_auprc", "--out-dir", str(tmp_path), "--quiet"]) + assert len(list(tmp_path.glob("*__score.csv"))) == 1 + + +def test_de_writes_h5(dataset_path, tmp_path): + main(["de", dataset_path, "--methods", "t-test", "--out-dir", str(tmp_path)]) + assert len(list(tmp_path.glob("*__de.h5"))) == 1 + + +def test_calibrate_rejects_score_output(dataset_path, tmp_path): + # `score` is a scoring-mode calibrator, not selectable from `calibrate --output` + with pytest.raises(SystemExit): + main(["calibrate", dataset_path, "-p", "mse", "--output", "score", "--out-dir", str(tmp_path)]) diff --git a/tests/test_de.py b/tests/test_de.py index 82bf471..1c1bc56 100644 --- a/tests/test_de.py +++ b/tests/test_de.py @@ -1,5 +1,6 @@ """Tests for the differential-expression backends, focused on the scanpy ``t-test_overestim_var`` variant added as a selectable DE method.""" + from __future__ import annotations import anndata as ad @@ -18,8 +19,7 @@ def _scanpy_overestim(Xt, Xr): adata.var_names = names adata.obs["g"] = ["target"] * Xt.shape[0] + ["reference"] * Xr.shape[0] adata.obs["g"] = adata.obs["g"].astype("category") - sc.tl.rank_genes_groups(adata, "g", groups=["target"], reference="reference", - method="t-test_overestim_var") + sc.tl.rank_genes_groups(adata, "g", groups=["target"], reference="reference", method="t-test_overestim_var") res = adata.uns["rank_genes_groups"] order = np.array([int(n) for n in res["names"]["target"]]) scores = np.empty(ng) @@ -32,8 +32,8 @@ def _scanpy_overestim(Xt, Xr): def test_overestim_var_matches_scanpy(): """Our backend reproduces scanpy's t-test_overestim_var statistic and p-values.""" rng = np.random.default_rng(0) - Xt = rng.poisson(1.0, (40, 60)).astype(np.float64) # small target group - Xr = rng.poisson(1.3, (90, 60)).astype(np.float64) # larger reference + Xt = rng.poisson(1.0, (40, 60)).astype(np.float64) # small target group + Xr = rng.poisson(1.3, (90, 60)).astype(np.float64) # larger reference de = de_ttest_overestim(Xt, Xr) sc_scores, sc_pvals = _scanpy_overestim(Xt, Xr) assert np.allclose(de.score, sc_scores, atol=1e-5, rtol=1e-4) @@ -76,8 +76,9 @@ def test_overestim_var_runs_through_export_path(): adata = ad.AnnData(np.vstack(parts).astype(np.float64)) adata.var_names = [f"g{i}" for i in range(ng)] adata.obs["perturbation"] = labels - cfg = RunConfig(dataset="-", protocols=[], de_method="t-test_overestim_var", - subsample=200, seed=0, min_cells=10, workers=1) + cfg = RunConfig( + dataset="-", protocols=[], de_method="t-test_overestim_var", subsample=200, seed=0, min_cells=10, workers=1 + ) ctx = Context(Dataset(adata, cfg), cfg) out = compute_de_export(ctx, ["t-test_overestim_var"]) stat, padj = out["t-test_overestim_var"] diff --git a/tests/test_score.py b/tests/test_score.py new file mode 100644 index 0000000..28f8ba9 --- /dev/null +++ b/tests/test_score.py @@ -0,0 +1,82 @@ +"""Prediction-scoring mode: score predictions against ground truth, and PredictionSet.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from scperteval.calibrators import CALIBRATORS +from scperteval.cli import _concrete +from scperteval.context import Context +from scperteval.dataset import Dataset +from scperteval.predictions import PredictionSet +from scperteval.protocols.table import PROTOCOLS +from scperteval.runner import run_protocol + + +def _score(name, dataset_adata, pred_adata, cfg): + ds = Dataset(dataset_adata, cfg) + ctx = Context(ds, cfg) + ctx.predictions = PredictionSet(pred_adata, ds, cfg) + return run_protocol(_concrete(PROTOCOLS[name]), ctx, CALIBRATORS["score"]) + + +def _score_cfg(cfg_factory): + return cfg_factory(truth="gt_all_cells", output="score") + + +def test_score_rows_have_prediction_column(dataset_adata, predictions_factory, cfg_factory): + pred = predictions_factory(dataset_adata, kind="perfect") + agg, rows, _ = _score("mse", dataset_adata, pred, _score_cfg(cfg_factory)) + assert {"protocol", "perturbation", "raw_prediction", "score"} <= set(rows[0]) + assert {"mean", "median"} <= set(agg) + + +def test_perfect_prediction_is_optimal(dataset_adata, predictions_factory, cfg_factory): + # an exact replica of the real cells must score optimally on every representation, + # even with the prediction's gene columns shuffled (name-based alignment). + pred = predictions_factory(dataset_adata, kind="perfect", shuffle_genes=True) + cfg = _score_cfg(cfg_factory) + assert _score("pearson", dataset_adata, pred, cfg)[0]["mean"] == pytest.approx(1.0, abs=1e-6) + assert _score("mse", dataset_adata, pred, cfg)[0]["mean"] == pytest.approx(0.0, abs=1e-6) + assert _score("de_auprc", dataset_adata, pred, cfg)[0]["mean"] == pytest.approx(1.0, abs=1e-6) + + +def test_degraded_prediction_scores_worse(dataset_adata, predictions_factory, cfg_factory): + cfg = _score_cfg(cfg_factory) + perfect = predictions_factory(dataset_adata, kind="perfect") + degraded = predictions_factory(dataset_adata, kind="degraded") + + def mean(name, pred): + return _score(name, dataset_adata, pred, cfg)[0]["mean"] + + assert mean("mse", degraded) > mean("mse", perfect) # error up + assert mean("de_auprc", degraded) < mean("de_auprc", perfect) # auprc down + + +def test_gene_set_mismatch_raises(dataset_adata, predictions_factory, cfg_factory): + ds = Dataset(dataset_adata, cfg_factory()) + pred_missing = predictions_factory(dataset_adata, kind="perfect")[:, :-1].copy() + with pytest.raises(ValueError, match="gene mismatch"): + PredictionSet(pred_missing, ds, cfg_factory()) + + +def test_missing_perturbation_raises(dataset_adata, predictions_factory, cfg_factory): + ds = Dataset(dataset_adata, cfg_factory()) + pred = predictions_factory(dataset_adata, kind="perfect") + only_a = pred[np.asarray(pred.obs["perturbation"]) == "pertA"].copy() + ps = PredictionSet(only_a, ds, cfg_factory()) + with pytest.raises(ValueError, match="no cells for perturbation"): + ps.cells("pertB") + + +def test_gene_alignment_reorders_by_name(dataset_adata, predictions_factory, cfg_factory): + # a shuffled-gene prediction is reordered to the dataset's gene order + ds = Dataset(dataset_adata, cfg_factory()) + pred = predictions_factory(dataset_adata, kind="perfect", shuffle_genes=True) + ps = PredictionSet(pred, ds, cfg_factory()) + cells = np.asarray(ps.cells("pertA")) + assert cells.shape[1] == len(ds.var_names) + # pertA's DE block (genes 0-5) should be the high-expression columns after realignment + col_means = cells.mean(0) + assert col_means[list(range(0, 6))].min() > col_means[10:].max() From 5552bf858c88414b5b1d20b4335e57a1b78f9b8a Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 29 Jun 2026 19:44:54 -0700 Subject: [PATCH 13/15] ci: add GitHub Actions (lint, type-check, test, docs) and pre-commit hooks CI runs ruff check + format-check, mypy, pyright, and pytest on py3.11/3.12, plus a sphinx docs-build job. pre-commit wires ruff (+autofix), ruff-format, and basic hygiene hooks. --- .github/workflows/ci.yml | 57 ++++++++++++++++++++++++++++++++ .pre-commit-config.yaml | 20 +++++++++++ docs/conf.py | 2 +- docs/extensions/typed_returns.py | 4 +-- 4 files changed, 80 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a082a45 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,57 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint-and-test: + name: lint + test (py${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11", "3.12"] + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + - name: Install package + dev tools + run: | + python -m pip install --upgrade pip + pip install -e . + pip install ruff mypy pyright pytest + - name: Ruff lint + run: ruff check . + - name: Ruff format check + run: ruff format --check . + - name: mypy + run: mypy src/scperteval + - name: pyright + run: pyright src/scperteval + - name: pytest + run: pytest -q + + docs: + name: docs build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + - name: Install package + docs deps + run: | + python -m pip install --upgrade pip + pip install -e . + pip install --group docs + - name: Build HTML docs + run: sphinx-build -b html -n docs docs/_build/html diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..2e7c9e3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +# Run `pre-commit install` once; hooks then run on every commit. +# Update pinned revs with `pre-commit autoupdate`. +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.13 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + - id: check-toml + - id: check-merge-conflict + - id: check-added-large-files + args: [--maxkb=1024] diff --git a/docs/conf.py b/docs/conf.py index 0d9561a..1c8f093 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -117,7 +117,7 @@ pygments_style = "default" katex_prerender = shutil.which(katex.NODEJS_BINARY) is not None -nitpick_ignore = [ # type: ignore +nitpick_ignore = [ # type: ignore # Add exceptions here for links outside your control that fail to resolve # ("py:class", "igraph.Graph"), ] diff --git a/docs/extensions/typed_returns.py b/docs/extensions/typed_returns.py index 1c2e393..5342ccf 100644 --- a/docs/extensions/typed_returns.py +++ b/docs/extensions/typed_returns.py @@ -5,7 +5,7 @@ from collections.abc import Generator, Iterable from sphinx.application import Sphinx -from sphinx.ext.napoleon import NumpyDocstring # type: ignore +from sphinx.ext.napoleon import NumpyDocstring # type: ignore def _process_return(lines: Iterable[str]) -> Generator[str, None, None]: @@ -28,4 +28,4 @@ def _parse_returns_section(self: NumpyDocstring, section: str) -> list[str]: def setup(app: Sphinx): """Set app.""" - NumpyDocstring._parse_returns_section = _parse_returns_section # type: ignore + NumpyDocstring._parse_returns_section = _parse_returns_section # type: ignore From 66ceb66fe8fafecd3419a83dca4dc8347c6e5d28 Mon Sep 17 00:00:00 2001 From: HugoHakem Date: Mon, 29 Jun 2026 19:52:35 -0700 Subject: [PATCH 14/15] docs: restore protocol-table sphinx extension --- docs/extensions/protocol_table.py | 60 +++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 docs/extensions/protocol_table.py diff --git a/docs/extensions/protocol_table.py b/docs/extensions/protocol_table.py new file mode 100644 index 0000000..360ffc0 --- /dev/null +++ b/docs/extensions/protocol_table.py @@ -0,0 +1,60 @@ +"""Sphinx directive that auto-generates a reference table of evaluation protocols.""" +from __future__ import annotations + +from docutils import nodes +from docutils.parsers.rst import Directive +from sphinx.application import Sphinx + + +class ProtocolTableDirective(Directive): + """Emit a table of all protocols from ``scperteval.protocols.TABLE``.""" + + def run(self): + """Build and return the protocol reference table node.""" + from scperteval.protocols import TABLE + + table = nodes.table() + tgroup = nodes.tgroup(cols=4) + table += tgroup + + for _ in range(4): + tgroup += nodes.colspec(colwidth=1) + + thead = nodes.thead() + tgroup += thead + header_row = nodes.row() + for text in ("Name", "Group", "Representation", "Better"): + entry = nodes.entry() + entry += nodes.paragraph(text=text) + header_row += entry + thead += header_row + + tbody = nodes.tbody() + tgroup += tbody + for p in TABLE: + row = nodes.row() + + # Name cell — link to the metric function page if available + name_entry = nodes.entry() + metric_name = p.metric.__name__ if hasattr(p.metric, "__name__") else p.name + ref_id = f"scperteval.protocols.metrics.{metric_name}" + ref = nodes.reference("", p.name, internal=True, refuri=f"generated/{ref_id}.html") + name_para = nodes.paragraph() + name_para += ref + name_entry += name_para + row += name_entry + + for text in (p.group, p.representation, p.better): + entry = nodes.entry() + entry += nodes.paragraph(text=text) + row += entry + + tbody += row + + return [table] + + +def setup(app: Sphinx): + """Register the ``protocol-table`` directive with Sphinx.""" + app.add_directive("protocol-table", ProtocolTableDirective) + return {"version": "0.1", "parallel_read_safe": True} From ae740a812024e90d6654a75a026b14cd8fb1e01b Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Mon, 29 Jun 2026 19:52:35 -0700 Subject: [PATCH 15/15] docs: fix strict sphinx build (-W) De-double LaTeX backslashes left raw by ruff's r-string autofix, escape |effect size| RST substitutions, and add nitpick_ignore for internal type-hint classes. Docs now build clean under -W -n. --- docs/conf.py | 5 ++-- docs/extensions/protocol_table.py | 1 + src/scperteval/blocks/de.py | 2 +- src/scperteval/blocks/spaces.py | 2 +- src/scperteval/context.py | 2 +- src/scperteval/protocols/metrics.py | 38 ++++++++++++++--------------- 6 files changed, 26 insertions(+), 24 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 1c8f093..d356e9c 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -118,6 +118,7 @@ katex_prerender = shutil.which(katex.NODEJS_BINARY) is not None nitpick_ignore = [ # type: ignore - # Add exceptions here for links outside your control that fail to resolve - # ("py:class", "igraph.Graph"), + # Internal classes referenced in type hints but not given their own API page. + ("py:class", "Dataset"), + ("py:class", "scperteval.reference.Reference"), ] diff --git a/docs/extensions/protocol_table.py b/docs/extensions/protocol_table.py index 360ffc0..a0203e6 100644 --- a/docs/extensions/protocol_table.py +++ b/docs/extensions/protocol_table.py @@ -1,4 +1,5 @@ """Sphinx directive that auto-generates a reference table of evaluation protocols.""" + from __future__ import annotations from docutils import nodes diff --git a/src/scperteval/blocks/de.py b/src/scperteval/blocks/de.py index 3879315..028cb41 100644 --- a/src/scperteval/blocks/de.py +++ b/src/scperteval/blocks/de.py @@ -35,7 +35,7 @@ def de_my_test(target, reference): def moments(X): r"""Per-gene mean, sample variance, and cell count for a cell matrix. - Sparse- and dense-aware; uses :math:`\\text{Var}(X) = E[X^2] - E[X]^2` + Sparse- and dense-aware; uses :math:`\text{Var}(X) = E[X^2] - E[X]^2` with Bessel's correction. Parameters diff --git a/src/scperteval/blocks/spaces.py b/src/scperteval/blocks/spaces.py index 1f05059..9fd1e8e 100644 --- a/src/scperteval/blocks/spaces.py +++ b/src/scperteval/blocks/spaces.py @@ -45,7 +45,7 @@ def space(X, ctx, pert): def top_space(k: int) -> str: - """top-k genes by |ground-truth effect size| (registered on demand).""" + """top-k genes by absolute ground-truth effect size (registered on demand).""" name = f"top_{k}" if name not in SPACES: register_de_space( diff --git a/src/scperteval/context.py b/src/scperteval/context.py index 6869f67..80e0ced 100644 --- a/src/scperteval/context.py +++ b/src/scperteval/context.py @@ -152,7 +152,7 @@ def _mom_key(source, pert): return source if source == "control" else (source, pert) def wmse_weights(self, pert): - """Mejia DEG weights: min-max normalised |effect size| of GT vs the reference.""" + """Mejia DEG weights: min-max normalised absolute effect size of GT vs the reference.""" if pert not in self._weights: s = np.abs(self.de(pert, self.cfg.truth, "all_perturbed").score) finite = np.isfinite(s) diff --git a/src/scperteval/protocols/metrics.py b/src/scperteval/protocols/metrics.py index 5f62e49..2720ed1 100644 --- a/src/scperteval/protocols/metrics.py +++ b/src/scperteval/protocols/metrics.py @@ -110,8 +110,8 @@ def pearson(gt, prediction, ctx): .. math:: - r = \\frac{\\sum_g (gt_g - \\bar{gt})(pred_g - \\bar{pred})}{ - \\sqrt{\\sum_g (gt_g - \\bar{gt})^2 \\cdot \\sum_g (pred_g - \\bar{pred})^2}} + r = \frac{\sum_g (gt_g - \bar{gt})(pred_g - \bar{pred})}{ + \sqrt{\sum_g (gt_g - \bar{gt})^2 \cdot \sum_g (pred_g - \bar{pred})^2}} Parameters ---------- @@ -131,7 +131,7 @@ def mse(gt, prediction, ctx): .. math:: - \\text{MSE} = \\frac{1}{G}\\sum_{g=1}^G (gt_g - pred_g)^2 + \text{MSE} = \frac{1}{G}\sum_{g=1}^G (gt_g - pred_g)^2 Parameters ---------- @@ -153,8 +153,8 @@ def weighted_mse(gt, prediction, ctx, exp=2.0): .. math:: - \\text{wMSE} = \\sum_g w_g \\,(gt_g - pred_g)^2, \\quad - w_g \\propto |s_g|^{\\text{exp}} / \\sum_{g'} |s_{g'}|^{\\text{exp}} + \text{wMSE} = \sum_g w_g \,(gt_g - pred_g)^2, \quad + w_g \propto |s_g|^{\text{exp}} / \sum_{g'} |s_{g'}|^{\text{exp}} where :math:`s_g` is the ground-truth DE t-statistic for gene :math:`g`. @@ -181,8 +181,8 @@ def energy_distance(gt, prediction, ctx): .. math:: - E(X, Y) = 2\\,\\mathbb{E}[\\|X - Y\\|] - - \\mathbb{E}[\\|X - X'\\|] - \\mathbb{E}[\\|Y - Y'\\|] + E(X, Y) = 2\,\mathbb{E}[\|X - Y\|] + - \mathbb{E}[\|X - X'\|] - \mathbb{E}[\|Y - Y'\|] Within-population terms use the unbiased (U-statistic) estimator. @@ -212,12 +212,12 @@ def unbiased_mmd_median(gt, prediction, ctx): .. math:: - \\widehat{\\text{MMD}}^2(X, Y) - = \\frac{1}{n(n-1)} \\sum_{i \\neq j} k(x_i, x_j) - + \\frac{1}{m(m-1)} \\sum_{i \\neq j} k(y_i, y_j) - - \\frac{2}{nm} \\sum_{i,j} k(x_i, y_j) + \widehat{\text{MMD}}^2(X, Y) + = \frac{1}{n(n-1)} \sum_{i \neq j} k(x_i, x_j) + + \frac{1}{m(m-1)} \sum_{i \neq j} k(y_i, y_j) + - \frac{2}{nm} \sum_{i,j} k(x_i, y_j) - with :math:`k(x,y) = \\exp(-\\|x-y\\|^2 / 2\\sigma^2)` and :math:`\\sigma` the median + with :math:`k(x,y) = \exp(-\|x-y\|^2 / 2\sigma^2)` and :math:`\sigma` the median pairwise Euclidean distance over the pooled sample. Parameters @@ -259,10 +259,10 @@ def sinkhorn_w2(gt, prediction, ctx, blur=0.05): .. math:: - W_2(X, Y) = \\sqrt{2\\,S_\\varepsilon(X, Y)} + W_2(X, Y) = \sqrt{2\,S_\varepsilon(X, Y)} - where :math:`S_\\varepsilon` is the debiased Sinkhorn divergence with blur - :math:`\\varepsilon`. Requires ``geomloss`` and ``torch``. + where :math:`S_\varepsilon` is the debiased Sinkhorn divergence with blur + :math:`\varepsilon`. Requires ``geomloss`` and ``torch``. Parameters ---------- @@ -304,8 +304,8 @@ def rank_retrieval(gt, prediction, ctx, transpose=False): .. math:: - \\text{rank}(a) = \\frac{\\text{rank}_{\\text{col}}(D_{aa})}{n - 1}, \\quad - D_{ij} = \\|P_i - G_j\\|^2 + \text{rank}(a) = \frac{\text{rank}_{\text{col}}(D_{aa})}{n - 1}, \quad + D_{ij} = \|P_i - G_j\|^2 where :math:`P_i` and :math:`G_j` are the predicted and ground-truth centroids. ``transpose_rank`` transposes the matrix first (each prediction ranked among all GTs). @@ -383,8 +383,8 @@ def de_overlap(gt, prediction, ctx, k=50): .. math:: - \\text{Overlap}_k - = \\frac{|\\text{top-}k(|gt.score|) \\cap \\text{top-}k(pred)|}{k} + \text{Overlap}_k + = \frac{|\text{top-}k(|gt.score|) \cap \text{top-}k(pred)|}{k} Parameters ----------