diff --git a/docs/release-notes/parallel-h5-read.feat.md b/docs/release-notes/parallel-h5-read.feat.md new file mode 100644 index 000000000..df7178d32 --- /dev/null +++ b/docs/release-notes/parallel-h5-read.feat.md @@ -0,0 +1,23 @@ + +Multi-threaded HDF5 reads in ``read_h5ad``. Chunked datasets +compressed with **gzip / deflate**, **zstd**, or **blosc** (any inner +codec — lz4, zstd, blosclz, snappy) are now read into memory in +parallel: chunk locations are enumerated via +``Dataset.id.chunk_iter``, adjacent chunks are coalesced into on-disk +extents, and worker threads each ``pread()`` a sub-extent (bypassing +h5py's process-global lock) and decompress their chunks into disjoint +regions of the output array using GIL-releasing decoders from +:mod:`numcodecs`. + +Bit-identical to the serial path; silently falls back to serial for +ineligible datasets (unsupported or multi-filter pipeline, unchunked, +partially written, smaller than ``parallel_h5_read_min_mb``, non-disk +file driver, missing ``chunk_iter`` on h5py < 3.8) or on any +unexpected error. + +Controlled by ``settings.parallel_h5_read`` (default ``True``), +``settings.parallel_h5_read_min_mb`` (default ``64``), and +``settings.parallel_h5_read_workers`` (default ``min(cpu_count, 16)``). diff --git a/pyproject.toml b/pyproject.toml index ea1c21a07..2cc8e57f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "array-api-compat>=1.7.1", "legacy-api-wrap", "zarr >=3.1", + "numcodecs>=0.14", "typing-extensions; python_version<'3.13'", "scverse-misc[settings]>=0.0.7", ] @@ -105,6 +106,7 @@ test-min = [ "awkward>=2.6.3", "pyarrow", "pooch", + "hdf5plugin", # registers HDF5 zstd/blosc filters for parallel-read tests "anndata[dask]", ] test-array-api = [ diff --git a/src/anndata/_core/sparse_dataset.py b/src/anndata/_core/sparse_dataset.py index d5b481d72..262b67f27 100644 --- a/src/anndata/_core/sparse_dataset.py +++ b/src/anndata/_core/sparse_dataset.py @@ -330,6 +330,24 @@ def is_sparse_indexing_overridden( ) +def _read_1d_full(arr: h5py.Dataset | zarr.Array) -> np.ndarray: + """Materialise a 1-D backed array into memory. + + For ``h5py.Dataset`` we attempt the parallel chunk-decompression path + (see :mod:`anndata._io.h5_parallel`); it silently returns ``None`` when + the dataset is ineligible (unsupported codec, too small, unchunked, + multi-dimensional, ...), in which case — and for zarr arrays — we fall + through to the existing serial read. Bit-identical to ``arr[...]``. + """ + if isinstance(arr, h5py.Dataset): + from .._io.h5_parallel import parallel_read_full + + out = parallel_read_full(arr) + if out is not None: + return out + return arr[...] + + class BaseCompressedSparseDataset[GroupT: _GroupStorageType, ArrayT: _ArrayStorageType]( abc._AbstractCSDataset, ABC ): @@ -580,8 +598,8 @@ def to_memory(self) -> SparseMatrixType: shape=self.shape, ) mtx = backed_class.memory_format(self.shape, dtype=self.dtype) - mtx.data = self._data[...] - mtx.indices = self._indices[...] + mtx.data = _read_1d_full(self._data) + mtx.indices = _read_1d_full(self._indices) mtx.indptr = self._indptr return mtx diff --git a/src/anndata/_io/_pool.py b/src/anndata/_io/_pool.py new file mode 100644 index 000000000..c86259cc8 --- /dev/null +++ b/src/anndata/_io/_pool.py @@ -0,0 +1,54 @@ +"""Generic, process-global thread-pool management for IO-side parallelism.""" + +from __future__ import annotations + +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from concurrent.futures import Future + from typing import TypeVar + + R = TypeVar("R") + + +class PoolManager: + def __init__(self, *, thread_name_prefix: str) -> None: + self._lock = threading.Lock() + self._pool: ThreadPoolExecutor | None = None + self._n_workers: int = 0 + self._thread_name_prefix = thread_name_prefix + + def distribute_tasks_to_threads( + self, + n_workers: int, + fn: Callable[..., R], + tasks: Sequence[tuple], + ) -> list[Future[R]]: + """Submit one ``fn(*t)`` call per ``t`` of ``tasks`` to a + pool of ``n_workers`` threads, returning the resulting Futures in + input order. + """ + with self._lock: + self._ensure_pool_locked(n_workers) + assert self._pool is not None # for type checkers + return [self._pool.submit(fn, *task) for task in tasks] + + def reset(self) -> None: + with self._lock: + if self._pool is not None: + self._pool.shutdown(wait=True) + self._pool = None + self._n_workers = 0 + + def _ensure_pool_locked(self, n_workers: int) -> None: + if self._pool is not None and self._n_workers == n_workers: + return + if self._pool is not None: + self._pool.shutdown(wait=False, cancel_futures=False) + self._pool = ThreadPoolExecutor( + max_workers=n_workers, thread_name_prefix=self._thread_name_prefix + ) + self._n_workers = n_workers diff --git a/src/anndata/_io/h5_parallel.py b/src/anndata/_io/h5_parallel.py new file mode 100644 index 000000000..15dd79a70 --- /dev/null +++ b/src/anndata/_io/h5_parallel.py @@ -0,0 +1,452 @@ +"""Parallel HDF5 chunk decompression for whole-array reads. + +By default, HDF5 reads from a single process will use only a single +thread (due to HDF5's design for multi-processing, and h5py's PHIL). +This leads to decompression of all data serializing onto one core. + +When eligible (chunked, supported codec, large enough), this module +uses multi-threading to read HDF5 data into memory, in 3 steps: + +1. **Locate chunks on disk** — get every chunk's on-disk byte + offset/size in a single C-side iteration. +2. **Bulk-pread chunk bytes into memory** — spin up a Python threads + and evenly distribute memory to be read (os.pread()) to the threads. + This bypasses h5py's process-global lock (PHIL). +3. **Decompress + place** — each worker walks its chunks, decompresses + them with the relevant GIL-releasing Python decoder, and writes each + decoded chunk into its destination slice of the output array. + +Note that numcodecs (Zarr's codec package is reused here; any codec from +that package can be added here with a few lines of code. + +Falls back silently to the caller's serial path if any issues occur. +""" + +from __future__ import annotations + +import math +import os +import threading +from concurrent.futures import wait +from itertools import batched, pairwise +from typing import TYPE_CHECKING, NamedTuple + +import numcodecs +import numcodecs.blosc +import numpy as np + +from anndata._settings import settings +from anndata.utils import warn_once + +from ._pool import PoolManager + +if TYPE_CHECKING: + from collections.abc import Callable + + import h5py + +_pool_manager = PoolManager(thread_name_prefix="anndata-h5-parallel") + +# HDF5 filter IDs +# https://support.hdfgroup.org/documentation/hdf5/latest/group___h5_z.html +# https://github.com/HDFGroup/hdf5_plugins/blob/master/docs/RegisteredFilterPlugins.md +_FILTER_DEFLATE = 1 +_FILTER_BLOSC = 32001 +_FILTER_ZSTD = 32015 + +# HDF5 generally has several-KB gaps between some of it's on-disk chunks. +# These primarily come from the HDF5 index, e.g. in the default case of +# it's B-tree, index nodes are written between occasional chunks. +# So for efficiency of reading memory, we read these extra bytes +# (which we won't use), because it's fastest. +# Note that gaps can also occur from unusual HDF5 file writing patterns, so +# in the case of gaps larger than 256KB, we simply do separate reads excluding +# those gaps to avoid reading excessive unused memory. +_EXTENT_ALLOWED_CHUNK_GAPS_BYTES = 256 * 1024 + +# file drivers we can pread() against +# - 'sec2' is the default POSIX driver on Unix. +# - 'stdio' goes through fopen() but still backs a single regular file. +_DISK_BACKED_DRIVERS: frozenset[str] = frozenset({"sec2", "stdio"}) + + +class _UnsupportedDriver(RuntimeError): + """Raised when the h5py file uses a non-disk driver (e.g., 'core', + 'mpio', 'ros3', 'fileobj') for which our bulk ``pread()`` approach + isn't applicable. Caller falls back to serial.""" + + +class _ChunkIterUnavailable(RuntimeError): + """Raised when ``Dataset.id.chunk_iter`` is not present (older h5py + < 3.8 / older HDF5 builds). Recommended to update for multi-threading + (significant performance improvements).""" + + +# ------------------------------------------------------------------------------- +# Codec dispatch (HDF5 filter id -> GIL-releasing decoder) +# ------------------------------------------------------------------------------- + + +_FILTER_TO_DECODER: dict[int, Callable[[bytes], bytes]] = { + _FILTER_DEFLATE: numcodecs.Zlib().decode, # Note that HDF5 'gzip' is actually zlib + _FILTER_BLOSC: numcodecs.Blosc().decode, # inner codec self-described in payload + _FILTER_ZSTD: numcodecs.Zstd().decode, +} + +_blosc_threads_lock = threading.Lock() +_blosc_threads_configured = False + + +def _ensure_blosc_single_threaded() -> None: + global _blosc_threads_configured + if _blosc_threads_configured: + return + with _blosc_threads_lock: + if _blosc_threads_configured: + return + numcodecs.blosc.set_nthreads(1) + _blosc_threads_configured = True + + +# ------------------------------------------------------------------------------- +# Eligibility gating +# ------------------------------------------------------------------------------- + + +def _supported_filter(dset: h5py.Dataset) -> Callable[[bytes], bytes] | None: + # AnnData does not generally deal with HDF5 files with multiple filter + # steps (e.g. shuffle + compressor). Single-filter pipelines only. + plist = dset.id.get_create_plist() + if plist.get_nfilters() != 1: + return None + filter_id = plist.get_filter(0)[0] + decoder = _FILTER_TO_DECODER.get(filter_id) + if decoder is None: + return None + if filter_id == _FILTER_BLOSC: + _ensure_blosc_single_threaded() + return decoder + + +def _eligible_for_multithreading( + dset: h5py.Dataset, *, min_mb: float +) -> tuple[Callable[[bytes], bytes], int] | None: + # Nothing to parallelize if no chunks + if dset.chunks is None: + return None + n_chunks = dset.id.get_num_chunks() + if n_chunks < 2: + return None + + # Reject datasets with initialized but unfilled chunks + expected = math.prod( + math.ceil(s / c) for s, c in zip(dset.shape, dset.chunks, strict=True) + ) + if n_chunks != expected: + return None + + nbytes = int(np.prod(dset.shape)) * dset.dtype.itemsize + if nbytes < min_mb * (1 << 20): + return None + + # Codec must be supported + decoder = _supported_filter(dset) + if decoder is None: + return None + return decoder, n_chunks + + +# ------------------------------------------------------------------------------- +# Step 1 — locate chunks on disk +# ------------------------------------------------------------------------------- + + +class _Extent(NamedTuple): + """A contiguous byte range on disk plus the chunks it covers. + + Produced first by adjacency-coalescing all chunks (one extent per + contiguous on-disk region), then refined by splitting each extent + across workers (one ``pread()`` per resulting sub-extent). + ``members`` are ``StoreInfo`` namedtuples from h5py's ``chunk_iter`` + callback (fields: ``chunk_offset``, ``filter_mask``, + ``byte_offset``, ``size``), and must be sorted by ``byte_offset`` + and non-overlapping — both hold for HDF5 dataset chunks by + construction. + """ + + file_start: int # absolute byte offset in the file + file_end: int # exclusive + members: list # list[StoreInfo] + + @property + def length(self) -> int: + return self.file_end - self.file_start + + @classmethod + def from_members(cls, members) -> _Extent: + members = list(members) # batched() yields tuples; normalize + return cls( + file_start=members[0].byte_offset, + file_end=members[-1].byte_offset + members[-1].size, + members=members, + ) + + +def _gather_chunk_metadata(dset: h5py.Dataset) -> list: + """Enumerate every chunk's metadata in a single C-side iteration.""" + if not hasattr(dset.id, "chunk_iter"): + msg = ( + "h5py.Dataset.id.chunk_iter not available; need h5py >= 3.8" + "upgrading is strongly recommended for better performance." + ) + raise _ChunkIterUnavailable(msg) + metas: list = [] + dset.id.chunk_iter(metas.append) + return metas + + +def _coalesce_adjacent_chunks(metas: list) -> list[_Extent]: + """\ + This determines the continuous sections of memory to be read (extents). + + This will generally return one extent unless the HDF5 file has gaps + between chunks greater than ``_EXTENT_ALLOWED_CHUNK_GAPS_BYTES``. See + above comment on this parameter. + """ + if not metas: + return [] + sorted_metas = sorted(metas, key=lambda m: m.byte_offset) + splits = [ + i + for i, (prev, cur) in enumerate(pairwise(sorted_metas), start=1) + if cur.byte_offset - (prev.byte_offset + prev.size) + > _EXTENT_ALLOWED_CHUNK_GAPS_BYTES + ] + boundaries = [0, *splits, len(sorted_metas)] + return [_Extent.from_members(sorted_metas[a:b]) for a, b in pairwise(boundaries)] + + +def _split_extents_for_workers( + extents: list[_Extent], n_workers: int, n_chunks: int +) -> list[_Extent]: + step = max(1, math.ceil(n_chunks / n_workers)) + return [ + _Extent.from_members(sub) + for extent in extents + for sub in batched(extent.members, step) + ] + + +def _assert_disjoint_chunks(sub_extents: list[_Extent], n_chunks: int) -> None: + """Refuse to dispatch unless every chunk appears in exactly one sub-extent; + otherwise workers would race on the output array.""" + seen: set[tuple[int, ...]] = { + m.chunk_offset for ext in sub_extents for m in ext.members + } + if len(seen) != n_chunks: + msg = ( + f"parallel read partition is invalid: {len(seen)} unique chunks, " + f"expected {n_chunks}." + ) + raise RuntimeError(msg) + + +# ------------------------------------------------------------------------------- +# Step 2 — bulk pread chunk bytes into memory +# ------------------------------------------------------------------------------- + + +def _open_for_bulk_read(dset: h5py.Dataset) -> int: + """Open a raw read-only file descriptor at the path backing ``dset``. + + This is separate from h5py's fd because h5py routes + I/O through libhdf5's VFD layer; here we want a read that bypasses + h5py's PHIL. Multiple read-only fds on the same file + are universally supported on POSIX and Windows. + + Falls back via ``_UnsupportedDriver`` if the file isn't on a normal + disk-backed driver (e.g., in-memory, MPI-IO, S3-backed). + """ + driver = dset.file.driver + if driver not in _DISK_BACKED_DRIVERS: + msg = f"file driver {driver!r} is not supported by the bulk-read path" + raise _UnsupportedDriver(msg) + fd = os.open(dset.file.filename, os.O_RDONLY) + + # On linux, speed up cold cache (start populating page cache) + if hasattr(os, "posix_fadvise"): + try: + os.posix_fadvise(fd, 0, 0, os.POSIX_FADV_WILLNEED) + except OSError: # pragma: no cover - some FSes refuse fadvise + pass + return fd + + +# TODO: before completion of PR, adjust for proper Windows support +def _pread(fd: int, n: int, offset: int) -> bytes: + """Cross-platform positional read. ``os.pread`` is Unix-only; + Windows would need lseek+read.""" + if hasattr(os, "pread"): + return os.pread(fd, n, offset) + os.lseek(fd, offset, os.SEEK_SET) # pragma: no cover - Unix in CI + return os.read(fd, n) # pragma: no cover - Unix in CI + + +# ------------------------------------------------------------------------------- +# Step 3 — decompress + place into output +# ------------------------------------------------------------------------------- + + +def _chunk_assignment( + chunk_offset: tuple[int, ...], + chunk_shape: tuple[int, ...], + array_shape: tuple[int, ...], +) -> tuple[tuple[slice, ...], tuple[slice, ...]]: + """Chunked HDF5 is designed such that N-dimensional arrays are split into N-dimensional + chunks, all of the same size (e.g. tiles). + + Thus, chunks around the edges typically contain extra, unused bytes. + + Here we identify how much of the chunk to keep, and the position of the N-dimensional + output array where it should reside. This results in the chunks being properly + re-assembled into the original array. + """ + array_region: list[slice] = [] + chunk_region: list[slice] = [] + for chunk_start, chunk__len, array_axis_len in zip( + chunk_offset, chunk_shape, array_shape, strict=True + ): + chunk_boundary = min(chunk__len, array_axis_len - chunk_start) + array_region.append(slice(chunk_start, chunk_start + chunk_boundary)) + chunk_region.append(slice(0, chunk_boundary)) + return tuple(array_region), tuple(chunk_region) + + +def _process_extent( + extent: _Extent, + fd: int, + decoder: Callable[[bytes], bytes], + dtype: np.dtype, + chunk_shape: tuple[int, ...], + array_shape: tuple[int, ...], + chunk_elems: int, + out: np.ndarray, +) -> None: + """Read one extent with a single ``pread`` and assign each of its + chunks into the corresponding region of ``out``. + + Workers run concurrently against the same ``out``; disjointness of + their writes is enforced upstream by :func:`_assert_disjoint_chunks`. + """ + extent_bytes = _pread(fd, extent.length, extent.file_start) + if len(extent_bytes) != extent.length: + msg = ( + f"short pread: got {len(extent_bytes)} bytes, " + f"expected {extent.length} at offset {extent.file_start}" + ) + raise OSError(msg) + + extent_view = memoryview(extent_bytes) + for chunk_info in extent.members: + rel_offset = chunk_info.byte_offset - extent.file_start + raw_chunk = extent_view[rel_offset : rel_offset + chunk_info.size] + + # Decode chunk. Skips if HDF5 did not code this chunk. + decoded = raw_chunk if chunk_info.filter_mask else decoder(raw_chunk) + + # Restore it to it's original shape (C row-major order, same as numpy) + chunk_array = np.frombuffer(decoded, dtype=dtype, count=chunk_elems).reshape( + chunk_shape + ) + + # Put it in the correct position of 'out' (noted in the chunk's metadata) + array_region, chunk_region = _chunk_assignment( + chunk_info.chunk_offset, chunk_shape, array_shape + ) + out[array_region] = chunk_array[chunk_region] + + +# ------------------------------------------------------------------------------- +# Orchestration / public entry +# ------------------------------------------------------------------------------- + + +def _resolve_num_workers(workers: int | None) -> int: + if workers is not None: + return workers + return min(os.cpu_count() or 1, 16) + + +def _warn_fallback(exc: Exception) -> None: + msg = ( + "anndata: parallel HDF5 read failed; falling back to serial. " + f"Reason: {type(exc).__name__}: {exc}" + ) + warn_once(msg, RuntimeWarning) + + +def _do_parallel_read( + dset: h5py.Dataset, + decoder: Callable[[bytes], bytes], + n_chunks: int, + n_workers: int, +) -> np.ndarray: + metas = _gather_chunk_metadata(dset) + extents = _coalesce_adjacent_chunks(metas) + pread_extents = _split_extents_for_workers(extents, n_workers, n_chunks) + _assert_disjoint_chunks(pread_extents, n_chunks) + + # worker threads will share the same fd; ``os.pread`` is thread-safe. + fd = _open_for_bulk_read(dset) + + array_shape = dset.shape + chunk_shape = dset.chunks + dtype = dset.dtype + chunk_elems = int(np.prod(chunk_shape)) + out = np.empty(array_shape, dtype=dtype) + + try: + futs = _pool_manager.distribute_tasks_to_threads( + n_workers, + _process_extent, + [ + (extent, fd, decoder, dtype, chunk_shape, array_shape, chunk_elems, out) + for extent in pread_extents + ], + ) + + wait(futs) + for f in futs: + f.result() + finally: + os.close(fd) + return out + + +def parallel_read_full(dset: h5py.Dataset) -> np.ndarray | None: + """Read a whole ``h5py.Dataset`` into memory using bulk-pread + parallel + decompression if eligible, else return ``None``. + + Bit-identical to ``dset[...]`` on success. Returns ``None`` (so callers + can fall back to their serial path) when: + + * ``settings.parallel_h5_read`` is disabled, + * the dataset is ineligible (see :func:`_eligible_for_multithreading`), + * the file driver isn't disk-backed (we need a real fd for ``pread``), + * ``chunk_iter`` isn't available (h5py < 3.8 / older HDF5 build), + * or any unexpected error occurs during the parallel read. + """ + if not settings.parallel_h5_read: + return None + try: + info = _eligible_for_multithreading( + dset, min_mb=settings.parallel_h5_read_min_mb + ) + if info is None: + return None + decoder, n_chunks = info + workers = _resolve_num_workers(settings.parallel_h5_read_workers) + return _do_parallel_read(dset, decoder, n_chunks, workers) + except Exception as exc: # noqa: BLE001 + _warn_fallback(exc) + return None diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index fef6980a4..79110d1d5 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -600,6 +600,12 @@ def write_basic_dask_dask_dense( @_REGISTRY.register_read(ZarrArray, IOSpec("array", "0.2.0")) @_REGISTRY.register_read(ZarrArray, IOSpec("string-array", "0.2.0")) def read_array(elem: _ArrayStorageType, *, _reader: Reader) -> npt.NDArray: + if isinstance(elem, H5Array): + from anndata._io.h5_parallel import parallel_read_full + + out = parallel_read_full(elem) + if out is not None: + return out return elem[()] diff --git a/src/anndata/_settings.py b/src/anndata/_settings.py index 44ee7d07f..b8e81cc4f 100644 --- a/src/anndata/_settings.py +++ b/src/anndata/_settings.py @@ -51,5 +51,31 @@ class Settings( "Only integer indices i.e., those caught by :func:`pandas.api.types.is_integer_dtype` will always be converted to strings. """ + parallel_h5_read: bool = True + """ + Enables substantially faster reads of h5ad files via multi-threading. Currently supports + these compression codecs: HDF5's gzip (zlib), zstd, and blosc with any inner codec. Will + fall back to standard single-threading with a warning if h5ad file compression method + is not supported or an error occurs. Set this parameter to False to restrict h5ad file reads + to a single thread. + """ + + parallel_h5_read_min_mb: Annotated[int, Field(ge=0)] = 64 + """ + Minimum uncompressed array size (MiB) below which the serial HDF5 read + path is used. Default 64 MiB: avoids thread-pool overhead on the many + small ``obs``/``var`` arrays a normal ``read_h5ad`` traverses; gains from + multi-threaded decompression come from ``X``/``layers`` arrays, which are + typically far above this threshold. + """ + + parallel_h5_read_workers: Annotated[int | None, Field(ge=1)] = None + """ + Thread-pool size for parallel HDF5 reads. ``None`` (default) chooses + ``min(os.cpu_count(), 16)``. For long-running jobs on machines with + more than 16 logical cores and high disk bandwidth, consider raising this + to the logical core count and reducing until performance peaks. + """ + settings = Settings() diff --git a/tests/test_io_h5_parallel.py b/tests/test_io_h5_parallel.py new file mode 100644 index 000000000..f6216eb30 --- /dev/null +++ b/tests/test_io_h5_parallel.py @@ -0,0 +1,561 @@ +"""Tests for the auto-detected parallel HDF5 chunk-decompression read path. + +The parallel path (``anndata._io.h5_parallel``) must be bit-identical to the +serial path on success and silently fall back to serial when ineligible. +""" + +from __future__ import annotations + +import warnings +from math import ceil +from typing import TYPE_CHECKING, NamedTuple + +import h5py +import hdf5plugin +import numcodecs +import numpy as np +import pytest +from scipy import sparse + +import anndata as ad +from anndata._io import h5_parallel as h5p + +if TYPE_CHECKING: + from pathlib import Path + + +# ------------------------------------------------------------------------------- +# Shared fixtures / helpers +# ------------------------------------------------------------------------------- + + +# gzip is built into libhdf5; zstd and blosc come from hdf5plugin filters. +CODEC_KWARGS: dict[str, dict] = { + "gzip": dict(compression="gzip"), + "zstd": dict(hdf5plugin.Zstd(clevel=3)), + "blosc": dict( + hdf5plugin.Blosc(cname="lz4", clevel=5, shuffle=hdf5plugin.Blosc.SHUFFLE) + ), +} + + +def _csr( + n_obs: int = 2_000, n_var: int = 1_000, density: float = 0.1 +) -> sparse.csr_matrix: + return sparse.random( + n_obs, + n_var, + density=density, + format="csr", + dtype=np.float32, + random_state=np.random.default_rng(0), + ) + + +def _rechunk_x_with_codec( + path: Path, codec_kwargs: dict, *, n_subchunks: int = 4 +) -> None: + # Only X's 1-D arrays get the filter (tiny obs/var datasets don't support + # every plugin); >= 2 chunks keeps them eligible at min_mb=0. + with h5py.File(path, "r+") as f: + for sub in ("data", "indices"): + arr = f[f"X/{sub}"][...] + del f[f"X/{sub}"] + chunk = max(1, ceil(arr.shape[0] / n_subchunks)) + f["X"].create_dataset(sub, data=arr, chunks=(chunk,), **codec_kwargs) + + +def _spy_parallel(monkeypatch: pytest.MonkeyPatch) -> list[int]: + # Lets a test assert the parallel path ran (else, after a silent fallback, + # a serial-vs-serial comparison would pass vacuously). + calls: list[int] = [] + real = h5p._do_parallel_read + + def spy(*args, **kwargs): + calls.append(1) + return real(*args, **kwargs) + + monkeypatch.setattr(h5p, "_do_parallel_read", spy) + return calls + + +def _assert_csr_byte_equal(a: ad.AnnData, b: ad.AnnData) -> None: + Xa, Xb = a.X.tocsr(), b.X.tocsr() + np.testing.assert_array_equal(Xa.data, Xb.data) + np.testing.assert_array_equal(Xa.indices, Xb.indices) + np.testing.assert_array_equal(Xa.indptr, Xb.indptr) + + +# ------------------------------------------------------------------------------- +# Orchestration / public entry — parallel read is bit-identical to serial +# ------------------------------------------------------------------------------- + + +@pytest.mark.parametrize("codec", list(CODEC_KWARGS), ids=list(CODEC_KWARGS)) +def test_parallel_byte_identical_to_serial( + tmp_path: Path, codec: str, monkeypatch: pytest.MonkeyPatch +) -> None: + p = tmp_path / f"{codec}.h5ad" + ad.AnnData(X=_csr()).write_h5ad(p) + _rechunk_x_with_codec(p, CODEC_KWARGS[codec]) + + with ad.settings.override(parallel_h5_read=False): + serial = ad.read_h5ad(p) + + calls = _spy_parallel(monkeypatch) + with ad.settings.override(parallel_h5_read_min_mb=0): + parallel = ad.read_h5ad(p) + + assert calls, "parallel path did not engage; comparison would be vacuous" + _assert_csr_byte_equal(serial, parallel) + + +def test_dense_2d_default_threshold_engages( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + # 4200 x 4200 f32 ~= 70.6 MB > 64 MiB default: engages with default settings. + X = np.random.default_rng(0).standard_normal((4200, 4200), dtype=np.float32) + p = tmp_path / "dense.h5ad" + ad.AnnData(X=X).write_h5ad(p, compression="gzip") + with h5py.File(p, "r") as f: + assert isinstance(f["X"], h5py.Dataset), "expected dense X" + assert f["X"].chunks is not None + + with ad.settings.override(parallel_h5_read=False): + serial = ad.read_h5ad(p) + + calls = _spy_parallel(monkeypatch) + parallel = ad.read_h5ad(p) + + assert calls, "parallel path did not engage at the default threshold" + np.testing.assert_array_equal(serial.X, parallel.X) + + +def test_csc_gzip_byte_identical( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + X = sparse.random( + 2_000, + 1_000, + density=0.1, + format="csc", + dtype=np.float32, + random_state=np.random.default_rng(0), + ) + p = tmp_path / "csc.h5ad" + ad.AnnData(X=X).write_h5ad(p) + _rechunk_x_with_codec(p, CODEC_KWARGS["gzip"]) + + with ad.settings.override(parallel_h5_read=False): + serial = ad.read_h5ad(p) + + calls = _spy_parallel(monkeypatch) + with ad.settings.override(parallel_h5_read_min_mb=0): + parallel = ad.read_h5ad(p) + + assert calls + Xa, Xb = serial.X.tocsc(), parallel.X.tocsc() + np.testing.assert_array_equal(Xa.data, Xb.data) + np.testing.assert_array_equal(Xa.indices, Xb.indices) + np.testing.assert_array_equal(Xa.indptr, Xb.indptr) + + +def test_opt_out_disables_parallel( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + p = tmp_path / "x.h5ad" + ad.AnnData(X=_csr()).write_h5ad(p) + _rechunk_x_with_codec(p, CODEC_KWARGS["gzip"]) + + calls = _spy_parallel(monkeypatch) + with ad.settings.override(parallel_h5_read=False, parallel_h5_read_min_mb=0): + ad.read_h5ad(p) + assert calls == [] + + +def test_unexpected_error_falls_back( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + p = tmp_path / "x.h5ad" + ad.AnnData(X=_csr()).write_h5ad(p) + _rechunk_x_with_codec(p, CODEC_KWARGS["gzip"]) + + def boom(*args, **kwargs): + msg = "synthetic" + raise RuntimeError(msg) + + monkeypatch.setattr(h5p, "_do_parallel_read", boom) + # warn_once installs a process-wide ignore filter; reset so pytest.warns sees it. + with warnings.catch_warnings(): + warnings.resetwarnings() + with ( + pytest.warns(RuntimeWarning, match="parallel HDF5 read failed"), + ad.settings.override(parallel_h5_read_min_mb=0), + ): + adata = ad.read_h5ad(p) + with ad.settings.override(parallel_h5_read=False): + ref = ad.read_h5ad(p) + _assert_csr_byte_equal(ref, adata) + + +def test_resize_between_reads_succeeds( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + # A worker-count change between reads rebuilds the shared pool; both reads + # must engage and agree (pre-refactor this could silently fall back). + p = tmp_path / "x.h5ad" + ad.AnnData(X=_csr()).write_h5ad(p) + _rechunk_x_with_codec(p, CODEC_KWARGS["gzip"]) + + calls = _spy_parallel(monkeypatch) + with ad.settings.override(parallel_h5_read_min_mb=0, parallel_h5_read_workers=2): + first = ad.read_h5ad(p) + with ad.settings.override(parallel_h5_read_min_mb=0, parallel_h5_read_workers=4): + second = ad.read_h5ad(p) + assert calls, "parallel path did not engage" + _assert_csr_byte_equal(first, second) + + +# ------------------------------------------------------------------------------- +# Eligibility gating — ineligible datasets fall back to serial +# ------------------------------------------------------------------------------- + + +def test_small_dataset_falls_back(tmp_path: Path) -> None: + X = sparse.random( + 1_000, + 500, + density=0.05, + format="csr", + dtype=np.float32, + random_state=np.random.default_rng(0), + ) + p = tmp_path / "small.h5ad" + ad.AnnData(X=X).write_h5ad(p, compression="gzip") + with ad.settings.override(parallel_h5_read=False): + serial = ad.read_h5ad(p) + parallel = ad.read_h5ad(p) # default 64 MiB threshold; too small to engage + _assert_csr_byte_equal(serial, parallel) + with h5py.File(p, "r") as f: + assert h5p._eligible_for_multithreading(f["X/data"], min_mb=64) is None + + +def test_unchunked_dataset_falls_back(tmp_path: Path) -> None: + p = tmp_path / "contig.h5" + with h5py.File(p, "w") as f: + f.create_dataset("d", data=np.arange(2_000, dtype=np.float32)) + assert f["d"].chunks is None + assert h5p._eligible_for_multithreading(f["d"], min_mb=0) is None + + +def test_unsupported_codec_falls_back(tmp_path: Path) -> None: + # lzf is HDF5-builtin but has no numcodecs equivalent; stands in for any + # future filter without a matching decoder. + p = tmp_path / "lzf.h5" + with h5py.File(p, "w") as f: + f.create_dataset( + "d", + data=np.arange(4_000, dtype=np.float32), + chunks=(1_000,), + compression="lzf", + ) + assert h5p._eligible_for_multithreading(f["d"], min_mb=0) is None + + +def test_multifilter_pipeline_falls_back(tmp_path: Path) -> None: + # Two-filter pipeline (gzip + shuffle): we accept exactly one filter. + p = tmp_path / "gzip_shuffle.h5" + with h5py.File(p, "w") as f: + f.create_dataset( + "d", + data=np.arange(4_000, dtype=np.float32), + chunks=(1_000,), + compression="gzip", + shuffle=True, + ) + assert h5p._eligible_for_multithreading(f["d"], min_mb=0) is None + + +def test_standalone_lz4_falls_back(tmp_path: Path) -> None: + # HDF5 LZ4's multi-block framing isn't parseable by numcodecs.LZ4, so the + # gate must reject it and the serial path must read it. + X = _csr() + p = tmp_path / "lz4.h5ad" + ad.AnnData(X=X).write_h5ad(p) + _rechunk_x_with_codec(p, dict(hdf5plugin.LZ4())) + with h5py.File(p, "r") as f: + assert h5p._eligible_for_multithreading(f["X/data"], min_mb=0) is None + Xp = ad.read_h5ad(p).X.tocsr() + np.testing.assert_array_equal(Xp.data, X.data) + np.testing.assert_array_equal(Xp.indices, X.indices) + + +def test_partial_chunk_allocation_falls_back() -> None: + # Unallocated chunks would leave garbage in the pre-allocated output. Mocked + # since h5py's DatasetID is a C type. + from types import SimpleNamespace + from unittest.mock import MagicMock + + plist = MagicMock() + plist.get_nfilters.return_value = 1 + plist.get_filter.return_value = (h5p._FILTER_DEFLATE, 0, (), b"") + dset_id = MagicMock() + dset_id.get_num_chunks.return_value = 9 # one missing: shape+chunks expect 10 + dset_id.get_create_plist.return_value = plist + fake = SimpleNamespace( + chunks=(100_000,), + shape=(1_000_000,), + dtype=np.dtype("float32"), + id=dset_id, + ) + assert h5p._eligible_for_multithreading(fake, min_mb=0) is None + dset_id.get_num_chunks.return_value = 10 # all allocated -> eligible + assert h5p._eligible_for_multithreading(fake, min_mb=0) is not None + + +# ------------------------------------------------------------------------------- +# Step 1 — locate chunks on disk (coalesce extents, split, disjointness guard) +# ------------------------------------------------------------------------------- + + +class _StoreInfo(NamedTuple): + # Matches what h5py's chunk_iter delivers, so tests don't depend on its + # internal StoreInfo class. + chunk_offset: tuple[int, ...] + filter_mask: int + byte_offset: int + size: int + + +def _meta(offset, byte_off, size, mask=0) -> _StoreInfo: + return _StoreInfo(offset, mask, byte_off, size) + + +def test_coalesce_adjacent_chunks_contiguous() -> None: + metas = [_meta((i * 100,), i * 1000, 1000) for i in range(10)] + extents = h5p._coalesce_adjacent_chunks(metas) + assert len(extents) == 1 + assert extents[0].file_start == 0 + assert extents[0].file_end == 10_000 + assert extents[0].length == 10_000 + + +def test_coalesce_adjacent_chunks_unsorted_input() -> None: + metas = [ + _meta((0,), 9000, 1000), + _meta((100,), 8000, 1000), + _meta((200,), 7000, 1000), + ] + extents = h5p._coalesce_adjacent_chunks(metas) + assert len(extents) == 1 + assert extents[0].file_start == 7000 + assert extents[0].file_end == 10_000 + + +def test_coalesce_adjacent_chunks_large_gap_splits() -> None: + # A gap above _EXTENT_ALLOWED_CHUNK_GAPS_BYTES splits the extent so we don't + # pread megabytes of free space between fragmented chunks. + metas = [ + _meta((0,), 0, 1000), + _meta((100,), h5p._EXTENT_ALLOWED_CHUNK_GAPS_BYTES + 10_000, 1000), + ] + extents = h5p._coalesce_adjacent_chunks(metas) + assert len(extents) == 2 + + +def test_split_extents_for_workers_no_cross_extent_subextents() -> None: + members_a = [_meta((i * 100,), i * 1000, 1000) for i in range(5)] + members_b = [_meta(((i + 5) * 100,), 10_000_000 + i * 1000, 1000) for i in range(5)] + extents = [ + h5p._Extent(0, 5000, members_a), + h5p._Extent(10_000_000, 10_005_000, members_b), + ] + sub_extents = h5p._split_extents_for_workers(extents, n_workers=4, n_chunks=10) + for sub in sub_extents: + in_a = sub.file_start >= 0 and sub.file_end <= 5000 + in_b = sub.file_start >= 10_000_000 and sub.file_end <= 10_005_000 + assert in_a ^ in_b, f"sub-extent {sub} crosses an extent boundary" + + +def test_assert_disjoint_chunks_happy_partition_passes() -> None: + members = [_meta((i * 100,), i * 1000, 1000) for i in range(10)] + sub_extents = h5p._split_extents_for_workers( + [h5p._Extent.from_members(members)], n_workers=4, n_chunks=10 + ) + h5p._assert_disjoint_chunks(sub_extents, n_chunks=10) + + +def test_assert_disjoint_chunks_overlap_raises() -> None: + # Two sub-extents referencing the same chunk would race on the output. + overlapping = [ + h5p._Extent.from_members([_meta((100,), 1000, 1000)]), + h5p._Extent.from_members([_meta((100,), 2000, 1000)]), + ] + with pytest.raises(RuntimeError, match="parallel read partition is invalid"): + h5p._assert_disjoint_chunks(overlapping, n_chunks=2) + + +def test_assert_disjoint_chunks_missing_chunk_raises() -> None: + # Fewer chunks than n_chunks would leave part of the output uninitialised. + truncated = [h5p._Extent.from_members([_meta((100,), 1000, 1000)])] + with pytest.raises(RuntimeError, match="parallel read partition is invalid"): + h5p._assert_disjoint_chunks(truncated, n_chunks=2) + + +def test_disjoint_guard_triggers_serial_fallback( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + p = tmp_path / "x.h5ad" + ad.AnnData(X=_csr()).write_h5ad(p) + _rechunk_x_with_codec(p, CODEC_KWARGS["gzip"]) + + def force_overlap(extents, n_workers, n_chunks): + # Duplicate the first chunk so the disjointness guard fails. + first = extents[0].members[0] + return [h5p._Extent.from_members([first, first])] + + monkeypatch.setattr(h5p, "_split_extents_for_workers", force_overlap) + + with warnings.catch_warnings(): + warnings.resetwarnings() + with ( + pytest.warns(RuntimeWarning, match="parallel HDF5 read failed"), + ad.settings.override(parallel_h5_read_min_mb=0), + ): + adata = ad.read_h5ad(p) + with ad.settings.override(parallel_h5_read=False): + ref = ad.read_h5ad(p) + _assert_csr_byte_equal(ref, adata) + + +# ------------------------------------------------------------------------------- +# Step 2 — bulk pread chunk bytes into memory (disk-backed driver required) +# ------------------------------------------------------------------------------- + + +def test_core_driver_falls_back(tmp_path: Path) -> None: + # The in-memory 'core' driver has no on-disk path to pread() against. + p = tmp_path / "for_core.h5" + data = np.random.default_rng(0).standard_normal(4_000, dtype=np.float32) + with h5py.File(p, "w") as f: + f.create_dataset("d", data=data, chunks=(1_000,), compression="gzip") + with ( + h5py.File(p, "r", driver="core") as f, + pytest.raises(h5p._UnsupportedDriver), + ): + h5p._open_for_bulk_read(f["d"]) + + +# ------------------------------------------------------------------------------- +# Step 3 — decompress + place into output (chunk -> array slice math) +# ------------------------------------------------------------------------------- + + +def test_filter_mask_chunk_bypassed(tmp_path: Path) -> None: + # A chunk HDF5 stored raw (filter_mask bit set) must skip the decoder. + path = tmp_path / "mixed.h5" + chunks = (1000,) + shape = (4000,) + dtype = np.dtype(" None: + # (250, 300) / (100, 100) chunks gives a clipped row, column, and corner; + # edge chunks are zero-padded like HDF5 fill, then reassembled. + rng = np.random.default_rng(0) + array_shape = (250, 300) + chunk_shape = (100, 100) + arr = rng.standard_normal(array_shape, dtype=np.float32) + + out = np.empty(array_shape, dtype=arr.dtype) + for i in range(0, array_shape[0], chunk_shape[0]): + for j in range(0, array_shape[1], chunk_shape[1]): + chunk = np.zeros(chunk_shape, dtype=arr.dtype) + valid = arr[i : i + chunk_shape[0], j : j + chunk_shape[1]] + chunk[: valid.shape[0], : valid.shape[1]] = valid + array_region, chunk_region = h5p._chunk_assignment( + (i, j), chunk_shape, array_shape + ) + out[array_region] = chunk[chunk_region] + + np.testing.assert_array_equal(out, arr) + + +def test_chunk_assignment_excludes_chunk_padding() -> None: + array_shape = (150,) + chunk_shape = (100,) + + interior_chunk = np.arange(100, 200, dtype=np.int32) + # Edge chunk overhangs by 50: first 50 are real, the rest is fill padding. + edge_chunk = np.full(chunk_shape, -999, dtype=np.int32) + edge_chunk[:50] = np.arange(200, 250, dtype=np.int32) + + out = np.full(array_shape, -1, dtype=np.int32) + + array_region, chunk_region = h5p._chunk_assignment((0,), chunk_shape, array_shape) + out[array_region] = interior_chunk[chunk_region] + + array_region, chunk_region = h5p._chunk_assignment((100,), chunk_shape, array_shape) + out[array_region] = edge_chunk[chunk_region] + + np.testing.assert_array_equal(out[:100], np.arange(100, 200)) + np.testing.assert_array_equal(out[100:150], np.arange(200, 250)) + assert -999 not in out # padding never leaked into the destination + + +# ------------------------------------------------------------------------------- +# Codec dispatch — blosc thread configuration is lazy +# ------------------------------------------------------------------------------- + + +def test_blosc_threading_configured_only_when_blosc_used( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + # set_nthreads(1) must fire the first time we decode blosc, not on import + # and not for non-blosc reads. + gzip_path = tmp_path / "gzip.h5ad" + ad.AnnData(X=_csr()).write_h5ad(gzip_path) + _rechunk_x_with_codec(gzip_path, CODEC_KWARGS["gzip"]) + blosc_path = tmp_path / "blosc.h5ad" + ad.AnnData(X=_csr()).write_h5ad(blosc_path) + _rechunk_x_with_codec(blosc_path, CODEC_KWARGS["blosc"]) + + calls: list[int] = [] + + def fake_set_nthreads(n: int) -> None: + calls.append(n) + + monkeypatch.setattr(numcodecs.blosc, "set_nthreads", fake_set_nthreads) + # Reset the lazy gate so the spy can observe it. + monkeypatch.setattr(h5p, "_blosc_threads_configured", False) + + with ad.settings.override(parallel_h5_read_min_mb=0): + ad.read_h5ad(gzip_path) + assert calls == [] # gzip read must not touch blosc thread state + + ad.read_h5ad(blosc_path) + assert calls == [1] # first blosc read configures it + + ad.read_h5ad(blosc_path) + assert calls == [1] # idempotent diff --git a/tests/test_io_pool.py b/tests/test_io_pool.py new file mode 100644 index 000000000..20fe1a9d2 --- /dev/null +++ b/tests/test_io_pool.py @@ -0,0 +1,182 @@ +"""Tests for the IO-side thread-pool manager. + +:class:`anndata._io._pool.PoolManager` is exercised end-to-end through +the HDF5 parallel-read path (see :mod:`test_io_h5_parallel`); these +tests pin its standalone semantics: + +* lazy creation on first ``distribute_tasks_to_threads``, +* reuse when the requested worker count is unchanged, +* atomic resize-and-submit when the worker count changes, +* ``reset()`` returns the manager to its pre-use state, +* the configured thread-name prefix is applied to worker threads. +""" + +from __future__ import annotations + +import threading +from concurrent.futures import wait +from typing import TYPE_CHECKING + +import pytest + +from anndata._io._pool import PoolManager + +if TYPE_CHECKING: + from collections.abc import Generator + + +@pytest.fixture +def mgr() -> Generator[PoolManager, None, None]: + """A fresh manager per test — prevents cross-test pool leakage.""" + pm = PoolManager(thread_name_prefix="anndata-test-pool") + yield pm + pm.reset() + + +def _square(x: int) -> int: + return x * x + + +def test_returns_one_future_per_task(mgr: PoolManager) -> None: + """One Future is returned per input task, in input order, each + resolving to ``fn(*task)``.""" + futs = mgr.distribute_tasks_to_threads(4, _square, [(i,) for i in range(8)]) + assert [f.result() for f in futs] == [i * i for i in range(8)] + + +def test_empty_tasks_returns_empty_list(mgr: PoolManager) -> None: + """An empty tasks iterable produces no futures and does not raise. + + Importantly, the pool *is* created on first call (even with no + work) so that the manager's invariant (`_pool is not None` after + any ``distribute_tasks_to_threads``) holds uniformly — exercised + by the reuse test below. + """ + futs = mgr.distribute_tasks_to_threads(4, _square, []) + assert futs == [] + assert mgr._pool is not None + assert mgr._n_workers == 4 + + +def test_reuses_pool_when_workers_unchanged(mgr: PoolManager) -> None: + """Repeat calls with the same worker count must return the same + underlying executor — this is the persistent-pool win the manager + exists for.""" + mgr.distribute_tasks_to_threads(4, _square, [(1,)]) + pool_a = mgr._pool + mgr.distribute_tasks_to_threads(4, _square, [(2,)]) + pool_b = mgr._pool + assert pool_a is pool_b + + +def test_resizes_when_workers_change(mgr: PoolManager) -> None: + """Different worker count triggers a rebuild — the old pool is + discarded and a new one of the requested size takes its place.""" + mgr.distribute_tasks_to_threads(2, _square, [(1,)]) + pool_a = mgr._pool + assert mgr._n_workers == 2 + mgr.distribute_tasks_to_threads(5, _square, [(2,)]) + pool_b = mgr._pool + assert pool_b is not pool_a + assert mgr._n_workers == 5 + + +def test_resize_still_completes_in_flight_work(mgr: PoolManager) -> None: + """After a resize, futures submitted to the *old* pool must still + complete. ``shutdown(wait=False)`` only blocks new submissions — + work already enqueued runs to completion.""" + # Submit a batch that takes a moment, so we can race a resize past it. + sentinel = threading.Event() + + def slow(x: int) -> int: + sentinel.wait(timeout=5.0) + return x + 1 + + old_futs = mgr.distribute_tasks_to_threads(2, slow, [(10,), (20,)]) + # Resize while the old tasks are parked on the event. + mgr.distribute_tasks_to_threads(4, _square, [(3,)]) + sentinel.set() + # Both groups of tasks must complete. + wait(old_futs) + assert sorted(f.result() for f in old_futs) == [11, 21] + + +def test_reset_tears_down_pool(mgr: PoolManager) -> None: + """``reset()`` waits for in-flight tasks and nulls the cached pool; + subsequent ``distribute_tasks_to_threads`` lazily creates a fresh + one.""" + mgr.distribute_tasks_to_threads(2, _square, [(1,)]) + assert mgr._pool is not None + mgr.reset() + assert mgr._pool is None + assert mgr._n_workers == 0 + + futs = mgr.distribute_tasks_to_threads(3, _square, [(4,)]) + assert mgr._pool is not None + assert mgr._n_workers == 3 + assert futs[0].result() == 16 + + +def test_reset_on_unused_manager_is_noop(mgr: PoolManager) -> None: + """``reset()`` before any submission must not raise — supports a + uniform test-teardown idiom regardless of whether the test + exercised the manager.""" + mgr.reset() + assert mgr._pool is None + + +def test_thread_name_prefix_is_applied() -> None: + """The constructor's prefix shows up on worker thread names so + profilers and crash dumps can attribute the threads back to anndata.""" + mgr = PoolManager(thread_name_prefix="anndata-test-prefix") + try: + name_box: list[str] = [] + mgr.distribute_tasks_to_threads( + 1, lambda: name_box.append(threading.current_thread().name), [()] + )[0].result() + assert name_box[0].startswith("anndata-test-prefix") + finally: + mgr.reset() + + +def test_concurrent_resize_and_submit_is_race_free() -> None: + """The central guarantee: two threads alternating + ``distribute_tasks_to_threads`` with *different* worker counts + must never trip a ``RuntimeError: cannot schedule new futures + after shutdown``. + + Pre-refactor, ``_get_pool`` released the lock before the caller + could submit; a concurrent resize could shut down the pool out + from under an in-flight submission. + ``distribute_tasks_to_threads`` holds the lock through submission, + so the race cannot occur. + """ + mgr = PoolManager(thread_name_prefix="anndata-test-race") + try: + errors: list[BaseException] = [] + # Many alternating calls, each toggling between two sizes. + # If the race exists, on a contended scheduler at least one + # submission would land on a freshly-shutdown pool. + N = 200 + + def loop(num_workers: int) -> None: + for _ in range(N): + try: + futs = mgr.distribute_tasks_to_threads( + num_workers, _square, [(1,), (2,), (3,)] + ) + wait(futs) + for f in futs: + f.result() + except BaseException as e: # noqa: BLE001 + errors.append(e) + + t1 = threading.Thread(target=loop, args=(2,)) + t2 = threading.Thread(target=loop, args=(8,)) + t1.start() + t2.start() + t1.join() + t2.join() + assert not errors, f"unexpected errors during race test: {errors[:3]}" + finally: + mgr.reset()