Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ dependencies = [
"matplotlib-scalebar>=0.8",
"networkx>=2.6",
"numba>=0.56.4",
"numba-progress>=1.1",
"numpy>=1.23",
"omnipath>=1.0.7",
"pandas>=2.1",
Expand Down
44 changes: 30 additions & 14 deletions src/squidpy/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
import numpy as np
import xarray as xr
from spatialdata.models import Image2DModel, Labels2DModel
from tqdm.auto import tqdm

if TYPE_CHECKING:
from numba_progress import ProgressBar

__all__ = ["singledispatchmethod", "Signal", "SigQueue", "NDArray", "NDArrayA"]

Expand Down Expand Up @@ -262,9 +264,9 @@ def thread_map(
n_jobs
Number of worker threads. ``1`` runs sequentially (no pool overhead).
show_progress_bar
Whether to display a ``tqdm`` progress bar.
Whether to display a ``numba_progress`` progress bar.
unit
Label shown next to the ``tqdm`` counter.
Label shown next to the progress counter.

Returns
-------
Expand All @@ -273,17 +275,31 @@ def thread_map(
"""
from concurrent.futures import ThreadPoolExecutor

if n_jobs == 1:
it: Iterable[Any] = map(fn, items)
if show_progress_bar and tqdm is not None:
it = tqdm(it, total=len(items), unit=unit)
return list(it)

with ThreadPoolExecutor(max_workers=n_jobs) as pool:
it = pool.map(fn, items)
if show_progress_bar and tqdm is not None:
it = tqdm(it, total=len(items), unit=unit)
return list(it)
items = list(items)

def _run(pbar: ProgressBar | None) -> list[Any]:
def _consume(results_it: Iterable[Any]) -> list[Any]:
results = []
# ``map``/``pool.map`` yield in submission order, so results stay aligned with *items*.
for res in results_it:
results.append(res)
if pbar is not None:
pbar.update(1)
return results

if n_jobs == 1:
return _consume(map(fn, items))

with ThreadPoolExecutor(max_workers=n_jobs) as pool:
return _consume(pool.map(fn, items))

if show_progress_bar:
from numba_progress import ProgressBar

with ProgressBar(total=len(items), unit=unit) as pbar:
return _run(pbar)

return _run(None)


def _get_n_cores(n_cores: int | None) -> int:
Expand Down
Loading