diff --git a/pyproject.toml b/pyproject.toml index 418135298..04aaaab43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,6 +55,7 @@ dependencies = [ "matplotlib-scalebar>=0.8", "networkx>=2.6", "numba>=0.56.4", + "numba-progress>=1.1", "numpy>=1.23", "omnipath>=1.0.7", "pandas>=2.1", diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index 17e6a7a83..9dc9b14c2 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -19,7 +19,9 @@ import numpy as np import xarray as xr from spatialdata.models import Image2DModel, Labels2DModel -from tqdm.auto import tqdm + +if TYPE_CHECKING: + from numba_progress import ProgressBar __all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"] @@ -262,9 +264,9 @@ def thread_map( n_jobs Number of worker threads. ``1`` runs sequentially (no pool overhead). show_progress_bar - Whether to display a ``tqdm`` progress bar. + Whether to display a ``numba_progress`` progress bar. unit - Label shown next to the ``tqdm`` counter. + Label shown next to the progress counter. Returns ------- @@ -273,17 +275,31 @@ def thread_map( """ from concurrent.futures import ThreadPoolExecutor - if n_jobs == 1: - it: Iterable[Any] = map(fn, items) - if show_progress_bar and tqdm is not None: - it = tqdm(it, total=len(items), unit=unit) - return list(it) - - with ThreadPoolExecutor(max_workers=n_jobs) as pool: - it = pool.map(fn, items) - if show_progress_bar and tqdm is not None: - it = tqdm(it, total=len(items), unit=unit) - return list(it) + items = list(items) + + def _run(pbar: ProgressBar | None) -> list[Any]: + def _consume(results_it: Iterable[Any]) -> list[Any]: + results = [] + # ``map``/``pool.map`` yield in submission order, so results stay aligned with *items*. + for res in results_it: + results.append(res) + if pbar is not None: + pbar.update(1) + return results + + if n_jobs == 1: + return _consume(map(fn, items)) + + with ThreadPoolExecutor(max_workers=n_jobs) as pool: + return _consume(pool.map(fn, items)) + + if show_progress_bar: + from numba_progress import ProgressBar + + with ProgressBar(total=len(items), unit=unit) as pbar: + return _run(pbar) + + return _run(None) def _get_n_cores(n_cores: int | None) -> int: