diff --git a/.gitignore b/.gitignore index a168fb481..04202f523 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,4 @@ __pycache__/ benchmark/benchmarks/data benchmarks/benchmarks/data benchmarks/pkgs +benchmarks/results diff --git a/benchmarks/README.md b/benchmarks/README.md index 86bfd66b7..cbf22a82a 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -74,3 +74,44 @@ All benchmarks: ### View in the browser: You can view the benchmarks in the browser with `asv publish` followed by `asv preview`. If you want to include benchmarks of a local branch, I think you'll have to add that branch to the `"branches"` list in `asv.conf.json`. + +## Dask Chunk Default Exploration + +Issue [#2036](https://github.com/scverse/anndata/issues/2036) needs benchmark data for choosing more sensible virtual Dask chunk defaults when reading chunked HDF5/Zarr arrays lazily. The script below creates dense `X` arrays with controlled on-disk chunks, reads them through `anndata.experimental.read_elem_lazy`, runs a small set of Dask and Scanpy-style workloads, and writes one CSV row per grid point. Rows include runtime package versions, store size, task count, elapsed time, and coarse process/worker memory readings. + +Run a small local smoke benchmark: + +```bash +uv run --group test-min python benchmarks/scripts/dask_chunk_grid.py \ + --shape 1000,250 \ + --store-types h5ad zarr \ + --on-disk-chunks 100,250 \ + --dask-chunks default \ + --dask-chunks 500,-1 \ + --workers 1 \ + --threads-per-worker 1 \ + --workloads sum_axis0 normalize_log1p_slice scanpy_normalize_log1p \ + --repeats 1 \ + --force +``` + +Run a larger grid for analysis: + +```bash +uv run --group test-min python benchmarks/scripts/dask_chunk_grid.py \ + --shape 12000,3000 \ + --store-types h5ad zarr \ + --on-disk-chunks 256,1024 \ + --on-disk-chunks 1024,1024 \ + --dask-chunks default \ + --dask-chunks 1024,-1 \ + --dask-chunks 4096,-1 \ + --workers 1 \ + --workers 4 \ + --threads-per-worker 1 \ + --workloads sum_axis0 sum_axis1 normalize_log1p_slice scanpy_normalize_log1p \ + --repeats 3 \ + --force +``` + +By default, results are written to `benchmarks/results/dask_chunk_grid.csv`. Use `benchmarks/notebooks/dask_chunk_grid_analysis.ipynb` to compare elapsed time, task counts, and coarse memory readings across the grid. diff --git a/benchmarks/notebooks/dask_chunk_grid_analysis.ipynb b/benchmarks/notebooks/dask_chunk_grid_analysis.ipynb new file mode 100644 index 000000000..10235602a --- /dev/null +++ b/benchmarks/notebooks/dask_chunk_grid_analysis.ipynb @@ -0,0 +1,194 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7fb27b941602401d91542211134fc71a", + "metadata": {}, + "source": [ + "# Dask Chunk Grid Analysis\n", + "\n", + "This notebook summarizes CSV output from `benchmarks/scripts/dask_chunk_grid.py` for issue #2036. It is intentionally small: run the benchmark script first, then use the tables and plots below to compare on-disk chunks, virtual Dask chunks, worker settings, task counts, storage overhead, and coarse memory readings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acae54e37e7d407bbb7b55eff062a284", + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "from pathlib import Path\n", + "\n", + "import pandas as pd\n", + "\n", + "results_path = Path(\"../results/dask_chunk_grid.csv\")\n", + "df = pd.read_csv(results_path)\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a63283cbaf04dbcab1f6479b197f3a8", + "metadata": {}, + "outputs": [], + "source": [ + "metadata_cols = [\n", + " \"run_started_at\",\n", + " \"python_version\",\n", + " \"platform\",\n", + " \"anndata_version\",\n", + " \"numpy_version\",\n", + " \"h5py_version\",\n", + " \"zarr_version\",\n", + " \"dask_version\",\n", + " \"distributed_version\",\n", + " \"scanpy_version\",\n", + "]\n", + "\n", + "df[metadata_cols].drop_duplicates()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8dd0d8092fe74a7c96281538738b07e2", + "metadata": {}, + "outputs": [], + "source": [ + "group_cols = [\n", + " \"store_type\",\n", + " \"zarr_format\",\n", + " \"zarr_shards\",\n", + " \"shape\",\n", + " \"on_disk_chunks\",\n", + " \"dask_chunks_arg\",\n", + " \"workers\",\n", + " \"threads_per_worker\",\n", + " \"processes\",\n", + " \"workload\",\n", + "]\n", + "\n", + "summary = (\n", + " df\n", + " .groupby(group_cols, dropna=False)\n", + " .agg(\n", + " elapsed_median_s=(\"elapsed_s\", \"median\"),\n", + " elapsed_min_s=(\"elapsed_s\", \"min\"),\n", + " task_count=(\"task_count\", \"median\"),\n", + " dataset_nbytes=(\"dataset_nbytes\", \"median\"),\n", + " store_nbytes=(\"store_nbytes\", \"median\"),\n", + " worker_rss_after_mb=(\"worker_rss_after_mb\", \"median\"),\n", + " runs=(\"elapsed_s\", \"size\"),\n", + " )\n", + " .reset_index()\n", + " .sort_values([\"workload\", \"elapsed_median_s\"])\n", + ")\n", + "summary[\"store_overhead_ratio\"] = summary[\"store_nbytes\"] / summary[\"dataset_nbytes\"]\n", + "summary.head(20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72eea5119410473aa328ad9291626812", + "metadata": {}, + "outputs": [], + "source": [ + "best_by_workload = summary.loc[\n", + " summary.groupby([\"store_type\", \"workload\"])[\"elapsed_median_s\"].idxmin()\n", + "]\n", + "best_by_workload.sort_values([\"store_type\", \"workload\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8edb47106e1a46a883d545849b8ab81b", + "metadata": {}, + "outputs": [], + "source": [ + "baseline_cols = [\n", + " \"store_type\",\n", + " \"zarr_format\",\n", + " \"zarr_shards\",\n", + " \"shape\",\n", + " \"on_disk_chunks\",\n", + " \"workers\",\n", + " \"threads_per_worker\",\n", + " \"processes\",\n", + " \"workload\",\n", + "]\n", + "\n", + "baseline = summary.loc[\n", + " summary[\"dask_chunks_arg\"] == \"default\", baseline_cols + [\"elapsed_median_s\"]\n", + "]\n", + "baseline = baseline.rename(columns={\"elapsed_median_s\": \"default_elapsed_median_s\"})\n", + "speedups = summary.merge(baseline, on=baseline_cols, how=\"left\")\n", + "speedups[\"speedup_vs_default\"] = (\n", + " speedups[\"default_elapsed_median_s\"] / speedups[\"elapsed_median_s\"]\n", + ")\n", + "speedups.sort_values([\"workload\", \"speedup_vs_default\"], ascending=[True, False]).head(\n", + " 30\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10185d26023b46108eb7d9f57d49d2b3", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "for workload, workload_df in summary.groupby(\"workload\"):\n", + " fig, ax = plt.subplots(figsize=(10, 5))\n", + " labels = []\n", + " for label, label_df in workload_df.groupby([\n", + " \"store_type\",\n", + " \"on_disk_chunks\",\n", + " \"workers\",\n", + " \"threads_per_worker\",\n", + " ]):\n", + " label_df = label_df.sort_values(\"dask_chunks_arg\")\n", + " labels.append(str(label))\n", + " ax.plot(\n", + " label_df[\"dask_chunks_arg\"],\n", + " label_df[\"elapsed_median_s\"],\n", + " marker=\"o\",\n", + " label=str(label),\n", + " )\n", + " ax.set_title(workload)\n", + " ax.set_xlabel(\"Dask chunks argument\")\n", + " ax.set_ylabel(\"Median elapsed seconds\")\n", + " ax.tick_params(axis=\"x\", rotation=45)\n", + " ax.legend(fontsize=\"small\", bbox_to_anchor=(1.05, 1), loc=\"upper left\")\n", + " fig.tight_layout()\n", + " plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/benchmarks/scripts/dask_chunk_grid.py b/benchmarks/scripts/dask_chunk_grid.py new file mode 100644 index 000000000..02ca58496 --- /dev/null +++ b/benchmarks/scripts/dask_chunk_grid.py @@ -0,0 +1,667 @@ +from __future__ import annotations + +import argparse +import csv +import itertools +import json +import platform +import shutil +import time +from contextlib import contextmanager +from dataclasses import asdict, dataclass +from datetime import UTC, datetime +from importlib.metadata import PackageNotFoundError, version +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +import h5py +import numpy as np +import zarr + +import anndata as ad +from anndata.experimental import read_elem_lazy + +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator, Sequence + +ChunkSpec = tuple[int | None, ...] +StoreType = Literal["h5ad", "zarr"] + + +DEFAULT_ON_DISK_CHUNKS: tuple[ChunkSpec, ...] = ((256, 1024), (1024, 1024)) +DEFAULT_DASK_CHUNKS: tuple[ChunkSpec | None, ...] = (None, (1024, -1), (4096, -1)) +DEFAULT_WORKLOADS = ( + "sum_axis0", + "sum_axis1", + "normalize_log1p_slice", + "scanpy_normalize_log1p", +) + + +@dataclass(frozen=True) +class StoreConfig: + store_type: StoreType + on_disk_chunks: ChunkSpec + zarr_format: int + zarr_shards: ChunkSpec | None + + +@dataclass(frozen=True) +class DaskConfig: + chunks: ChunkSpec | None + n_workers: int + threads_per_worker: int + processes: bool + + +@dataclass(frozen=True) +class WorkloadConfig: + shape: tuple[int, int] + slice_obs: int + slice_vars: int + target_sum: float + memory_limit: str + dataset_nbytes: int + run_started_at: str + + +@dataclass(frozen=True) +class BenchmarkResult: + run_started_at: str + python_version: str + platform: str + anndata_version: str + numpy_version: str + h5py_version: str + zarr_version: str + dask_version: str + distributed_version: str + scanpy_version: str + store_type: str + zarr_format: int | str + zarr_shards: str + shape: str + dataset_nbytes: int + store_nbytes: int + on_disk_chunks: str + dask_chunks_arg: str + dask_chunksize: str + workers: int + threads_per_worker: int + processes: bool + workload: str + repeat: int + task_count: int + elapsed_s: float + result_shape: str + result_nbytes: int + driver_max_rss_mb: float | None + worker_rss_before_mb: float | None + worker_rss_after_mb: float | None + + +def parse_chunk_spec(value: str) -> ChunkSpec | None: + normalized = value.strip().lower() + if normalized in {"default", "none", "null"}: + return None + if normalized == "full": + return (-1, -1) + + parts = normalized.replace("x", ",").split(",") + if len(parts) != 2: + msg = f"Expected two comma-separated dimensions, got {value!r}" + raise argparse.ArgumentTypeError(msg) + + parsed: list[int | None] = [] + for part in parts: + part = part.strip() + if part in {"default", "none", "null"}: + parsed.append(None) + continue + if part == "full": + parsed.append(-1) + continue + try: + dim = int(part) + except ValueError as exc: + msg = f"Chunk dimension {part!r} is not an integer" + raise argparse.ArgumentTypeError(msg) from exc + if dim == 0 or dim < -1: + msg = f"Chunk dimension must be positive, -1, or None, got {dim}" + raise argparse.ArgumentTypeError(msg) + parsed.append(dim) + return tuple(parsed) + + +def parse_shape(value: str) -> tuple[int, int]: + parsed = parse_chunk_spec(value) + if parsed is None or any(dim is None or dim <= 0 for dim in parsed): + msg = f"Shape must contain two positive integers, got {value!r}" + raise argparse.ArgumentTypeError(msg) + return (int(parsed[0]), int(parsed[1])) + + +def materialize_chunks(chunks: ChunkSpec, shape: tuple[int, int]) -> tuple[int, int]: + return tuple( + axis_size if chunk in {-1, None} else min(int(chunk), axis_size) + for chunk, axis_size in zip(chunks, shape, strict=True) + ) + + +def format_chunks(chunks: Sequence[int | None] | None) -> str: + if chunks is None: + return "default" + return "x".join("none" if chunk is None else str(chunk) for chunk in chunks) + + +def remove_path(path: Path) -> None: + if path.is_dir(): + shutil.rmtree(path) + elif path.exists(): + path.unlink() + + +def generate_counts(shape: tuple[int, int], seed: int) -> np.ndarray: + rng = np.random.default_rng(seed) + return rng.poisson(lam=1.0, size=shape).astype(np.float32, copy=False) + + +def path_nbytes(path: Path) -> int: + if path.is_file(): + return path.stat().st_size + return sum(child.stat().st_size for child in path.rglob("*") if child.is_file()) + + +def package_version(package: str) -> str: + try: + return version(package) + except PackageNotFoundError: + return "not-installed" + + +def runtime_metadata() -> dict[str, str]: + return { + "python_version": platform.python_version(), + "platform": platform.platform(), + "anndata_version": package_version("anndata"), + "numpy_version": package_version("numpy"), + "h5py_version": package_version("h5py"), + "zarr_version": package_version("zarr"), + "dask_version": package_version("dask"), + "distributed_version": package_version("distributed"), + "scanpy_version": package_version("scanpy"), + } + + +def store_path(base_dir: Path, config: StoreConfig) -> Path: + chunk_label = format_chunks(config.on_disk_chunks) + if config.store_type == "h5ad": + return base_dir / f"X_h5ad_chunks-{chunk_label}.h5ad" + + shard_label = ( + f"_shards-{format_chunks(config.zarr_shards)}" + if config.zarr_shards is not None + else "" + ) + return ( + base_dir + / f"X_zarr-v{config.zarr_format}_chunks-{chunk_label}{shard_label}.zarr" + ) + + +def write_store( + path: Path, + config: StoreConfig, + shape: tuple[int, int], + seed: int, + *, + force: bool, +) -> None: + if path.exists(): + if not force: + return + remove_path(path) + + data = generate_counts(shape, seed) + on_disk_chunks = materialize_chunks(config.on_disk_chunks, shape) + if config.store_type == "h5ad": + with h5py.File(path, "w") as f: + ad.io.write_elem(f, "X", data, dataset_kwargs={"chunks": on_disk_chunks}) + return + + dataset_kwargs: dict[str, object] = {"chunks": on_disk_chunks} + if config.zarr_shards is not None: + dataset_kwargs["shards"] = materialize_chunks(config.zarr_shards, shape) + + root = zarr.open_group(path, mode="w", zarr_format=config.zarr_format) + with ad.settings.override(auto_shard_zarr_v3=False): + ad.io.write_elem(root, "X", data, dataset_kwargs=dataset_kwargs) + zarr.consolidate_metadata(root.store) + + +@contextmanager +def open_x(path: Path, store_type: StoreType) -> Iterator[object]: + if store_type == "h5ad": + with h5py.File(path, "r") as f: + yield f["X"] + return + + root = zarr.open_group(path, mode="r") + yield root["X"] + + +def max_rss_mb() -> float | None: + try: + import resource + except ImportError: + return None + + rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss + if platform.system() == "Darwin": + return rss / (1024 * 1024) + return rss / 1024 + + +def _process_rss_bytes() -> int | None: + try: + import psutil + except ImportError: + return None + return psutil.Process().memory_info().rss + + +def worker_rss_mb(client) -> float | None: + values = client.run(_process_rss_bytes).values() + rss_values = [value for value in values if value is not None] + if not rss_values: + return None + return sum(rss_values) / (1024 * 1024) + + +@contextmanager +def dask_client(config: DaskConfig, memory_limit: str) -> Iterator[object]: + try: + from dask.distributed import Client, LocalCluster + except ImportError as exc: + msg = "dask.distributed is required; run with `uv run --group test-min python`." + raise RuntimeError(msg) from exc + + cluster = LocalCluster( + n_workers=config.n_workers, + threads_per_worker=config.threads_per_worker, + processes=config.processes, + dashboard_address=None, + memory_limit=memory_limit, + ) + client = Client(cluster) + try: + yield client + finally: + client.close() + cluster.close() + + +def build_workload( + array, + workload: str, + *, + slice_obs: int, + slice_vars: int, + target_sum: float, +): + import dask.array as da + + if workload == "sum_axis0": + return array.sum(axis=0) + if workload == "sum_axis1": + return array.sum(axis=1) + if workload == "subset_mean": + return array[:slice_obs, :slice_vars].mean(axis=0) + if workload == "normalize_log1p_slice": + subset = array[:slice_obs, :] + row_sums = subset.sum(axis=1) + safe_row_sums = da.where(row_sums == 0, 1, row_sums) + normalized = subset / safe_row_sums[:, None] * target_sum + return da.log1p(normalized[:, :slice_vars]).mean(axis=0) + if workload == "scanpy_normalize_log1p": + try: + import scanpy as sc + except ImportError as exc: + msg = ( + "scanpy_normalize_log1p requires scanpy; run with " + "`uv run --group test-min python`." + ) + raise RuntimeError(msg) from exc + + adata = ad.AnnData(X=array) + sc.pp.normalize_total(adata, target_sum=target_sum) + sc.pp.log1p(adata) + return adata.X[:, :slice_vars].mean(axis=0) + + msg = f"Unknown workload {workload!r}" + raise ValueError(msg) + + +def task_count(collection) -> int: + return len(collection.__dask_graph__()) + + +def result_nbytes(result: object) -> int: + return int(getattr(result, "nbytes", 0)) + + +def result_shape(result: object) -> str: + shape = getattr(result, "shape", ()) + return format_chunks(tuple(int(dim) for dim in shape)) + + +def run_one( + path: Path, + store_config: StoreConfig, + dask_config: DaskConfig, + *, + workload: str, + repeat: int, + workload_config: WorkloadConfig, +) -> BenchmarkResult: + with ( + dask_client(dask_config, memory_limit=workload_config.memory_limit) as client, + open_x(path, store_config.store_type) as elem, + ): + metadata = runtime_metadata() + array = read_elem_lazy(elem, chunks=dask_config.chunks) + slice_obs = min(workload_config.slice_obs, array.shape[0]) + slice_vars = min(workload_config.slice_vars, array.shape[1]) + workload_array = build_workload( + array, + workload, + slice_obs=slice_obs, + slice_vars=slice_vars, + target_sum=workload_config.target_sum, + ) + + before_worker_rss = worker_rss_mb(client) + started = time.perf_counter() + computed = workload_array.compute() + elapsed_s = time.perf_counter() - started + after_worker_rss = worker_rss_mb(client) + + return BenchmarkResult( + run_started_at=workload_config.run_started_at, + **metadata, + store_type=store_config.store_type, + zarr_format=store_config.zarr_format + if store_config.store_type == "zarr" + else "n/a", + zarr_shards=format_chunks(store_config.zarr_shards), + shape=format_chunks(workload_config.shape), + dataset_nbytes=workload_config.dataset_nbytes, + store_nbytes=path_nbytes(path), + on_disk_chunks=format_chunks(store_config.on_disk_chunks), + dask_chunks_arg=format_chunks(dask_config.chunks), + dask_chunksize=format_chunks(getattr(array, "chunksize", None)), + workers=dask_config.n_workers, + threads_per_worker=dask_config.threads_per_worker, + processes=dask_config.processes, + workload=workload, + repeat=repeat, + task_count=task_count(workload_array), + elapsed_s=elapsed_s, + result_shape=result_shape(computed), + result_nbytes=result_nbytes(computed), + driver_max_rss_mb=max_rss_mb(), + worker_rss_before_mb=before_worker_rss, + worker_rss_after_mb=after_worker_rss, + ) + + +def iter_store_configs(args: argparse.Namespace) -> Iterator[StoreConfig]: + for store_type, on_disk_chunks in itertools.product( + args.store_types, args.on_disk_chunks + ): + if store_type == "h5ad": + yield StoreConfig( + store_type="h5ad", + on_disk_chunks=on_disk_chunks, + zarr_format=args.zarr_format, + zarr_shards=None, + ) + continue + + for zarr_shards in args.zarr_shards: + if args.zarr_format == 2 and zarr_shards is not None: + msg = "Zarr v2 does not support shards; omit --zarr-shards or use --zarr-format 3" + raise ValueError(msg) + yield StoreConfig( + store_type="zarr", + on_disk_chunks=on_disk_chunks, + zarr_format=args.zarr_format, + zarr_shards=zarr_shards, + ) + + +def iter_dask_configs(args: argparse.Namespace) -> Iterator[DaskConfig]: + for chunks, n_workers, threads_per_worker in itertools.product( + args.dask_chunks, args.workers, args.threads_per_worker + ): + yield DaskConfig( + chunks=chunks, + n_workers=n_workers, + threads_per_worker=threads_per_worker, + processes=args.processes, + ) + + +def write_results( + output: Path, rows: Iterable[BenchmarkResult], *, append: bool +) -> None: + output.parent.mkdir(parents=True, exist_ok=True) + mode = "a" if append else "w" + fieldnames = list(BenchmarkResult.__dataclass_fields__) + write_header = not append or not output.exists() + + with output.open(mode, newline="") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + if write_header: + writer.writeheader() + for row in rows: + writer.writerow(asdict(row)) + f.flush() + + +def describe_plan(args: argparse.Namespace) -> dict[str, object]: + normalize_args(args) + store_configs = list(iter_store_configs(args)) + dask_configs = list(iter_dask_configs(args)) + return { + "shape": format_chunks(args.shape), + "dataset_nbytes": int( + np.prod(args.shape, dtype=np.int64) * np.dtype("float32").itemsize + ), + "store_configs": [asdict(config) for config in store_configs], + "dask_configs": [asdict(config) for config in dask_configs], + "workloads": args.workloads, + "repeats": args.repeats, + "result_rows": len(store_configs) + * len(dask_configs) + * len(args.workloads) + * args.repeats, + } + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description=( + "Run an exploratory grid over on-disk chunking, Dask chunks, and " + "Dask worker settings for AnnData lazy array reads." + ) + ) + parser.add_argument( + "--work-dir", + type=Path, + default=Path("benchmarks/results/dask_chunk_grid_stores"), + help="Directory for generated HDF5/Zarr stores.", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("benchmarks/results/dask_chunk_grid.csv"), + help="CSV path for benchmark rows.", + ) + parser.add_argument( + "--shape", + type=parse_shape, + default=(12_000, 3_000), + help="Dense X shape as OBS,VARS.", + ) + parser.add_argument( + "--store-types", + nargs="+", + choices=("h5ad", "zarr"), + default=["h5ad", "zarr"], + help="Storage backends to benchmark.", + ) + parser.add_argument( + "--on-disk-chunks", + action="append", + type=parse_chunk_spec, + default=None, + help="On-disk chunks as OBS,VARS. Repeat for a grid.", + ) + parser.add_argument( + "--dask-chunks", + action="append", + type=parse_chunk_spec, + default=None, + help="Virtual Dask chunks as OBS,VARS, or 'default'. Repeat for a grid.", + ) + parser.add_argument( + "--zarr-format", + type=int, + choices=(2, 3), + default=3, + help="Zarr format for generated Zarr stores.", + ) + parser.add_argument( + "--zarr-shards", + action="append", + type=parse_chunk_spec, + default=None, + help="Zarr v3 shard shape as OBS,VARS. Repeat for a grid.", + ) + parser.add_argument( + "--workers", + action="append", + type=int, + default=None, + help="Dask worker counts to test. Repeat for a grid.", + ) + parser.add_argument( + "--threads-per-worker", + action="append", + type=int, + default=None, + help="Dask threads per worker to test. Repeat for a grid.", + ) + parser.add_argument( + "--no-processes", + action="store_false", + dest="processes", + help="Use threaded workers instead of process workers.", + ) + parser.set_defaults(processes=True) + parser.add_argument( + "--workloads", + nargs="+", + choices=( + "sum_axis0", + "sum_axis1", + "subset_mean", + "normalize_log1p_slice", + "scanpy_normalize_log1p", + ), + default=list(DEFAULT_WORKLOADS), + help="Dask workloads to run.", + ) + parser.add_argument("--repeats", type=int, default=3) + parser.add_argument("--slice-obs", type=int, default=4096) + parser.add_argument("--slice-vars", type=int, default=1024) + parser.add_argument("--target-sum", type=float, default=10_000.0) + parser.add_argument("--memory-limit", default="auto") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--append", action="store_true") + parser.add_argument("--force", action="store_true", help="Regenerate stores.") + parser.add_argument( + "--dry-run", + action="store_true", + help="Print the expanded grid without generating data or running Dask.", + ) + return parser + + +def normalize_args(args: argparse.Namespace) -> None: + if args.on_disk_chunks is None: + args.on_disk_chunks = list(DEFAULT_ON_DISK_CHUNKS) + if args.dask_chunks is None: + args.dask_chunks = list(DEFAULT_DASK_CHUNKS) + if args.zarr_shards is None: + args.zarr_shards = [None] + if args.workers is None: + args.workers = [1, 4] + if args.threads_per_worker is None: + args.threads_per_worker = [1] + + +def main(argv: Sequence[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + normalize_args(args) + + if args.repeats < 1: + parser.error("--repeats must be at least 1") + if any(worker < 1 for worker in args.workers): + parser.error("--workers values must be at least 1") + if any(threads < 1 for threads in args.threads_per_worker): + parser.error("--threads-per-worker values must be at least 1") + + if args.dry_run: + print(json.dumps(describe_plan(args), indent=2, default=format_chunks)) + return 0 + + args.work_dir.mkdir(parents=True, exist_ok=True) + store_configs = list(iter_store_configs(args)) + dask_configs = list(iter_dask_configs(args)) + workload_config = WorkloadConfig( + shape=args.shape, + slice_obs=args.slice_obs, + slice_vars=args.slice_vars, + target_sum=args.target_sum, + memory_limit=args.memory_limit, + dataset_nbytes=int( + np.prod(args.shape, dtype=np.int64) * np.dtype("float32").itemsize + ), + run_started_at=datetime.now(UTC).isoformat(timespec="seconds"), + ) + + def rows() -> Iterator[BenchmarkResult]: + for store_config in store_configs: + path = store_path(args.work_dir, store_config) + write_store(path, store_config, args.shape, args.seed, force=args.force) + for dask_config, workload, repeat in itertools.product( + dask_configs, args.workloads, range(args.repeats) + ): + yield run_one( + path, + store_config, + dask_config, + workload=workload, + repeat=repeat, + workload_config=workload_config, + ) + + write_results(args.output, rows(), append=args.append) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/test_dask_chunk_grid_script.py b/tests/test_dask_chunk_grid_script.py new file mode 100644 index 000000000..86a5686c9 --- /dev/null +++ b/tests/test_dask_chunk_grid_script.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import argparse +import importlib.util +import sys +from pathlib import Path + +import pytest + + +def load_script_module(): + script = Path(__file__).parents[1] / "benchmarks" / "scripts" / "dask_chunk_grid.py" + spec = importlib.util.spec_from_file_location("dask_chunk_grid", script) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def test_parse_chunk_spec_accepts_default_and_axis_chunks(): + module = load_script_module() + + assert module.parse_chunk_spec("default") is None + assert module.parse_chunk_spec("1024,-1") == (1024, -1) + assert module.parse_chunk_spec("512xnone") == (512, None) + + +def test_parse_chunk_spec_rejects_wrong_rank(): + module = load_script_module() + + with pytest.raises(argparse.ArgumentTypeError): + module.parse_chunk_spec("1024") + + +def test_materialize_chunks_clamps_to_shape_and_expands_full_axis(): + module = load_script_module() + + assert module.materialize_chunks((2048, -1), (1000, 300)) == (1000, 300) + assert module.materialize_chunks((None, 128), (1000, 300)) == (1000, 128) + + +def test_path_nbytes_counts_files_and_directories(tmp_path): + module = load_script_module() + file_path = tmp_path / "one.bin" + dir_path = tmp_path / "nested" + file_path.write_bytes(b"1234") + dir_path.mkdir() + (dir_path / "two.bin").write_bytes(b"12") + (dir_path / "three.bin").write_bytes(b"123") + + assert module.path_nbytes(file_path) == 4 + assert module.path_nbytes(dir_path) == 5 + + +def test_describe_plan_counts_expanded_grid(): + module = load_script_module() + parser = module.build_parser() + args = parser.parse_args([ + "--shape", + "100,20", + "--store-types", + "h5ad", + "--on-disk-chunks", + "25,20", + "--dask-chunks", + "50,-1", + "--workers", + "2", + "--workloads", + "sum_axis0", + "--repeats", + "2", + ]) + + plan = module.describe_plan(args) + + assert plan["result_rows"] == 2 + assert plan["shape"] == "100x20" + assert plan["dataset_nbytes"] == 8000