diff --git a/nemo_retriever/pyproject.toml b/nemo_retriever/pyproject.toml index 8e2758554..549adb019 100644 --- a/nemo_retriever/pyproject.toml +++ b/nemo_retriever/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "uvicorn[standard]>=0.30.0", "python-multipart>=0.0.9", # HTTP clients + "aiohttp>=3.9.0", "httpx>=0.27.0", "requests>=2.32.5", "urllib3==2.6.3", diff --git a/nemo_retriever/src/nemo_retriever/audio/asr_actor.py b/nemo_retriever/src/nemo_retriever/audio/asr_actor.py index 488dd8db4..4aea7308c 100644 --- a/nemo_retriever/src/nemo_retriever/audio/asr_actor.py +++ b/nemo_retriever/src/nemo_retriever/audio/asr_actor.py @@ -401,4 +401,4 @@ def apply_asr_to_df( """ params = ASRParams(**(asr_params or {})) actor = ASRActor(params=params) - return actor(batch_df) + return actor.run(batch_df) diff --git a/nemo_retriever/src/nemo_retriever/chart/cpu_actor.py b/nemo_retriever/src/nemo_retriever/chart/cpu_actor.py index 647feff13..75820548a 100644 --- a/nemo_retriever/src/nemo_retriever/chart/cpu_actor.py +++ b/nemo_retriever/src/nemo_retriever/chart/cpu_actor.py @@ -15,7 +15,7 @@ from nemo_retriever.graph.cpu_operator import CPUOperator from nemo_retriever.nim.nim import NIMClient from nemo_retriever.params import RemoteRetryParams -from nemo_retriever.chart.shared import graphic_elements_ocr_page_elements +from nemo_retriever.chart.shared import agraphic_elements_ocr_page_elements, graphic_elements_ocr_page_elements logger = logging.getLogger(__name__) @@ -143,9 +143,23 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await agraphic_elements_ocr_page_elements( + data, + graphic_elements_model=self._graphic_elements_model, + ocr_model=self._ocr_model, + graphic_elements_invoke_url=self._graphic_elements_invoke_url, + ocr_invoke_url=self._ocr_invoke_url, + api_key=self._api_key, + request_timeout_s=self._request_timeout_s, + remote_retry=self._remote_retry, + inference_batch_size=self._inference_batch_size, + **kwargs, + ) + + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as exc: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/chart/gpu_actor.py b/nemo_retriever/src/nemo_retriever/chart/gpu_actor.py index d856f7f46..87fe92af9 100644 --- a/nemo_retriever/src/nemo_retriever/chart/gpu_actor.py +++ b/nemo_retriever/src/nemo_retriever/chart/gpu_actor.py @@ -12,7 +12,7 @@ from nemo_retriever.graph.gpu_operator import GPUOperator from nemo_retriever.nim.nim import NIMClient from nemo_retriever.params import RemoteRetryParams -from nemo_retriever.chart.shared import graphic_elements_ocr_page_elements +from nemo_retriever.chart.shared import agraphic_elements_ocr_page_elements, graphic_elements_ocr_page_elements class GraphicElementsActor(AbstractOperator, GPUOperator): @@ -88,9 +88,23 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await agraphic_elements_ocr_page_elements( + data, + graphic_elements_model=self._graphic_elements_model, + ocr_model=self._ocr_model, + graphic_elements_invoke_url=self._graphic_elements_invoke_url, + ocr_invoke_url=self._ocr_invoke_url, + api_key=self._api_key, + request_timeout_s=self._request_timeout_s, + remote_retry=self._remote_retry, + inference_batch_size=self._inference_batch_size, + **kwargs, + ) + + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as exc: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/chart/shared.py b/nemo_retriever/src/nemo_retriever/chart/shared.py index 4467246bb..d9af167c0 100644 --- a/nemo_retriever/src/nemo_retriever/chart/shared.py +++ b/nemo_retriever/src/nemo_retriever/chart/shared.py @@ -8,9 +8,12 @@ import base64 import io +import logging import time import traceback +logger = logging.getLogger(__name__) + import pandas as pd from nemo_retriever.nim.nim import NIMClient, invoke_image_inference_batches from nemo_retriever.params import RemoteRetryParams @@ -576,6 +579,234 @@ def graphic_elements_ocr_page_elements( return out +async def agraphic_elements_ocr_page_elements( + batch_df: Any, + *, + graphic_elements_model: Any = None, + ocr_model: Any = None, + graphic_elements_invoke_url: str = "", + ocr_invoke_url: str = "", + api_key: str = "", + request_timeout_s: float = 120.0, + remote_retry: RemoteRetryParams | None = None, + **kwargs: Any, +) -> Any: + """Async version of :func:`graphic_elements_ocr_page_elements`.""" + import asyncio + + from nemo_retriever.nim.nim import ainvoke_image_inference_batches + from nemo_retriever.ocr.ocr import ( + _blocks_to_text, + _crop_all_from_page, + _extract_remote_ocr_item, + _np_rgb_to_b64_png, + _parse_ocr_result, + ) + from nemo_retriever.utils.table_and_chart import join_graphic_elements_and_ocr_output + + retry = remote_retry or RemoteRetryParams( + remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)), + remote_max_retries=int(kwargs.get("remote_max_retries", 10)), + remote_max_429_retries=int(kwargs.get("remote_max_429_retries", 5)), + ) + + if not isinstance(batch_df, pd.DataFrame): + raise NotImplementedError("agraphic_elements_ocr_page_elements currently only supports pandas.DataFrame input.") + + ge_url = (graphic_elements_invoke_url or kwargs.get("graphic_elements_invoke_url") or "").strip() + ocr_url = (ocr_invoke_url or kwargs.get("ocr_invoke_url") or "").strip() + use_remote_ge = bool(ge_url) + use_remote_ocr = bool(ocr_url) + + if not use_remote_ge and graphic_elements_model is None: + raise ValueError("A local `graphic_elements_model` is required when `graphic_elements_invoke_url` is not set.") + if not use_remote_ocr and ocr_model is None: + raise ValueError("A local `ocr_model` is required when `ocr_invoke_url` is not set.") + + label_names = _labels_from_model(graphic_elements_model) if graphic_elements_model is not None else [] + inference_batch_size = int(kwargs.get("inference_batch_size", 8)) + + all_chart: List[List[Dict[str, Any]]] = [] + all_meta: List[Dict[str, Any]] = [] + + t0_total = time.perf_counter() + + for row in batch_df.itertuples(index=False): + chart_items: List[Dict[str, Any]] = [] + row_error: Any = None + + try: + pe = getattr(row, "page_elements_v3", None) + dets: List[Dict[str, Any]] = [] + if isinstance(pe, dict): + dets = pe.get("detections") or [] + if not isinstance(dets, list): + dets = [] + + page_image = getattr(row, "page_image", None) or {} + page_image_b64 = page_image.get("image_b64") if isinstance(page_image, dict) else None + + if not isinstance(page_image_b64, str) or not page_image_b64: + all_chart.append(chart_items) + all_meta.append({"timing": None, "error": None}) + continue + + crops = _crop_all_from_page(page_image_b64, dets, {"chart"}) + + if not crops: + all_chart.append(chart_items) + all_meta.append({"timing": None, "error": None}) + continue + + crop_b64s = ( + [_np_rgb_to_b64_png(crop_array) for _, _, crop_array in crops] + if (use_remote_ge or use_remote_ocr) + else [] + ) + + ge_results: List[List[Dict[str, Any]]] = [] + ocr_results: List[Any] = [] + + if use_remote_ge and use_remote_ocr: + ge_task = ainvoke_image_inference_batches( + invoke_url=ge_url, + image_b64_list=crop_b64s, + api_key=api_key or None, + timeout_s=float(request_timeout_s), + max_batch_size=inference_batch_size, + max_concurrency=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + ocr_task = ainvoke_image_inference_batches( + invoke_url=ocr_url, + image_b64_list=crop_b64s, + api_key=api_key or None, + timeout_s=float(request_timeout_s), + max_batch_size=inference_batch_size, + max_concurrency=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + ge_items, ocr_items = await asyncio.gather(ge_task, ocr_task) + + if len(ge_items) != len(crops): + raise RuntimeError(f"Expected {len(crops)} GE responses, got {len(ge_items)}") + for resp in ge_items: + ge_results.append( + [ + d + for d in _remote_response_to_ge_detections(resp) + if (d.get("score") or 0.0) >= YOLOX_GRAPHIC_MIN_SCORE + ] + ) + if len(ocr_items) != len(crops): + raise RuntimeError(f"Expected {len(crops)} OCR responses, got {len(ocr_items)}") + for resp in ocr_items: + ocr_results.append(_extract_remote_ocr_item(resp)) + else: + if use_remote_ge: + ge_items = await ainvoke_image_inference_batches( + invoke_url=ge_url, + image_b64_list=crop_b64s, + api_key=api_key or None, + timeout_s=float(request_timeout_s), + max_batch_size=inference_batch_size, + max_concurrency=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + if len(ge_items) != len(crops): + raise RuntimeError(f"Expected {len(crops)} GE responses, got {len(ge_items)}") + for resp in ge_items: + ge_results.append( + [ + d + for d in _remote_response_to_ge_detections(resp) + if (d.get("score") or 0.0) >= YOLOX_GRAPHIC_MIN_SCORE + ] + ) + else: + + def _run_local_ge(): + results = [] + for _, _, crop_array in crops: + chw = torch.from_numpy(crop_array).permute(2, 0, 1).contiguous().to(dtype=torch.float32) + h, w = crop_array.shape[:2] + x = chw.unsqueeze(0) + try: + pre = graphic_elements_model.preprocess(x) + except Exception: + pre = x + if isinstance(pre, torch.Tensor) and pre.ndim == 3: + pre = pre.unsqueeze(0) + pred = graphic_elements_model.invoke(pre, (h, w)) + ge_dets = _prediction_to_detections(pred, label_names=label_names) + results.append([d for d in ge_dets if (d.get("score") or 0.0) >= YOLOX_GRAPHIC_MIN_SCORE]) + return results + + ge_results = await asyncio.to_thread(_run_local_ge) + + if use_remote_ocr: + ocr_items = await ainvoke_image_inference_batches( + invoke_url=ocr_url, + image_b64_list=crop_b64s, + api_key=api_key or None, + timeout_s=float(request_timeout_s), + max_batch_size=inference_batch_size, + max_concurrency=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + if len(ocr_items) != len(crops): + raise RuntimeError(f"Expected {len(crops)} OCR responses, got {len(ocr_items)}") + for resp in ocr_items: + ocr_results.append(_extract_remote_ocr_item(resp)) + else: + + def _run_local_ocr(): + results = [] + for _, _, crop_array in crops: + results.append(ocr_model.invoke(crop_array, merge_level="word")) + return results + + ocr_results = await asyncio.to_thread(_run_local_ocr) + + for crop_i, (label_name, bbox, crop_array) in enumerate(crops): + crop_hw = (int(crop_array.shape[0]), int(crop_array.shape[1])) + ge_dets = ge_results[crop_i] + ocr_preds = ocr_results[crop_i] + + text = join_graphic_elements_and_ocr_output(ge_dets, ocr_preds, crop_hw) + + if not text: + blocks = _parse_ocr_result(ocr_preds) + text = _blocks_to_text(blocks) + + chart_items.append({"bbox_xyxy_norm": bbox, "text": text}) + + except BaseException as e: + logger.warning("graphic-elements+OCR failed: %s: %s", type(e).__name__, e, exc_info=True) + row_error = { + "stage": "graphic_elements_ocr_page_elements", + "type": e.__class__.__name__, + "message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + all_chart.append(chart_items) + all_meta.append({"timing": None, "error": row_error}) + + elapsed = time.perf_counter() - t0_total + for meta in all_meta: + meta["timing"] = {"seconds": float(elapsed)} + + out = batch_df.copy() + out["chart"] = all_chart + out["graphic_elements_ocr_v1"] = all_meta + return out + + # --------------------------------------------------------------------------- # Combined graphic-elements + OCR Ray Actor # --------------------------------------------------------------------------- diff --git a/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py b/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py index b0f5ee508..e2fa11fbc 100644 --- a/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/abstract_operator.py @@ -4,6 +4,7 @@ from __future__ import annotations +import asyncio from abc import ABC, abstractmethod import inspect from typing import Any, TYPE_CHECKING @@ -12,6 +13,25 @@ from nemo_retriever.graph.pipeline_graph import Graph, Node +def _ensure_event_loop() -> None: + """Guarantee an asyncio event loop exists for the current thread. + + Ray Data spawns fresh worker processes that may not have an event loop. + ``asyncio.new_event_loop()`` delegates to the installed policy, so uvloop + (or any other custom policy) is automatically preserved. + """ + try: + asyncio.get_running_loop() + return + except RuntimeError: + pass + asyncio.set_event_loop(asyncio.new_event_loop()) + + +# Run at import time so the loop exists before Ray's _init_async() runs. +_ensure_event_loop() + + class AbstractOperator(ABC): """Base class for all pipeline operators.""" @@ -35,9 +55,31 @@ def run(self, data: Any, **kwargs: Any) -> Any: data = self.postprocess(data, **kwargs) return data - def __call__(self, data: Any, **kwargs: Any) -> Any: - """Make operators directly usable as Ray ``map_batches`` callables.""" - return self.run(data, **kwargs) + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + """Async version of :meth:`process`. + + The default calls ``process()`` synchronously on the event-loop + thread. This is intentional: most operators use C extensions + (pypdfium2, OpenCV, torch, …) that are **not** thread-safe, so + ``asyncio.to_thread`` would allow Ray Data to run multiple + batches in parallel threads inside the same actor, causing + memory corruption. + + I/O-bound subclasses that call remote endpoints should override + this with a proper ``await``-based implementation. + """ + return self.process(data, **kwargs) + + async def arun(self, data: Any, **kwargs: Any) -> Any: + """Async version of :meth:`run`.""" + data = self.preprocess(data, **kwargs) + data = await self.aprocess(data, **kwargs) + data = self.postprocess(data, **kwargs) + return data + + async def __call__(self, data: Any, **kwargs: Any) -> Any: + """Make operators directly usable as Ray ``map_batches`` async callables.""" + return await self.arun(data, **kwargs) def get_constructor_kwargs(self) -> dict[str, Any]: """Best-effort constructor kwargs for executor-side reconstruction.""" diff --git a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py index c2dbc83f3..ffec8dd24 100644 --- a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py +++ b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py @@ -22,7 +22,7 @@ explode_content_to_rows, ) from nemo_retriever.graph.multi_type_extract_operator import MultiTypeExtractOperator -from nemo_retriever.text_embed.operators import _BatchEmbedActor +from nemo_retriever.text_embed.operators import BatchEmbedActor from nemo_retriever.ocr.ocr import OCRActor from nemo_retriever.parse.nemotron_parse import NemotronParseActor from nemo_retriever.page_elements.page_elements import PageElementDetectionActor @@ -119,9 +119,9 @@ def _force_cpu_only(node_name: str) -> None: embed_invoke_url = _positive(getattr(embed_params, "embed_invoke_url", None)) explicit_bs = getattr(embed_tuning, "embed_batch_size", None) if embed_tuning is not None else None embed_bs = _positive(explicit_bs) or (plan.embed_batch_size if plan else None) - _set(_BatchEmbedActor.__name__, "batch_size", embed_bs) + _set(BatchEmbedActor.__name__, "batch_size", embed_bs) if embed_bs: - overrides.setdefault(_BatchEmbedActor.__name__, {})["target_num_rows_per_block"] = embed_bs + overrides.setdefault(BatchEmbedActor.__name__, {})["target_num_rows_per_block"] = embed_bs embed_concurrency = ( _resolve( getattr(embed_tuning, "embed_workers", None) if embed_tuning is not None else None, @@ -129,19 +129,19 @@ def _force_cpu_only(node_name: str) -> None: ) or 0 ) - _set(_BatchEmbedActor.__name__, "concurrency", embed_concurrency or None) + _set(BatchEmbedActor.__name__, "concurrency", embed_concurrency or None) embed_cpus = ( _resolve( getattr(embed_tuning, "embed_cpus_per_actor", None) if embed_tuning is not None else None, ) or 1.0 ) - _set(_BatchEmbedActor.__name__, "num_cpus", embed_cpus if embed_cpus != 1.0 else None) + _set(BatchEmbedActor.__name__, "num_cpus", embed_cpus if embed_cpus != 1.0 else None) if effective_allow_no_gpu: - _force_cpu_only(_BatchEmbedActor.__name__) + _force_cpu_only(BatchEmbedActor.__name__) elif not embed_invoke_url: _set_gpu( - _BatchEmbedActor.__name__, + BatchEmbedActor.__name__, getattr(embed_tuning, "gpu_embed", None) if embed_tuning is not None else None, plan.embed_gpus_per_actor if plan else None, ) @@ -474,7 +474,7 @@ def _append_ordered_transform_stages( ), name="ExplodeContentToRows", ) - graph = graph >> _BatchEmbedActor(params=embed_params) + graph = graph >> BatchEmbedActor(params=embed_params) return graph diff --git a/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py b/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py index 3fb968263..c28e6e885 100644 --- a/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py +++ b/nemo_retriever/src/nemo_retriever/graph/operator_archetype.py @@ -71,8 +71,15 @@ def postprocess(self, data: Any, **kwargs: Any) -> Any: def run(self, data: Any, **kwargs: Any) -> Any: return self._resolve_delegate().run(data, **kwargs) - def __call__(self, data: Any, **kwargs: Any) -> Any: - return self._resolve_delegate()(data, **kwargs) + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await self._resolve_delegate().aprocess(data, **kwargs) + + async def arun(self, data: Any, **kwargs: Any) -> Any: + return await self._resolve_delegate().arun(data, **kwargs) + + async def __call__(self, data: Any, **kwargs: Any) -> Any: + delegate = self._resolve_delegate() + return await delegate(data, **kwargs) def _resolve_delegate(self, resources: ClusterResources | Resources | None = None) -> AbstractOperator: if not hasattr(self, "_resolved_delegate"): diff --git a/nemo_retriever/src/nemo_retriever/graph/tabular_fetch_embeddings_operator.py b/nemo_retriever/src/nemo_retriever/graph/tabular_fetch_embeddings_operator.py index a97cf22f1..0a45f76e9 100644 --- a/nemo_retriever/src/nemo_retriever/graph/tabular_fetch_embeddings_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/tabular_fetch_embeddings_operator.py @@ -22,7 +22,7 @@ class TabularFetchEmbeddingsOp(AbstractOperator, CPUOperator): ``text``, ``_embed_modality``, ``path``, ``page_number``, ``metadata``. The output schema matches the format produced by the unstructured pipeline, - so the standard :class:`~nemo_retriever.text_embed.operators._BatchEmbedActor` + so the standard :class:`~nemo_retriever.text_embed.operators.BatchEmbedActor` can be chained directly after this operator. """ diff --git a/nemo_retriever/src/nemo_retriever/html/ray_data.py b/nemo_retriever/src/nemo_retriever/html/ray_data.py index 28723efe9..8bd7eaf4e 100644 --- a/nemo_retriever/src/nemo_retriever/html/ray_data.py +++ b/nemo_retriever/src/nemo_retriever/html/ray_data.py @@ -71,8 +71,8 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: pd.DataFrame) -> pd.DataFrame: - return self.run(batch_df) + async def __call__(self, batch_df: pd.DataFrame) -> pd.DataFrame: + return await self.arun(batch_df) class HtmlSplitActor(ArchetypeOperator): diff --git a/nemo_retriever/src/nemo_retriever/image/ray_data.py b/nemo_retriever/src/nemo_retriever/image/ray_data.py index 3b2669f86..c5726c039 100644 --- a/nemo_retriever/src/nemo_retriever/image/ray_data.py +++ b/nemo_retriever/src/nemo_retriever/image/ray_data.py @@ -83,8 +83,8 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: pd.DataFrame) -> pd.DataFrame: - return self.run(batch_df) + async def __call__(self, batch_df: pd.DataFrame) -> pd.DataFrame: + return await self.arun(batch_df) class ImageLoadActor(ArchetypeOperator): diff --git a/nemo_retriever/src/nemo_retriever/infographic/infographic_detection.py b/nemo_retriever/src/nemo_retriever/infographic/infographic_detection.py index 604bc0283..5202f1d9d 100644 --- a/nemo_retriever/src/nemo_retriever/infographic/infographic_detection.py +++ b/nemo_retriever/src/nemo_retriever/infographic/infographic_detection.py @@ -796,9 +796,9 @@ def process(self, batch_df: Any, **override_kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as e: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() @@ -851,9 +851,9 @@ def process(self, batch_df: Any, **override_kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as e: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/nim/nim.py b/nemo_retriever/src/nemo_retriever/nim/nim.py index 00766c02c..c5fd2649c 100644 --- a/nemo_retriever/src/nemo_retriever/nim/nim.py +++ b/nemo_retriever/src/nemo_retriever/nim/nim.py @@ -4,11 +4,13 @@ from __future__ import annotations +import asyncio import logging import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any, Dict, List, Optional, Sequence, Tuple +import aiohttp import requests logger = logging.getLogger(__name__) @@ -468,3 +470,208 @@ def invoke_page_elements_batches( max_retries=max_retries, max_429_retries=max_429_retries, ) + + +# --------------------------------------------------------------------------- +# Async variants (aiohttp) +# --------------------------------------------------------------------------- + + +async def _apost_with_retries( + *, + session: aiohttp.ClientSession, + invoke_url: str, + payload: Dict[str, Any], + headers: Dict[str, str], + timeout_s: float, + max_retries: int, + max_429_retries: int, +) -> Any: + base_delay = 2.0 + attempt = 0 + retries_429 = 0 + timeout = aiohttp.ClientTimeout(total=float(timeout_s)) + + while attempt < int(max_retries): + try: + async with session.post(invoke_url, json=payload, headers=headers, timeout=timeout) as response: + status_code = response.status + + if status_code == 429: + retries_429 += 1 + if retries_429 >= int(max_429_retries): + response.raise_for_status() + backoff_time = base_delay * (2**retries_429) + logger.warning( + "NIM endpoint %s returned 429 (rate limited). Retry %d/%d after %.1fs backoff.", + invoke_url, + retries_429, + max_429_retries, + backoff_time, + ) + await asyncio.sleep(backoff_time) + continue + + if status_code == 503 or (500 <= status_code < 600): + if attempt == int(max_retries) - 1: + response.raise_for_status() + backoff_time = base_delay * (2**attempt) + logger.warning( + "NIM endpoint %s returned %d. Retry %d/%d after %.1fs backoff.", + invoke_url, + status_code, + attempt + 1, + max_retries, + backoff_time, + ) + await asyncio.sleep(backoff_time) + attempt += 1 + continue + + if 400 <= status_code < 500: + body = await response.text() + raise aiohttp.ClientResponseError( + response.request_info, + response.history, + status=status_code, + message=f"HTTP {status_code} from {invoke_url}: {body}", + ) + + response.raise_for_status() + return await response.json() + + except asyncio.TimeoutError as exc: + if attempt == int(max_retries) - 1: + raise TimeoutError(f"Request timed out after {attempt + 1} attempts.") from exc + backoff_time = base_delay * (2**attempt) + logger.warning( + "NIM endpoint %s timed out (%.1fs). Retry %d/%d after %.1fs backoff.", + invoke_url, + timeout_s, + attempt + 1, + max_retries, + backoff_time, + ) + await asyncio.sleep(backoff_time) + attempt += 1 + except aiohttp.ClientError as exc: + resp_status = getattr(exc, "status", None) + if resp_status is not None and 400 <= resp_status < 500: + raise + if attempt == int(max_retries) - 1: + raise + backoff_time = base_delay * (2**attempt) + logger.warning( + "NIM endpoint %s request failed: %s. Retry %d/%d after %.1fs backoff.", + invoke_url, + exc, + attempt + 1, + max_retries, + backoff_time, + ) + await asyncio.sleep(backoff_time) + attempt += 1 + + raise RuntimeError(f"Failed to get a successful response after {max_retries} retries.") + + +async def ainvoke_image_inference_batches( + *, + invoke_url: str, + image_b64_list: Sequence[str], + merge_levels: Optional[Sequence[str]] = None, + api_key: Optional[str] = None, + timeout_s: float = 60.0, + max_batch_size: int = 8, + max_concurrency: int = 8, + max_retries: int = 5, + max_429_retries: int = 3, +) -> List[Any]: + """Async version of :func:`invoke_image_inference_batches` using aiohttp.""" + invoke_urls = _parse_invoke_urls(invoke_url) + + token = (api_key or "").strip() + headers: Dict[str, str] = {"Accept": "application/json", "Content-Type": "application/json"} + if token: + headers["Authorization"] = f"Bearer {token}" + + n = len(image_b64_list) + if n == 0: + return [] + + if merge_levels is not None and len(merge_levels) != n: + raise ValueError(f"merge_levels length ({len(merge_levels)}) must match image_b64_list length ({n})") + + ranges = _chunk_ranges(n, int(max_batch_size)) + flattened: List[Optional[Any]] = [None] * n + + async def _invoke_one_batch( + session: aiohttp.ClientSession, + start: int, + end: int, + endpoint_url: str, + ) -> Tuple[int, int, List[Any]]: + inputs = [ + { + "type": "image_url", + "url": f"data:{_mime_from_b64(b64)};base64,{b64}", + } + for b64 in image_b64_list[start:end] + ] + payload: Dict[str, Any] = {"input": inputs} + if merge_levels is not None: + payload["merge_levels"] = list(merge_levels[start:end]) + response_json = await _apost_with_retries( + session=session, + invoke_url=endpoint_url, + payload=payload, + headers=headers, + timeout_s=float(timeout_s), + max_retries=int(max_retries), + max_429_retries=int(max_429_retries), + ) + per_image = _normalize_batch_response(response_json, end - start) + return start, end, per_image + + connector = aiohttp.TCPConnector(limit=max(1, int(max_concurrency))) + async with aiohttp.ClientSession(connector=connector) as session: + tasks = [ + _invoke_one_batch(session, start, end, invoke_urls[idx % len(invoke_urls)]) + for idx, (start, end) in enumerate(ranges) + ] + results = await asyncio.gather(*tasks) + + for start, end, per_image in results: + for i, item in enumerate(per_image): + flattened[start + i] = item + + out: List[Any] = [] + for idx, item in enumerate(flattened): + if item is None: + raise RuntimeError(f"Missing response for item index {idx}") + out.append(item) + return out + + +async def ainvoke_page_elements_batches( + *, + invoke_url: str, + image_b64_list: Sequence[str], + api_key: Optional[str] = None, + timeout_s: float = 60.0, + max_batch_size: int = 8, + max_concurrency: int = 8, + max_retries: int = 5, + max_429_retries: int = 3, +) -> List[Any]: + """Async backward-compatible alias for page-elements callers.""" + return await ainvoke_image_inference_batches( + invoke_url=invoke_url, + image_b64_list=image_b64_list, + api_key=api_key, + timeout_s=timeout_s, + max_batch_size=max_batch_size, + max_concurrency=max_concurrency, + max_retries=max_retries, + max_429_retries=max_429_retries, + ) diff --git a/nemo_retriever/src/nemo_retriever/ocr/cpu_ocr.py b/nemo_retriever/src/nemo_retriever/ocr/cpu_ocr.py index 1d695f18a..0f3c22e52 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/cpu_ocr.py +++ b/nemo_retriever/src/nemo_retriever/ocr/cpu_ocr.py @@ -13,7 +13,7 @@ from nemo_retriever.nim.nim import NIMClient from nemo_retriever.params import RemoteRetryParams from nemo_retriever.ocr.shared import _error_payload -from nemo_retriever.ocr.shared import ocr_page_elements +from nemo_retriever.ocr.shared import aocr_page_elements, ocr_page_elements class OCRCPUActor(AbstractOperator, CPUOperator): @@ -65,9 +65,18 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await aocr_page_elements( + data, + model=self._model, + remote_retry=self._remote_retry, + **self.ocr_kwargs, + **kwargs, + ) + + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as exc: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/ocr/cpu_parse.py b/nemo_retriever/src/nemo_retriever/ocr/cpu_parse.py index bface8158..d87b8f7e1 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/cpu_parse.py +++ b/nemo_retriever/src/nemo_retriever/ocr/cpu_parse.py @@ -10,7 +10,7 @@ from nemo_retriever.graph.cpu_operator import CPUOperator from nemo_retriever.nim.nim import NIMClient from nemo_retriever.params import RemoteRetryParams -from nemo_retriever.ocr.shared import nemotron_parse_page_elements +from nemo_retriever.ocr.shared import anemotron_parse_page_elements, nemotron_parse_page_elements class NemotronParseCPUActor(AbstractOperator, CPUOperator): @@ -75,3 +75,19 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data + + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await anemotron_parse_page_elements( + data, + model=self._model, + invoke_url=self._invoke_url, + api_key=self._api_key, + request_timeout_s=self._request_timeout_s, + task_prompt=self._task_prompt, + extract_text=self._extract_text, + extract_tables=self._extract_tables, + extract_charts=self._extract_charts, + extract_infographics=self._extract_infographics, + remote_retry=self._remote_retry, + **kwargs, + ) diff --git a/nemo_retriever/src/nemo_retriever/ocr/gpu_ocr.py b/nemo_retriever/src/nemo_retriever/ocr/gpu_ocr.py index 94e57a871..3cca7b074 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/gpu_ocr.py +++ b/nemo_retriever/src/nemo_retriever/ocr/gpu_ocr.py @@ -12,7 +12,7 @@ from nemo_retriever.graph.gpu_operator import GPUOperator from nemo_retriever.nim.nim import NIMClient from nemo_retriever.params import RemoteRetryParams -from nemo_retriever.ocr.shared import Image, _error_payload, ocr_page_elements +from nemo_retriever.ocr.shared import Image, _error_payload, aocr_page_elements, ocr_page_elements class OCRActor(AbstractOperator, GPUOperator): @@ -71,9 +71,18 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await aocr_page_elements( + data, + model=self._model, + remote_retry=self._remote_retry, + **self.ocr_kwargs, + **kwargs, + ) + + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as exc: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/ocr/gpu_parse.py b/nemo_retriever/src/nemo_retriever/ocr/gpu_parse.py index 61acf3670..c5f751582 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/gpu_parse.py +++ b/nemo_retriever/src/nemo_retriever/ocr/gpu_parse.py @@ -12,7 +12,7 @@ from nemo_retriever.graph.gpu_operator import GPUOperator from nemo_retriever.nim.nim import NIMClient from nemo_retriever.params import RemoteRetryParams -from nemo_retriever.ocr.shared import _error_payload, nemotron_parse_page_elements +from nemo_retriever.ocr.shared import _error_payload, anemotron_parse_page_elements, nemotron_parse_page_elements class NemotronParseActor(AbstractOperator, GPUOperator): @@ -82,9 +82,25 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await anemotron_parse_page_elements( + data, + model=self._model, + invoke_url=self._invoke_url, + api_key=self._api_key, + request_timeout_s=self._request_timeout_s, + task_prompt=self._task_prompt, + extract_text=self._extract_text, + extract_tables=self._extract_tables, + extract_charts=self._extract_charts, + extract_infographics=self._extract_infographics, + remote_retry=self._remote_retry, + **kwargs, + ) + + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as exc: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/ocr/shared.py b/nemo_retriever/src/nemo_retriever/ocr/shared.py index 72fa20b5e..cbd0b81b3 100644 --- a/nemo_retriever/src/nemo_retriever/ocr/shared.py +++ b/nemo_retriever/src/nemo_retriever/ocr/shared.py @@ -22,7 +22,7 @@ import numpy as np import pandas as pd from nemo_retriever.params import RemoteRetryParams -from nemo_retriever.nim.nim import NIMClient, invoke_image_inference_batches +from nemo_retriever.nim.nim import NIMClient, ainvoke_image_inference_batches, invoke_image_inference_batches from nemo_retriever.utils.table_and_chart import ( join_graphic_elements_and_ocr_output, join_table_structure_and_ocr_output, @@ -1055,3 +1055,389 @@ def nemotron_parse_page_elements( out["infographic_parse"] = all_infographic out["nemotron_parse_v1_2"] = all_meta return out + + +async def aocr_page_elements( + batch_df: Any, + *, + model: Any = None, + invoke_url: Optional[str] = None, + api_key: Optional[str] = None, + request_timeout_s: float = 120.0, + extract_text: bool = False, + extract_tables: bool = False, + extract_charts: bool = False, + extract_infographics: bool = False, + use_graphic_elements: bool = False, + inference_batch_size: int = 8, + remote_retry: RemoteRetryParams | None = None, + **kwargs: Any, +) -> Any: + """Async version of :func:`ocr_page_elements`. + + Remote NIM calls use ``ainvoke_image_inference_batches``; local GPU + inference is delegated to ``asyncio.to_thread`` wrapping the sync version. + """ + import asyncio + + invoke_url_resolved = (invoke_url or kwargs.get("ocr_invoke_url") or "").strip() + use_remote = bool(invoke_url_resolved) + + if not use_remote: + return await asyncio.to_thread( + ocr_page_elements, + batch_df, + model=model, + invoke_url=invoke_url, + api_key=api_key, + request_timeout_s=request_timeout_s, + extract_text=extract_text, + extract_tables=extract_tables, + extract_charts=extract_charts, + extract_infographics=extract_infographics, + use_graphic_elements=use_graphic_elements, + inference_batch_size=inference_batch_size, + remote_retry=remote_retry, + **kwargs, + ) + + retry = remote_retry or RemoteRetryParams( + remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)), + remote_max_retries=int(kwargs.get("remote_max_retries", 10)), + remote_max_429_retries=int(kwargs.get("remote_max_429_retries", 5)), + ) + + if not isinstance(batch_df, pd.DataFrame): + raise NotImplementedError("aocr_page_elements currently only supports pandas.DataFrame input.") + + wanted_labels: set[str] = set() + if extract_tables: + wanted_labels.add("table") + if extract_charts: + wanted_labels.add("chart") + if extract_infographics: + wanted_labels.add("infographic") + + all_table: List[List[Dict[str, Any]]] = [] + all_chart: List[List[Dict[str, Any]]] = [] + all_infographic: List[List[Dict[str, Any]]] = [] + all_text: List[str] = [] + all_ocr_meta: List[Dict[str, Any]] = [] + + t0_total = time.perf_counter() + + for row in batch_df.itertuples(index=False): + table_items: List[Dict[str, Any]] = [] + chart_items: List[Dict[str, Any]] = [] + infographic_items: List[Dict[str, Any]] = [] + row_ocr_text_blocks: List[Dict[str, Any]] = [] + row_error: Any = None + + try: + pe = getattr(row, "page_elements_v3", None) + dets: List[Dict[str, Any]] = [] + if isinstance(pe, dict): + dets = pe.get("detections") or [] + if not isinstance(dets, list): + dets = [] + + page_image = getattr(row, "page_image", None) or {} + page_image_b64 = page_image.get("image_b64") if isinstance(page_image, dict) else None + + if not isinstance(page_image_b64, str) or not page_image_b64: + all_table.append(table_items) + all_chart.append(chart_items) + all_infographic.append(infographic_items) + all_text.append(None) + all_ocr_meta.append({"timing": None, "error": None}) + continue + + row_wanted = wanted_labels + if extract_text: + meta = getattr(row, "metadata", None) or {} + needs_ocr = meta.get("needs_ocr_for_text", False) if isinstance(meta, dict) else False + if needs_ocr: + row_wanted = wanted_labels | _TEXT_LABELS + + crops = _crop_all_from_page(page_image_b64, dets, row_wanted, as_b64=True) + crop_b64s: List[str] = [b64 for _label, _bbox, b64 in crops] + crop_meta: List[Tuple[str, List[float]]] = [(label, bbox) for label, bbox, _b64 in crops] + + if crop_b64s: + response_items = await ainvoke_image_inference_batches( + invoke_url=invoke_url_resolved, + image_b64_list=crop_b64s, + api_key=api_key, + timeout_s=float(request_timeout_s), + max_batch_size=int(kwargs.get("inference_batch_size", 8)), + max_concurrency=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + if len(response_items) != len(crop_meta): + raise RuntimeError(f"Expected {len(crop_meta)} OCR responses, got {len(response_items)}") + + for i, (label_name, bbox) in enumerate(crop_meta): + preds = _extract_remote_ocr_item(response_items[i]) + + if label_name == "chart" and use_graphic_elements: + ge_dets = _find_ge_detections_for_bbox(row, bbox) + if ge_dets: + crop_hw = (0, 0) + try: + _raw = base64.b64decode(crop_b64s[i]) + with Image.open(io.BytesIO(_raw)) as _cim: + _cw, _ch = _cim.size + crop_hw = (_ch, _cw) + except Exception: + pass + text = join_graphic_elements_and_ocr_output(ge_dets, preds, crop_hw) + if text: + chart_items.append({"bbox_xyxy_norm": bbox, "text": text}) + continue + + blocks = _parse_ocr_result(preds) + if label_name == "table": + crop_hw_table: Tuple[int, int] = (0, 0) + try: + _raw = base64.b64decode(crop_b64s[i]) + with Image.open(io.BytesIO(_raw)) as _cim: + _cw, _ch = _cim.size + crop_hw_table = (_ch, _cw) + except Exception: + pass + text = _blocks_to_pseudo_markdown(blocks, crop_hw=crop_hw_table) or _blocks_to_text(blocks) + else: + text = _blocks_to_text(blocks) + entry = {"bbox_xyxy_norm": bbox, "text": text} + if label_name == "table": + table_items.append(entry) + elif label_name == "chart": + chart_items.append(entry) + elif label_name == "infographic": + infographic_items.append(entry) + elif label_name in _TEXT_LABELS: + row_ocr_text_blocks.extend(blocks) + + except BaseException as e: + print(f"Warning: OCR failed: {type(e).__name__}: {e}") + row_error = { + "stage": "ocr_page_elements", + "type": e.__class__.__name__, + "message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + if extract_text and row_ocr_text_blocks: + all_text.append(_blocks_to_text(row_ocr_text_blocks)) + else: + all_text.append(None) + + all_table.append(table_items) + all_chart.append(chart_items) + all_infographic.append(infographic_items) + all_ocr_meta.append({"timing": None, "error": row_error}) + + elapsed = time.perf_counter() - t0_total + for meta in all_ocr_meta: + meta["timing"] = {"seconds": float(elapsed)} + + out = batch_df.copy() + if extract_tables or "table" not in out.columns: + out["table"] = all_table + if extract_charts or "chart" not in out.columns: + out["chart"] = all_chart + if extract_infographics or "infographic" not in out.columns: + out["infographic"] = all_infographic + if extract_text and "text" in out.columns: + for i, ocr_text in enumerate(all_text): + if ocr_text is not None: + out.iat[i, out.columns.get_loc("text")] = ocr_text + elif extract_text: + out["text"] = [t if t is not None else "" for t in all_text] + out["ocr_v1"] = all_ocr_meta + return out + + +async def anemotron_parse_page_elements( + batch_df: Any, + *, + model: Any = None, + invoke_url: Optional[str] = None, + api_key: Optional[str] = None, + request_timeout_s: float = 120.0, + extract_text: bool = False, + extract_tables: bool = False, + extract_charts: bool = False, + extract_infographics: bool = False, + task_prompt: str = "", + remote_retry: RemoteRetryParams | None = None, + **kwargs: Any, +) -> Any: + """Async version of :func:`nemotron_parse_page_elements`. + + Remote NIM calls use ``ainvoke_image_inference_batches``; local inference + is delegated to ``asyncio.to_thread`` wrapping the sync version. + """ + import asyncio + + invoke_url_resolved = (invoke_url or kwargs.get("nemotron_parse_invoke_url") or "").strip() + use_remote = bool(invoke_url_resolved) + + if not use_remote: + return await asyncio.to_thread( + nemotron_parse_page_elements, + batch_df, + model=model, + invoke_url=invoke_url, + api_key=api_key, + request_timeout_s=request_timeout_s, + extract_text=extract_text, + extract_tables=extract_tables, + extract_charts=extract_charts, + extract_infographics=extract_infographics, + task_prompt=task_prompt, + remote_retry=remote_retry, + **kwargs, + ) + + retry = remote_retry or RemoteRetryParams( + remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)), + remote_max_retries=int(kwargs.get("remote_max_retries", 10)), + remote_max_429_retries=int(kwargs.get("remote_max_429_retries", 5)), + ) + + if not isinstance(batch_df, pd.DataFrame): + raise NotImplementedError("anemotron_parse_page_elements currently only supports pandas.DataFrame input.") + + wanted_labels: set[str] = set() + if extract_tables: + wanted_labels.add("table") + if extract_charts: + wanted_labels.add("chart") + if extract_infographics: + wanted_labels.add("infographic") + + all_table: List[List[Dict[str, Any]]] = [] + all_chart: List[List[Dict[str, Any]]] = [] + all_infographic: List[List[Dict[str, Any]]] = [] + all_text: List[str] = [] + all_meta: List[Dict[str, Any]] = [] + + t0_total = time.perf_counter() + + for row in batch_df.itertuples(index=False): + table_items: List[Dict[str, Any]] = [] + chart_items: List[Dict[str, Any]] = [] + infographic_items: List[Dict[str, Any]] = [] + row_text: Optional[str] = None + row_error: Any = None + + try: + pe = getattr(row, "page_elements_v3", None) + dets_list: List[Dict[str, Any]] = [] + if isinstance(pe, dict): + dets_list = pe.get("detections") or [] + if not isinstance(dets_list, list): + dets_list = [] + + page_image = getattr(row, "page_image", None) or {} + page_image_b64 = page_image.get("image_b64") if isinstance(page_image, dict) else None + if not isinstance(page_image_b64, str) or not page_image_b64: + all_table.append(table_items) + all_chart.append(chart_items) + all_infographic.append(infographic_items) + all_text.append(None) + all_meta.append({"timing": None, "error": None}) + continue + + crops = _crop_all_from_page(page_image_b64, dets_list, wanted_labels, as_b64=True) + if not crops and wanted_labels: + crops = [("full_page", [0.0, 0.0, 1.0, 1.0], page_image_b64)] + + crop_b64s: List[str] = [b64 for _label, _bbox, b64 in crops] + crop_meta_items: List[Tuple[str, List[float]]] = [(label, bbox) for label, bbox, _b64 in crops] + + if crop_b64s: + response_items = await ainvoke_image_inference_batches( + invoke_url=invoke_url_resolved, + image_b64_list=crop_b64s, + api_key=api_key, + timeout_s=float(request_timeout_s), + max_batch_size=int(kwargs.get("inference_batch_size", 8)), + max_concurrency=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + if len(response_items) != len(crop_meta_items): + raise RuntimeError(f"Expected {len(crop_meta_items)} Parse responses, got {len(response_items)}") + + for i, (label_name, bbox) in enumerate(crop_meta_items): + text = _extract_parse_text(response_items[i]) + entry = {"bbox_xyxy_norm": bbox, "text": text} + if label_name == "table": + table_items.append(entry) + elif label_name == "chart": + chart_items.append(entry) + elif label_name == "infographic": + infographic_items.append(entry) + elif label_name == "full_page": + if extract_tables: + table_items.append(dict(entry)) + if extract_charts: + chart_items.append(dict(entry)) + if extract_infographics: + infographic_items.append(dict(entry)) + + meta_row = getattr(row, "metadata", None) or {} + needs_ocr = meta_row.get("needs_ocr_for_text", False) if isinstance(meta_row, dict) else False + if extract_text and needs_ocr: + try: + resp = await ainvoke_image_inference_batches( + invoke_url=invoke_url_resolved, + image_b64_list=[page_image_b64], + api_key=api_key, + timeout_s=float(request_timeout_s), + max_batch_size=1, + max_concurrency=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + row_text = _extract_parse_text(resp[0]) if resp else "" + except Exception: + row_text = "" + + except BaseException as e: + print(f"Warning: Nemotron Parse failed: {type(e).__name__}: {e}") + row_error = { + "stage": "nemotron_parse_page_elements", + "type": e.__class__.__name__, + "message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + all_text.append(row_text) + all_table.append(table_items) + all_chart.append(chart_items) + all_infographic.append(infographic_items) + all_meta.append({"timing": None, "error": row_error}) + + elapsed = time.perf_counter() - t0_total + for meta in all_meta: + meta["timing"] = {"seconds": float(elapsed)} + + out = batch_df.copy() + if extract_text and "text" in out.columns: + for i, parse_text in enumerate(all_text): + if parse_text is not None: + out.iat[i, out.columns.get_loc("text")] = parse_text + elif extract_text: + out["text"] = [t if t is not None else "" for t in all_text] + out["table"] = all_table + out["chart"] = all_chart + out["infographic"] = all_infographic + out["table_parse"] = all_table + out["chart_parse"] = all_chart + out["infographic_parse"] = all_infographic + out["nemotron_parse_v1_2"] = all_meta + return out diff --git a/nemo_retriever/src/nemo_retriever/operators/__init__.py b/nemo_retriever/src/nemo_retriever/operators/__init__.py index a0c692ae7..4a67f7b52 100644 --- a/nemo_retriever/src/nemo_retriever/operators/__init__.py +++ b/nemo_retriever/src/nemo_retriever/operators/__init__.py @@ -6,6 +6,6 @@ from nemo_retriever.operators.base import AbstractOperator, CPUOperator, GPUOperator from nemo_retriever.operators.content import ExplodeContentActor -from nemo_retriever.operators.embedding import _BatchEmbedActor +from nemo_retriever.operators.embedding import BatchEmbedActor -__all__ = ["AbstractOperator", "CPUOperator", "GPUOperator", "ExplodeContentActor", "_BatchEmbedActor"] +__all__ = ["AbstractOperator", "CPUOperator", "GPUOperator", "ExplodeContentActor", "BatchEmbedActor"] diff --git a/nemo_retriever/src/nemo_retriever/operators/embedding.py b/nemo_retriever/src/nemo_retriever/operators/embedding.py index 46543675b..09ba48e79 100644 --- a/nemo_retriever/src/nemo_retriever/operators/embedding.py +++ b/nemo_retriever/src/nemo_retriever/operators/embedding.py @@ -6,6 +6,6 @@ from __future__ import annotations -from nemo_retriever.text_embed.operators import _BatchEmbedActor, _BatchEmbedCPUActor, _BatchEmbedGPUActor +from nemo_retriever.text_embed.operators import BatchEmbedActor, BatchEmbedCPUActor, BatchEmbedGPUActor -__all__ = ["_BatchEmbedActor", "_BatchEmbedCPUActor", "_BatchEmbedGPUActor"] +__all__ = ["BatchEmbedActor", "BatchEmbedCPUActor", "BatchEmbedGPUActor"] diff --git a/nemo_retriever/src/nemo_retriever/page_elements/cpu_actor.py b/nemo_retriever/src/nemo_retriever/page_elements/cpu_actor.py index d934ada89..61c4b1b33 100644 --- a/nemo_retriever/src/nemo_retriever/page_elements/cpu_actor.py +++ b/nemo_retriever/src/nemo_retriever/page_elements/cpu_actor.py @@ -11,7 +11,7 @@ from nemo_retriever.graph.abstract_operator import AbstractOperator from nemo_retriever.graph.cpu_operator import CPUOperator from nemo_retriever.nim.nim import NIMClient -from nemo_retriever.page_elements.shared import _error_payload, detect_page_elements_v3 +from nemo_retriever.page_elements.shared import _error_payload, adetect_page_elements_v3, detect_page_elements_v3 class PageElementDetectionCPUActor(AbstractOperator, CPUOperator): @@ -53,9 +53,17 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, pages_df: Any, **override_kwargs: Any) -> Any: + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await adetect_page_elements_v3( + data, + model=self._model, + **self.detect_kwargs, + **kwargs, + ) + + async def __call__(self, pages_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(pages_df, **override_kwargs) + return await self.arun(pages_df, **override_kwargs) except Exception as exc: if isinstance(pages_df, pd.DataFrame): out = pages_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/page_elements/gpu_actor.py b/nemo_retriever/src/nemo_retriever/page_elements/gpu_actor.py index 210751e89..838c09b6b 100644 --- a/nemo_retriever/src/nemo_retriever/page_elements/gpu_actor.py +++ b/nemo_retriever/src/nemo_retriever/page_elements/gpu_actor.py @@ -11,7 +11,7 @@ from nemo_retriever.graph.abstract_operator import AbstractOperator from nemo_retriever.graph.gpu_operator import GPUOperator from nemo_retriever.nim.nim import NIMClient -from nemo_retriever.page_elements.shared import _error_payload, detect_page_elements_v3 +from nemo_retriever.page_elements.shared import _error_payload, adetect_page_elements_v3, detect_page_elements_v3 class PageElementDetectionActor(AbstractOperator, GPUOperator): @@ -56,9 +56,17 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, pages_df: Any, **override_kwargs: Any) -> Any: + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await adetect_page_elements_v3( + data, + model=self._model, + **self.detect_kwargs, + **kwargs, + ) + + async def __call__(self, pages_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(pages_df, **override_kwargs) + return await self.arun(pages_df, **override_kwargs) except Exception as exc: if isinstance(pages_df, pd.DataFrame): out = pages_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/page_elements/shared.py b/nemo_retriever/src/nemo_retriever/page_elements/shared.py index 22f8750c6..ab2615a81 100644 --- a/nemo_retriever/src/nemo_retriever/page_elements/shared.py +++ b/nemo_retriever/src/nemo_retriever/page_elements/shared.py @@ -12,7 +12,7 @@ import traceback import pandas as pd -from nemo_retriever.nim.nim import NIMClient, invoke_page_elements_batches +from nemo_retriever.nim.nim import NIMClient, ainvoke_page_elements_batches, invoke_page_elements_batches from nemo_retriever.params import RemoteRetryParams from nemo_retriever.page_elements.local import ( YOLOX_PAGE_V3_CLASS_LABELS, @@ -734,3 +734,130 @@ def detect_page_elements_v3( _counts_by_label(p.get("detections") or []) if isinstance(p, dict) else {} for p in row_payloads ] return out + + +async def adetect_page_elements_v3( + pages_df: Any, + *, + model: Any = None, + invoke_url: Optional[str] = None, + api_key: Optional[str] = None, + request_timeout_s: float = 120.0, + inference_batch_size: int = 8, + output_column: str = "page_elements_v3", + num_detections_column: str = "page_elements_v3_num_detections", + counts_by_label_column: str = "page_elements_v3_counts_by_label", + remote_retry: RemoteRetryParams | None = None, + **kwargs: Any, +) -> Any: + """Async version of :func:`detect_page_elements_v3`. + + Uses ``ainvoke_page_elements_batches`` for non-blocking remote NIM calls. + Delegates local GPU inference to ``asyncio.to_thread`` wrapping the sync + version. + """ + import asyncio + + invoke_url_resolved = (invoke_url or kwargs.get("page_elements_invoke_url") or "").strip() + use_remote = bool(invoke_url_resolved) + + if not use_remote: + return await asyncio.to_thread( + detect_page_elements_v3, + pages_df, + model=model, + invoke_url=invoke_url, + api_key=api_key, + request_timeout_s=request_timeout_s, + inference_batch_size=inference_batch_size, + output_column=output_column, + num_detections_column=num_detections_column, + counts_by_label_column=counts_by_label_column, + remote_retry=remote_retry, + **kwargs, + ) + + retry = remote_retry or RemoteRetryParams( + remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)), + remote_max_retries=int(kwargs.get("remote_max_retries", 10)), + remote_max_429_retries=int(kwargs.get("remote_max_429_retries", 5)), + ) + + if not isinstance(pages_df, pd.DataFrame): + raise NotImplementedError("adetect_page_elements_v3 currently only supports pandas.DataFrame input.") + + label_names = list(_RETRIEVER_LABEL_NAMES) + thresholds_per_class = [ + YOLOX_PAGE_V3_FINAL_SCORE.get(_RETRIEVER_TO_API.get(name, name), 0.0) for name in label_names + ] + + row_b64: List[Optional[str]] = [] + row_payloads: List[Dict[str, Any]] = [] + + for _, row in pages_df.iterrows(): + try: + b64 = row.get("page_image")["image_b64"] + if not b64: + raise ValueError("No usable image_b64 found in row.") + row_b64.append(b64) + row_payloads.append({"detections": []}) + except BaseException as e: + row_b64.append(None) + row_payloads.append(_error_payload(stage="decode_image", exc=e)) + + valid_indices = [i for i, b64 in enumerate(row_b64) if b64] + + if valid_indices: + valid_b64: List[str] = [] + for row_i in valid_indices: + b64 = row_b64[row_i] + if b64: + valid_b64.append(b64) + + t0 = time.perf_counter() + try: + response_items = await ainvoke_page_elements_batches( + invoke_url=invoke_url_resolved, + image_b64_list=valid_b64, + api_key=api_key, + timeout_s=float(request_timeout_s), + max_batch_size=int(inference_batch_size), + max_concurrency=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + elapsed = time.perf_counter() - t0 + + if len(response_items) != len(valid_indices): + raise RuntimeError( + "Remote response count mismatch: " f"expected {len(valid_indices)}, got {len(response_items)}" + ) + + for local_i, row_i in enumerate(valid_indices): + dets = _remote_response_to_detections( + response_json=response_items[local_i], + label_names=label_names, + thresholds_per_class=thresholds_per_class, + ) + row_payloads[row_i] = { + "detections": dets, + "timing": {"seconds": float(elapsed)}, + "error": None, + } + except BaseException as e: + elapsed = time.perf_counter() - t0 + print(f"Warning: page_elements remote inference failed: {type(e).__name__}: {e}") + for row_i in valid_indices: + row_payloads[row_i] = _error_payload(stage="remote_inference", exc=e) | { + "timing": {"seconds": float(elapsed)} + } + + out = pages_df.copy() + out[output_column] = row_payloads + out[num_detections_column] = [ + int(len(p.get("detections") or [])) if isinstance(p, dict) else 0 for p in row_payloads + ] + out[counts_by_label_column] = [ + _counts_by_label(p.get("detections") or []) if isinstance(p, dict) else {} for p in row_payloads + ] + return out diff --git a/nemo_retriever/src/nemo_retriever/parse/nemotron_parse.py b/nemo_retriever/src/nemo_retriever/parse/nemotron_parse.py index 0ce0f0996..39b975942 100644 --- a/nemo_retriever/src/nemo_retriever/parse/nemotron_parse.py +++ b/nemo_retriever/src/nemo_retriever/parse/nemotron_parse.py @@ -499,9 +499,9 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as e: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() @@ -587,9 +587,9 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as e: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/pdf/extract.py b/nemo_retriever/src/nemo_retriever/pdf/extract.py index e3272ec5c..25b9821a5 100644 --- a/nemo_retriever/src/nemo_retriever/pdf/extract.py +++ b/nemo_retriever/src/nemo_retriever/pdf/extract.py @@ -421,9 +421,9 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, pdf: Any, **override_kwargs: Any) -> Optional[Any]: + async def __call__(self, pdf: Any, **override_kwargs: Any) -> Optional[Any]: try: - return self.run(pdf, **override_kwargs) + return await self.arun(pdf, **override_kwargs) except BaseException as e: # As a last line of defense, never let the Ray UDF raise. source_path = None diff --git a/nemo_retriever/src/nemo_retriever/pdf/split.py b/nemo_retriever/src/nemo_retriever/pdf/split.py index ab9ca2381..6bfcdc419 100644 --- a/nemo_retriever/src/nemo_retriever/pdf/split.py +++ b/nemo_retriever/src/nemo_retriever/pdf/split.py @@ -194,8 +194,8 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, pdf_batch: Any) -> Any: - return self.run(pdf_batch) + async def __call__(self, pdf_batch: Any) -> Any: + return await self.arun(pdf_batch) class PDFSplitActor(ArchetypeOperator): diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index e5be26dea..58ced38e3 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -17,7 +17,6 @@ logger = logging.getLogger(__name__) AUDIO_MATCH_TOLERANCE_SECS = 2.0 -import numpy as np import pandas as pd @@ -230,77 +229,6 @@ def _normalize_query_df(df: pd.DataFrame, *, match_mode: str) -> pd.DataFrame: return df -def _resolve_embedding_endpoint(cfg: RecallConfig) -> Tuple[Optional[str], Optional[bool]]: - """ - Resolve which embedding endpoint to use. - - Returns (endpoint, use_grpc) where: - - endpoint is either an http(s) URL or a host:port string for gRPC - - use_grpc is True for gRPC, False for HTTP, None when no endpoint is configured - """ - http_ep = (cfg.embedding_http_endpoint or "").strip() if isinstance(cfg.embedding_http_endpoint, str) else None - grpc_ep = (cfg.embedding_grpc_endpoint or "").strip() if isinstance(cfg.embedding_grpc_endpoint, str) else None - single = (cfg.embedding_endpoint or "").strip() if isinstance(cfg.embedding_endpoint, str) else None - - if http_ep: - return http_ep, False - if grpc_ep: - return grpc_ep, True - if single: - # Infer protocol: if a URL scheme is present, treat as HTTP; otherwise gRPC. - return single, (not single.lower().startswith("http")) - - return None, None - - -def _embed_queries_nim( - queries: List[str], - *, - endpoint: str, - model: str, - api_key: str, - grpc: bool, -) -> List[List[float]]: - from nv_ingest_api.util.nim import infer_microservice - - # `infer_microservice` returns a list of embeddings. - embeddings = infer_microservice( - queries, - model_name=model, - embedding_endpoint=endpoint, - nvidia_api_key=(api_key or "").strip(), - grpc=bool(grpc), - input_type="query", - ) - # Some backends return numpy arrays; normalize to list-of-list floats. - out: List[List[float]] = [] - for e in embeddings: - if isinstance(e, np.ndarray): - out.append(e.astype("float32").tolist()) - else: - out.append(list(e)) - return out - - -def _embed_queries_local_hf( - queries: List[str], - *, - device: Optional[str], - cache_dir: Optional[str], - batch_size: int, - model_name: Optional[str] = None, -) -> List[List[float]]: - from nemo_retriever.model import create_local_embedder, is_vl_embed_model - - embedder = create_local_embedder(model_name, device=device, hf_cache_dir=cache_dir) - - if is_vl_embed_model(model_name): - vecs = embedder.embed_queries(queries, batch_size=int(batch_size)) - else: - vecs = embedder.embed(["query: " + q for q in queries], batch_size=int(batch_size)) - return vecs.detach().to("cpu").tolist() - - def _hits_to_keys(raw_hits: List[List[Dict[str, Any]]]) -> List[List[str]]: retrieved_keys: List[List[str]] = [] for hits in raw_hits: @@ -517,12 +445,12 @@ def retrieve_and_score( queries = df_query["query"].astype(str).tolist() gold = df_query["golden_answer"].astype(str).tolist() - endpoint, use_grpc = _resolve_embedding_endpoint(cfg) retriever = Retriever( lancedb_uri=cfg.lancedb_uri, lancedb_table=cfg.lancedb_table, embedder=cfg.embedding_model or VL_EMBED_MODEL, embedding_http_endpoint=cfg.embedding_http_endpoint, + embedding_endpoint=cfg.embedding_endpoint, embedding_api_key=cfg.embedding_api_key, top_k=cfg.top_k, nprobes=cfg.nprobes, diff --git a/nemo_retriever/src/nemo_retriever/rerank/rerank.py b/nemo_retriever/src/nemo_retriever/rerank/rerank.py index 5061c51f4..ed5c6c259 100644 --- a/nemo_retriever/src/nemo_retriever/rerank/rerank.py +++ b/nemo_retriever/src/nemo_retriever/rerank/rerank.py @@ -398,9 +398,9 @@ def process(self, batch_df: Any, **override_kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as exc: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() @@ -432,9 +432,9 @@ def process(self, batch_df: Any, **override_kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as exc: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/retriever.py b/nemo_retriever/src/nemo_retriever/retriever.py index 59d3e585a..5ed2b0366 100644 --- a/nemo_retriever/src/nemo_retriever/retriever.py +++ b/nemo_retriever/src/nemo_retriever/retriever.py @@ -4,7 +4,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from pathlib import Path from typing import Any, Optional, Sequence from tqdm import tqdm @@ -94,6 +94,16 @@ class Retriever: # Internal cache for local HF embedders, keyed by model name. _embedder_cache: dict = field(default_factory=dict, init=False, repr=False, compare=False) + def __str__(self) -> str: + lines = [f"{self.__class__.__name__}("] + for f in fields(self): + if f.name.startswith("_"): + continue + val = getattr(self, f.name) + lines.append(f" {f.name}={val!r},") + lines.append(")") + return "\n".join(lines) + def _resolve_embedding_endpoint(self) -> Optional[str]: http_ep = self.embedding_http_endpoint.strip() if isinstance(self.embedding_http_endpoint, str) else None single = self.embedding_endpoint.strip() if isinstance(self.embedding_endpoint, str) else None diff --git a/nemo_retriever/src/nemo_retriever/table/cpu_actor.py b/nemo_retriever/src/nemo_retriever/table/cpu_actor.py index a8c69e20c..251ac4afd 100644 --- a/nemo_retriever/src/nemo_retriever/table/cpu_actor.py +++ b/nemo_retriever/src/nemo_retriever/table/cpu_actor.py @@ -15,7 +15,7 @@ from nemo_retriever.graph.cpu_operator import CPUOperator from nemo_retriever.nim.nim import NIMClient from nemo_retriever.params import RemoteRetryParams -from nemo_retriever.table.shared import table_structure_ocr_page_elements +from nemo_retriever.table.shared import atable_structure_ocr_page_elements, table_structure_ocr_page_elements logger = logging.getLogger(__name__) @@ -149,9 +149,21 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await atable_structure_ocr_page_elements( + data, + table_structure_model=self._table_structure_model, + table_structure_invoke_url=self._table_structure_invoke_url, + api_key=self._api_key, + table_output_format=self._table_output_format, + request_timeout_s=self._request_timeout_s, + remote_retry=self._remote_retry, + **kwargs, + ) + + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as exc: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/table/gpu_actor.py b/nemo_retriever/src/nemo_retriever/table/gpu_actor.py index fd33dc7a8..697fb38cb 100644 --- a/nemo_retriever/src/nemo_retriever/table/gpu_actor.py +++ b/nemo_retriever/src/nemo_retriever/table/gpu_actor.py @@ -12,7 +12,7 @@ from nemo_retriever.graph.gpu_operator import GPUOperator from nemo_retriever.nim.nim import NIMClient from nemo_retriever.params import RemoteRetryParams -from nemo_retriever.table.shared import table_structure_ocr_page_elements +from nemo_retriever.table.shared import atable_structure_ocr_page_elements, table_structure_ocr_page_elements class TableStructureActor(AbstractOperator, GPUOperator): @@ -84,9 +84,21 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def aprocess(self, data: Any, **kwargs: Any) -> Any: + return await atable_structure_ocr_page_elements( + data, + table_structure_model=self._table_structure_model, + table_structure_invoke_url=self._table_structure_invoke_url, + api_key=self._api_key, + table_output_format=self._table_output_format, + request_timeout_s=self._request_timeout_s, + remote_retry=self._remote_retry, + **kwargs, + ) + + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as exc: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/table/shared.py b/nemo_retriever/src/nemo_retriever/table/shared.py index f2b0b1394..56d9be381 100644 --- a/nemo_retriever/src/nemo_retriever/table/shared.py +++ b/nemo_retriever/src/nemo_retriever/table/shared.py @@ -485,6 +485,158 @@ def _run_remote_ts() -> List[Any]: return out +async def atable_structure_ocr_page_elements( + batch_df: Any, + *, + table_structure_model: Any = None, + table_structure_invoke_url: str = "", + api_key: str = "", + request_timeout_s: float = 120.0, + remote_retry: RemoteRetryParams | None = None, + **kwargs: Any, +) -> Any: + """Async version of :func:`table_structure_ocr_page_elements`. + + Uses ``ainvoke_image_inference_batches`` for non-blocking remote NIM calls. + Falls back to ``asyncio.to_thread`` for local GPU inference. + """ + import asyncio + + from nemo_retriever.nim.nim import ainvoke_image_inference_batches + from nemo_retriever.ocr.ocr import _crop_all_from_page, _np_rgb_to_b64_png + + retry = remote_retry or RemoteRetryParams( + remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)), + remote_max_retries=int(kwargs.get("remote_max_retries", 10)), + remote_max_429_retries=int(kwargs.get("remote_max_429_retries", 5)), + ) + + if not isinstance(batch_df, pd.DataFrame): + raise NotImplementedError("atable_structure_ocr_page_elements currently only supports pandas.DataFrame input.") + + ts_url = (table_structure_invoke_url or kwargs.get("table_structure_invoke_url") or "").strip() + use_remote_ts = bool(ts_url) + table_output_format = kwargs.get("table_output_format") + + if not use_remote_ts and table_structure_model is None: + raise ValueError("A local `table_structure_model` is required when `table_structure_invoke_url` is not set.") + + label_names = _labels_from_model(table_structure_model) if table_structure_model is not None else [] + if not label_names: + label_names = _DEFAULT_TABLE_STRUCTURE_LABELS + inference_batch_size = int(kwargs.get("inference_batch_size", 8)) + + all_table: List[List[Dict[str, Any]]] = [] + all_meta: List[Dict[str, Any]] = [] + + t0_total = time.perf_counter() + + for row in batch_df.itertuples(index=False): + table_items: List[Dict[str, Any]] = [] + row_error: Any = None + + try: + pe = getattr(row, "page_elements_v3", None) + dets: List[Dict[str, Any]] = [] + if isinstance(pe, dict): + dets = pe.get("detections") or [] + if not isinstance(dets, list): + dets = [] + + page_image = getattr(row, "page_image", None) or {} + page_image_b64 = page_image.get("image_b64") if isinstance(page_image, dict) else None + + if not isinstance(page_image_b64, str) or not page_image_b64: + all_table.append(table_items) + all_meta.append({"timing": None, "error": None}) + continue + + crops = _crop_all_from_page(page_image_b64, dets, {"table"}) + + if not crops: + all_table.append(table_items) + all_meta.append({"timing": None, "error": None}) + continue + + crop_b64s = [_np_rgb_to_b64_png(crop_array) for _, _, crop_array in crops] if use_remote_ts else [] + + structure_results: List[List[Dict[str, Any]]] = [] + if use_remote_ts: + response_items = await ainvoke_image_inference_batches( + invoke_url=ts_url, + image_b64_list=crop_b64s, + api_key=api_key or None, + timeout_s=float(request_timeout_s), + max_batch_size=inference_batch_size, + max_concurrency=int(retry.remote_max_pool_workers), + max_retries=int(retry.remote_max_retries), + max_429_retries=int(retry.remote_max_429_retries), + ) + if len(response_items) != len(crops): + raise RuntimeError(f"Expected {len(crops)} table-structure responses, got {len(response_items)}") + for resp in response_items: + parsed = _parse_nim_bounding_boxes(resp) + if not parsed: + pred_item = _extract_remote_pred_item(resp) + parsed = _prediction_to_detections(pred_item, label_names=label_names) + structure_results.append([d for d in parsed if (d.get("score") or 0.0) >= YOLOX_TABLE_MIN_SCORE]) + else: + + def _run_local_ts(): + results = [] + for _, _, crop_array in crops: + chw = torch.from_numpy(crop_array).permute(2, 0, 1).contiguous().to(dtype=torch.float32) + h, w = crop_array.shape[:2] + x = chw.unsqueeze(0) + try: + pre = table_structure_model.preprocess(x, (h, w)) + except TypeError: + pre = table_structure_model.preprocess(x) + if isinstance(pre, torch.Tensor) and pre.ndim == 3: + pre = pre.unsqueeze(0) + pred = table_structure_model.invoke(pre, (h, w)) + dets_local = _prediction_to_detections(pred, label_names=label_names) + results.append([d for d in dets_local if (d.get("score") or 0.0) >= YOLOX_TABLE_MIN_SCORE]) + return results + + structure_results = await asyncio.to_thread(_run_local_ts) + + for crop_i, (_, bbox, _) in enumerate(crops): + structure_dets = structure_results[crop_i] + table_items.append( + { + "bbox_xyxy_norm": bbox, + "text": _render_structure_only_text( + structure_dets, + table_output_format=table_output_format, + ), + "structure_detections": structure_dets, + "structure_counts": _count_structure_labels(structure_dets), + } + ) + + except BaseException as e: + print(f"Warning: table-structure failed: {type(e).__name__}: {e}") + row_error = { + "stage": "table_structure_ocr_page_elements", + "type": e.__class__.__name__, + "message": str(e), + "traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)), + } + + all_table.append(table_items) + all_meta.append({"timing": None, "error": row_error}) + + elapsed = time.perf_counter() - t0_total + for meta in all_meta: + meta["timing"] = {"seconds": float(elapsed)} + + out = batch_df.copy() + out["table"] = all_table + out["table_structure_ocr_v1"] = all_meta + return out + + # --------------------------------------------------------------------------- # Combined table-structure + OCR Ray Actor # --------------------------------------------------------------------------- diff --git a/nemo_retriever/src/nemo_retriever/text_embed/cpu_operator.py b/nemo_retriever/src/nemo_retriever/text_embed/cpu_operator.py index c9517aa9f..2216f87e7 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/cpu_operator.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/cpu_operator.py @@ -15,7 +15,7 @@ from nemo_retriever.text_embed.shared import build_embed_kwargs -class _BatchEmbedCPUActor(AbstractOperator, CPUOperator): +class BatchEmbedCPUActor(AbstractOperator, CPUOperator): """CPU-only embedding actor that always targets a remote endpoint.""" DEFAULT_EMBED_INVOKE_URL = "https://integrate.api.nvidia.com/v1/embeddings" diff --git a/nemo_retriever/src/nemo_retriever/text_embed/gpu_operator.py b/nemo_retriever/src/nemo_retriever/text_embed/gpu_operator.py index 0eb60bddc..090f490c0 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/gpu_operator.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/gpu_operator.py @@ -15,7 +15,7 @@ from nemo_retriever.text_embed.shared import build_embed_kwargs -class _BatchEmbedActor(AbstractOperator, GPUOperator): +class BatchEmbedGPUActor(AbstractOperator, GPUOperator): """Graph embedding actor that loads a local embedder or calls a remote endpoint.""" def __init__(self, params: EmbedParams) -> None: @@ -55,5 +55,5 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any) -> Any: - return self.run(batch_df) + async def __call__(self, batch_df: Any) -> Any: + return await self.arun(batch_df) diff --git a/nemo_retriever/src/nemo_retriever/text_embed/operators.py b/nemo_retriever/src/nemo_retriever/text_embed/operators.py index 473a361e4..a845b7fe3 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/operators.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/operators.py @@ -12,7 +12,7 @@ from nemo_retriever.graph.operator_archetype import ArchetypeOperator from nemo_retriever.text_embed.runtime import embed_text_main_text_embed -__all__ = ["_BatchEmbedActor", "embed_text_main_text_embed"] +__all__ = ["BatchEmbedActor", "embed_text_main_text_embed"] @designer_component( @@ -22,7 +22,7 @@ description="Generates embeddings in batches using configurable embedding parameters", category_color="#e06cff", ) -class _BatchEmbedActor(ArchetypeOperator): +class BatchEmbedActor(ArchetypeOperator): """Graph-facing batch embedding archetype.""" @classmethod @@ -33,27 +33,37 @@ def prefers_cpu_variant(cls, operator_kwargs: dict[str, Any] | None = None) -> b @classmethod def cpu_variant_class(cls): - from nemo_retriever.text_embed.cpu_operator import _BatchEmbedCPUActor + from nemo_retriever.text_embed.cpu_operator import BatchEmbedCPUActor - return _BatchEmbedCPUActor + return BatchEmbedCPUActor @classmethod def gpu_variant_class(cls): - from nemo_retriever.text_embed.gpu_operator import _BatchEmbedActor as _BatchEmbedGPUActor + from nemo_retriever.text_embed.gpu_operator import BatchEmbedGPUActor - return _BatchEmbedGPUActor + return BatchEmbedGPUActor def __init__(self, params: Any) -> None: super().__init__(params=params) def __getattr__(name: str): + if name == "BatchEmbedCPUActor": + from nemo_retriever.text_embed.cpu_operator import BatchEmbedCPUActor + + return BatchEmbedCPUActor + if name == "BatchEmbedGPUActor": + from nemo_retriever.text_embed.gpu_operator import BatchEmbedGPUActor + + return BatchEmbedGPUActor + if name == "_BatchEmbedActor": + return BatchEmbedActor if name == "_BatchEmbedCPUActor": - from nemo_retriever.text_embed.cpu_operator import _BatchEmbedCPUActor + from nemo_retriever.text_embed.cpu_operator import BatchEmbedCPUActor - return _BatchEmbedCPUActor + return BatchEmbedCPUActor if name == "_BatchEmbedGPUActor": - from nemo_retriever.text_embed.gpu_operator import _BatchEmbedActor as _BatchEmbedGPUActor + from nemo_retriever.text_embed.gpu_operator import BatchEmbedGPUActor - return _BatchEmbedGPUActor + return BatchEmbedGPUActor raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py b/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py index 08e528a44..e4de03e2b 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/text_embed.py @@ -223,9 +223,9 @@ def process(self, batch_df: Any, **override_kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: + async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any: try: - return self.run(batch_df, **override_kwargs) + return await self.arun(batch_df, **override_kwargs) except BaseException as e: if isinstance(batch_df, pd.DataFrame): out = batch_df.copy() diff --git a/nemo_retriever/src/nemo_retriever/txt/ray_data.py b/nemo_retriever/src/nemo_retriever/txt/ray_data.py index 7bdbd2dd8..69ced0d22 100644 --- a/nemo_retriever/src/nemo_retriever/txt/ray_data.py +++ b/nemo_retriever/src/nemo_retriever/txt/ray_data.py @@ -57,8 +57,8 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: pd.DataFrame) -> pd.DataFrame: - return self.run(batch_df) + async def __call__(self, batch_df: pd.DataFrame) -> pd.DataFrame: + return await self.arun(batch_df) @designer_component( @@ -111,8 +111,8 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: pd.DataFrame) -> pd.DataFrame: - return self.run(batch_df) + async def __call__(self, batch_df: pd.DataFrame) -> pd.DataFrame: + return await self.arun(batch_df) class TextChunkActor(ArchetypeOperator): diff --git a/nemo_retriever/src/nemo_retriever/utils/convert/to_pdf.py b/nemo_retriever/src/nemo_retriever/utils/convert/to_pdf.py index bbe94d346..fd207991f 100644 --- a/nemo_retriever/src/nemo_retriever/utils/convert/to_pdf.py +++ b/nemo_retriever/src/nemo_retriever/utils/convert/to_pdf.py @@ -179,8 +179,8 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data - def __call__(self, batch_df: Any) -> Any: - return self.run(batch_df) + async def __call__(self, batch_df: Any) -> Any: + return await self.arun(batch_df) class DocToPdfConversionActor(ArchetypeOperator): diff --git a/nemo_retriever/tests/test_actor_operators.py b/nemo_retriever/tests/test_actor_operators.py index 1f857e63a..0c8976c77 100644 --- a/nemo_retriever/tests/test_actor_operators.py +++ b/nemo_retriever/tests/test_actor_operators.py @@ -5,12 +5,13 @@ """Unit tests verifying all pipeline actors inherit from AbstractOperator.""" from pathlib import Path -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pandas as pd import pytest from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.tests.testing_utils import _run # --------------------------------------------------------------------------- @@ -52,7 +53,7 @@ def test_call_delegates_to_run(self, mock_fn): expected = pd.DataFrame({"page": [1]}) mock_fn.return_value = expected actor = self._make() - result = actor(pd.DataFrame({"bytes": [b"x"]})) + result = _run(actor(pd.DataFrame({"bytes": [b"x"]}))) pd.testing.assert_frame_equal(result, expected) @@ -89,14 +90,14 @@ def test_call_delegates_to_run(self, mock_fn): expected = pd.DataFrame({"text": ["hello"]}) mock_fn.return_value = expected actor = self._make() - result = actor(pd.DataFrame({"bytes": [b"x"]})) + result = _run(actor(pd.DataFrame({"bytes": [b"x"]}))) pd.testing.assert_frame_equal(result, expected) @patch("nemo_retriever.pdf.extract.pdf_extraction", side_effect=RuntimeError("boom")) def test_call_error_handling(self, mock_fn): actor = self._make() df = pd.DataFrame({"bytes": [b"x"], "path": ["/tmp/a.pdf"]}) - result = actor(df) + result = _run(actor(df)) assert isinstance(result, list) record = result[0] assert record["metadata"]["error"]["type"] == "RuntimeError" @@ -110,15 +111,17 @@ def test_pdfium_output_can_have_empty_text_without_ocr_flag(self): pytest.skip(f"External regression fixture not available: {pdf_path}") source_df = pd.DataFrame({"path": [str(pdf_path)], "bytes": [pdf_path.read_bytes()]}) - split_df = PDFSplitActor()(source_df) - - result = PDFExtractionActor( - method="pdfium", - extract_text=True, - extract_tables=True, - extract_charts=True, - extract_infographics=True, - )(split_df.head(5)) + split_df = _run(PDFSplitActor()(source_df)) + + result = _run( + PDFExtractionActor( + method="pdfium", + extract_text=True, + extract_tables=True, + extract_charts=True, + extract_infographics=True, + )(split_df.head(5)) + ) first_page = result[result["page_number"] == 1].iloc[0] metadata = first_page["metadata"] @@ -157,19 +160,23 @@ def test_process(self, mock_fn): mock_fn.assert_called_once() pd.testing.assert_frame_equal(result, expected) - @patch("nemo_retriever.page_elements.cpu_actor.detect_page_elements_v3") + @patch("nemo_retriever.page_elements.cpu_actor.adetect_page_elements_v3", new_callable=AsyncMock) def test_call_delegates(self, mock_fn): expected = pd.DataFrame({"page_elements_v3": ["det"]}) mock_fn.return_value = expected actor = self._make() - result = actor(pd.DataFrame({"page_image": ["x"]})) + result = _run(actor(pd.DataFrame({"page_image": ["x"]}))) pd.testing.assert_frame_equal(result, expected) - @patch("nemo_retriever.page_elements.cpu_actor.detect_page_elements_v3", side_effect=RuntimeError("boom")) + @patch( + "nemo_retriever.page_elements.cpu_actor.adetect_page_elements_v3", + new_callable=AsyncMock, + side_effect=RuntimeError("boom"), + ) def test_call_error_handling(self, mock_fn): actor = self._make() df = pd.DataFrame({"page_image": ["x"]}) - result = actor(df) + result = _run(actor(df)) assert isinstance(result, pd.DataFrame) assert "page_elements_v3" in result.columns @@ -205,11 +212,15 @@ def test_process(self, mock_fn): mock_fn.assert_called_once() pd.testing.assert_frame_equal(result, expected) - @patch("nemo_retriever.chart.cpu_actor.graphic_elements_ocr_page_elements", side_effect=RuntimeError("boom")) + @patch( + "nemo_retriever.chart.cpu_actor.agraphic_elements_ocr_page_elements", + new_callable=AsyncMock, + side_effect=RuntimeError("boom"), + ) def test_call_error_handling(self, mock_fn): actor = self._make() df = pd.DataFrame({"page_image": ["x"]}) - result = actor(df) + result = _run(actor(df)) assert isinstance(result, pd.DataFrame) assert "graphic_elements_ocr_v1" in result.columns @@ -242,11 +253,15 @@ def test_process(self, mock_fn): mock_fn.assert_called_once() pd.testing.assert_frame_equal(result, expected) - @patch("nemo_retriever.table.cpu_actor.table_structure_ocr_page_elements", side_effect=RuntimeError("boom")) + @patch( + "nemo_retriever.table.cpu_actor.atable_structure_ocr_page_elements", + new_callable=AsyncMock, + side_effect=RuntimeError("boom"), + ) def test_call_error_handling(self, mock_fn): actor = self._make() df = pd.DataFrame({"page_image": ["x"]}) - result = actor(df) + result = _run(actor(df)) assert isinstance(result, pd.DataFrame) assert "table_structure_ocr_v1" in result.columns @@ -279,11 +294,11 @@ def test_process(self, mock_fn): mock_fn.assert_called_once() pd.testing.assert_frame_equal(result, expected) - @patch("nemo_retriever.ocr.cpu_ocr.ocr_page_elements", side_effect=RuntimeError("boom")) + @patch("nemo_retriever.ocr.cpu_ocr.aocr_page_elements", new_callable=AsyncMock, side_effect=RuntimeError("boom")) def test_call_error_handling(self, mock_fn): actor = self._make() df = pd.DataFrame({"page_image": ["x"]}) - result = actor(df) + result = _run(actor(df)) assert isinstance(result, pd.DataFrame) assert "ocr_v1" in result.columns @@ -320,7 +335,7 @@ def test_process(self, mock_fn): def test_call_error_handling(self, mock_fn): actor = self._make() df = pd.DataFrame({"page_image": ["x"]}) - result = actor(df) + result = _run(actor(df)) assert isinstance(result, pd.DataFrame) assert "nemotron_parse_v1_2" in result.columns @@ -360,7 +375,7 @@ def test_call_delegates(self, mock_fn): expected = pd.DataFrame({"text": ["chunk1"]}) mock_fn.return_value = expected actor = self._make() - result = actor(pd.DataFrame({"text": ["hello world"]})) + result = _run(actor(pd.DataFrame({"text": ["hello world"]}))) pd.testing.assert_frame_equal(result, expected) @@ -403,7 +418,7 @@ def test_call_delegates(self, mock_fn): expected = pd.DataFrame({"path": ["/tmp/a.png"], "page_number": [0]}) mock_fn.return_value = expected actor = self._make() - result = actor(pd.DataFrame({"bytes": [b"img"], "path": ["/tmp/a.png"]})) + result = _run(actor(pd.DataFrame({"bytes": [b"img"], "path": ["/tmp/a.png"]}))) pd.testing.assert_frame_equal(result, expected) @@ -446,7 +461,7 @@ def test_call_delegates(self, mock_fn): expected = pd.DataFrame({"text": ["chunk"], "path": ["/a.txt"], "page_number": [0], "metadata": [{}]}) mock_fn.return_value = expected actor = self._make() - result = actor(pd.DataFrame({"bytes": [b"hello"], "path": ["/a.txt"]})) + result = _run(actor(pd.DataFrame({"bytes": [b"hello"], "path": ["/a.txt"]}))) pd.testing.assert_frame_equal(result, expected) @@ -484,25 +499,25 @@ def test_call_delegates(self, mock_fn): expected = pd.DataFrame({"text": ["chunk"], "path": ["/a.html"], "page_number": [0], "metadata": [{}]}) mock_fn.return_value = expected actor = self._make() - result = actor(pd.DataFrame({"bytes": [b"

hi

"], "path": ["/a.html"]})) + result = _run(actor(pd.DataFrame({"bytes": [b"

hi

"], "path": ["/a.html"]}))) pd.testing.assert_frame_equal(result, expected) # --------------------------------------------------------------------------- -# 12. _BatchEmbedActor +# 12. BatchEmbedActor # --------------------------------------------------------------------------- class TestBatchEmbedActor: def _make(self): from nemo_retriever.params import EmbedParams - from nemo_retriever.text_embed.operators import _BatchEmbedActor + from nemo_retriever.text_embed.operators import BatchEmbedActor params = EmbedParams(model_name="test-model", embed_invoke_url="http://fake") - return _BatchEmbedActor(params=params) + return BatchEmbedActor(params=params) def test_inherits(self): - from nemo_retriever.text_embed.operators import _BatchEmbedActor + from nemo_retriever.text_embed.operators import BatchEmbedActor - assert issubclass(_BatchEmbedActor, AbstractOperator) + assert issubclass(BatchEmbedActor, AbstractOperator) def test_preprocess_passthrough(self): actor = self._make() @@ -528,5 +543,5 @@ def test_call_delegates(self, mock_fn): expected = pd.DataFrame({"text": ["hello"], "embedding": [[0.1, 0.2]]}) mock_fn.return_value = expected actor = self._make() - result = actor(pd.DataFrame({"text": ["hello"]})) + result = _run(actor(pd.DataFrame({"text": ["hello"]}))) pd.testing.assert_frame_equal(result, expected) diff --git a/nemo_retriever/tests/test_asr_actor.py b/nemo_retriever/tests/test_asr_actor.py index c7297e391..b98494232 100644 --- a/nemo_retriever/tests/test_asr_actor.py +++ b/nemo_retriever/tests/test_asr_actor.py @@ -20,6 +20,7 @@ from nemo_retriever.audio.asr_actor import ASRActor, ASRCPUActor from nemo_retriever.audio.asr_actor import apply_asr_to_df from nemo_retriever.params import ASRParams +from nemo_retriever.tests.testing_utils import _run def test_strip_pad_from_transcript(): @@ -49,7 +50,7 @@ def test_asr_actor_empty_batch(): params = ASRParams(audio_endpoints=("localhost:50051", None)) actor = ASRActor(params=params) empty = pd.DataFrame(columns=["path", "bytes"]) - out = actor(empty) + out = _run(actor(empty)) assert isinstance(out, pd.DataFrame) assert "text" in out.columns @@ -79,7 +80,7 @@ def test_asr_actor_mock_transcribe(): } ] ) - out = actor(batch) + out = _run(actor(batch)) assert len(out) == 1 assert out["text"].iloc[0] == "hello world transcript" @@ -142,7 +143,7 @@ def test_asr_actor_remote_segment_audio(): } ] ) - out = actor(batch) + out = _run(actor(batch)) assert len(out) == 2 assert out["text"].tolist() == ["Hello world.", "How are you?"] @@ -223,7 +224,7 @@ def test_local_asr_does_not_call_get_client(): } ] ) - out = actor(batch) + out = _run(actor(batch)) assert len(out) == 1 assert out["text"].iloc[0] == "mocked local transcript" diff --git a/nemo_retriever/tests/test_audio_chunk_actor.py b/nemo_retriever/tests/test_audio_chunk_actor.py index 293825488..4874fab4e 100644 --- a/nemo_retriever/tests/test_audio_chunk_actor.py +++ b/nemo_retriever/tests/test_audio_chunk_actor.py @@ -17,6 +17,7 @@ from nemo_retriever.audio.chunk_actor import audio_path_to_chunks_df from nemo_retriever.audio.media_interface import is_media_available from nemo_retriever.params import AudioChunkParams +from nemo_retriever.tests.testing_utils import _run def _make_small_wav(path: Path, duration_sec: float = 0.5, sample_rate: int = 8000) -> None: @@ -36,7 +37,7 @@ def test_media_chunk_actor_empty_batch(): params = AudioChunkParams(split_type="size", split_interval=1000) actor = MediaChunkActor(params=params) empty = pd.DataFrame(columns=["path", "bytes"]) - out = actor(empty) + out = _run(actor(empty)) assert isinstance(out, pd.DataFrame) assert list(out.columns) == CHUNK_COLUMNS assert len(out) == 0 @@ -54,7 +55,7 @@ def test_media_chunk_actor_single_small_file(tmp_path: Path): params = AudioChunkParams(split_type="size", split_interval=1_000_000) actor = MediaChunkActor(params=params) batch = pd.DataFrame([{"path": str(wav.resolve()), "bytes": body}]) - out = actor(batch) + out = _run(actor(batch)) assert isinstance(out, pd.DataFrame) for col in ["path", "source_path", "duration", "chunk_index", "metadata", "page_number", "bytes"]: diff --git a/nemo_retriever/tests/test_chart_graphic_elements.py b/nemo_retriever/tests/test_chart_graphic_elements.py index 88e4433a4..c1dc6e0aa 100644 --- a/nemo_retriever/tests/test_chart_graphic_elements.py +++ b/nemo_retriever/tests/test_chart_graphic_elements.py @@ -14,6 +14,7 @@ import pandas as pd import pytest +from nemo_retriever.tests.testing_utils import _run from nemo_retriever.utils.table_and_chart import join_graphic_elements_and_ocr_output @@ -291,7 +292,7 @@ def test_actor_error_returns_dataframe_with_error(self) -> None: df = _make_chart_page_df() # This will fail because both models are None and no URLs set. - result = actor(df) + result = _run(actor(df)) assert "chart" in result.columns assert "graphic_elements_ocr_v1" in result.columns meta = result.iloc[0]["graphic_elements_ocr_v1"] diff --git a/nemo_retriever/tests/test_doc_to_pdf_actor.py b/nemo_retriever/tests/test_doc_to_pdf_actor.py index c282aa9f2..4a41e9259 100644 --- a/nemo_retriever/tests/test_doc_to_pdf_actor.py +++ b/nemo_retriever/tests/test_doc_to_pdf_actor.py @@ -8,6 +8,7 @@ import pytest from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.tests.testing_utils import _run from nemo_retriever.utils.convert.to_pdf import DocToPdfConversionActor, convert_to_pdf_bytes, convert_batch_to_pdf @@ -79,6 +80,6 @@ def test_call_delegates_to_run(self, mock_convert): mock_convert.return_value = expected actor = DocToPdfConversionActor() df = pd.DataFrame({"bytes": [b"docx"], "path": ["/tmp/test.docx"]}) - result = actor(df) + result = _run(actor(df)) mock_convert.assert_called_once_with(df) pd.testing.assert_frame_equal(result, expected) diff --git a/nemo_retriever/tests/test_image_load.py b/nemo_retriever/tests/test_image_load.py index 623056059..63805e1c6 100644 --- a/nemo_retriever/tests/test_image_load.py +++ b/nemo_retriever/tests/test_image_load.py @@ -22,6 +22,8 @@ image_file_to_pages_df, ) from nemo_retriever.image.ray_data import ImageLoadActor # noqa: E402 +from nemo_retriever.tests.testing_utils import _run # noqa: E402 + # -- Helpers ------------------------------------------------------------------ @@ -175,7 +177,7 @@ def test_batch_processing(self) -> None: {"bytes": img2, "path": "/b/img2.png"}, ] ) - result = actor(batch) + result = _run(actor(batch)) assert isinstance(result, pd.DataFrame) assert len(result) == 2 assert list(result["path"]) == ["/a/img1.png", "/b/img2.png"] @@ -184,14 +186,14 @@ def test_batch_processing(self) -> None: def test_empty_batch(self) -> None: actor = ImageLoadActor() - result = actor(pd.DataFrame()) + result = _run(actor(pd.DataFrame())) assert isinstance(result, pd.DataFrame) assert len(result) == 0 def test_missing_columns_skipped(self) -> None: actor = ImageLoadActor() batch = pd.DataFrame([{"bytes": b"data"}]) # no 'path' column - result = actor(batch) + result = _run(actor(batch)) assert len(result) == 0 def test_corrupt_row_skipped(self) -> None: @@ -203,7 +205,7 @@ def test_corrupt_row_skipped(self) -> None: {"bytes": good, "path": "/good.png"}, ] ) - result = actor(batch) + result = _run(actor(batch)) # Corrupt row produces an error record, good row succeeds. assert len(result) == 2 diff --git a/nemo_retriever/tests/test_ingest_plans.py b/nemo_retriever/tests/test_ingest_plans.py index ae9e991bd..c130bc02c 100644 --- a/nemo_retriever/tests/test_ingest_plans.py +++ b/nemo_retriever/tests/test_ingest_plans.py @@ -7,7 +7,7 @@ from nemo_retriever.graph.pipeline_graph import Graph from nemo_retriever.ocr.ocr import OCRActor from nemo_retriever.page_elements.page_elements import PageElementDetectionActor -from nemo_retriever.text_embed.operators import _BatchEmbedActor +from nemo_retriever.text_embed.operators import BatchEmbedActor from nemo_retriever.graph.operator_archetype import ArchetypeOperator from nemo_retriever.graph.cpu_operator import CPUOperator from nemo_retriever.graph.gpu_operator import GPUOperator @@ -101,7 +101,7 @@ def test_build_graph_accepts_execution_plan() -> None: break node = node.children[0] - assert names == ["MultiTypeExtractOperator", "TextChunkActor", "_BatchEmbedActor"] + assert names == ["MultiTypeExtractOperator", "TextChunkActor", "BatchEmbedActor"] def test_build_graph_keeps_archetype_operator_classes() -> None: @@ -127,11 +127,11 @@ def test_build_graph_keeps_archetype_operator_classes() -> None: "PageElementDetectionActor", "OCRActor", "UDFOperator", - "_BatchEmbedActor", + "BatchEmbedActor", ] assert nodes[3].operator_class is PageElementDetectionActor assert nodes[4].operator_class is OCRActor - assert nodes[-1].operator_class is _BatchEmbedActor + assert nodes[-1].operator_class is BatchEmbedActor assert issubclass(nodes[3].operator_class, ArchetypeOperator) assert issubclass(nodes[4].operator_class, ArchetypeOperator) assert issubclass(nodes[-1].operator_class, ArchetypeOperator) @@ -162,10 +162,10 @@ def test_build_graph_resolves_endpoint_configured_nodes_to_cpu_variants() -> Non assert classes["TableStructureActor"].__name__ == "TableStructureCPUActor" assert classes["GraphicElementsActor"].__name__ == "GraphicElementsCPUActor" assert classes["OCRActor"].__name__ == "OCRCPUActor" - assert classes["_BatchEmbedActor"].__name__ == "_BatchEmbedCPUActor" + assert classes["BatchEmbedActor"].__name__ == "BatchEmbedCPUActor" assert issubclass(classes["PageElementDetectionActor"], CPUOperator) assert issubclass(classes["OCRActor"], CPUOperator) - assert issubclass(classes["_BatchEmbedActor"], CPUOperator) + assert issubclass(classes["BatchEmbedActor"], CPUOperator) def test_build_graph_resolves_local_nodes_to_gpu_variants_when_gpus_available() -> None: @@ -185,10 +185,10 @@ def test_build_graph_resolves_local_nodes_to_gpu_variants_when_gpus_available() assert classes["PageElementDetectionActor"] is not PageElementDetectionActor assert classes["OCRActor"] is not OCRActor - assert classes["_BatchEmbedActor"] is not _BatchEmbedActor + assert classes["BatchEmbedActor"] is not BatchEmbedActor assert issubclass(classes["PageElementDetectionActor"], GPUOperator) assert issubclass(classes["OCRActor"], GPUOperator) - assert issubclass(classes["_BatchEmbedActor"], GPUOperator) + assert issubclass(classes["BatchEmbedActor"], GPUOperator) def test_batch_tuning_to_node_overrides_auto_cpu_only_when_no_gpus() -> None: @@ -221,11 +221,11 @@ def test_batch_tuning_to_node_overrides_auto_cpu_only_when_no_gpus() -> None: cluster_resources=cluster, ) - assert overrides["_BatchEmbedActor"]["num_gpus"] == 0.0 + assert overrides["BatchEmbedActor"]["num_gpus"] == 0.0 assert overrides["OCRActor"]["num_gpus"] == 0.0 assert overrides["PageElementDetectionActor"]["num_gpus"] == 0.0 assert overrides["NemotronParseActor"]["num_gpus"] == 0.0 - assert overrides["_BatchEmbedActor"]["concurrency"] == 5 + assert overrides["BatchEmbedActor"]["concurrency"] == 5 assert overrides["OCRActor"]["concurrency"] == 4 assert overrides["PageElementDetectionActor"]["concurrency"] == 3 assert overrides["NemotronParseActor"]["concurrency"] == 2 @@ -310,7 +310,7 @@ def test_build_inprocess_graph_accepts_execution_plan() -> None: "TextChunkActor", "CaptionActor", "UDFOperator", - "_BatchEmbedActor", + "BatchEmbedActor", ] @@ -335,7 +335,7 @@ def test_build_inprocess_graph_supports_text_execution_plan() -> None: break node = node.children[0] - assert names == ["MultiTypeExtractOperator", "TextChunkActor", "_BatchEmbedActor"] + assert names == ["MultiTypeExtractOperator", "TextChunkActor", "BatchEmbedActor"] @pytest.mark.skipif(not is_media_available(), reason="ffmpeg not available") diff --git a/nemo_retriever/tests/test_nemotron_rerank_v2.py b/nemo_retriever/tests/test_nemotron_rerank_v2.py index 5a412bb44..179a01a32 100644 --- a/nemo_retriever/tests/test_nemotron_rerank_v2.py +++ b/nemo_retriever/tests/test_nemotron_rerank_v2.py @@ -17,6 +17,8 @@ import pytest +from nemo_retriever.tests.testing_utils import _run + # --------------------------------------------------------------------------- # Helpers to build lightweight torch / transformers stubs @@ -539,7 +541,7 @@ def test_actor_call_scores_dataframe(self): ] with patch("requests.post", return_value=mock_resp): - out = actor(df) + out = _run(actor(df)) assert "rerank_score" in out.columns assert len(out) == 2 @@ -559,7 +561,7 @@ def test_actor_call_sorts_descending_by_default(self): ] with patch("requests.post", return_value=mock_resp): - out = actor(df) + out = _run(actor(df)) scores = out["rerank_score"].tolist() assert scores == sorted(scores, reverse=True) @@ -572,7 +574,7 @@ def test_actor_call_returns_error_payload_on_exception(self): df = pd.DataFrame({"query": ["q"], "text": ["doc"]}) with patch("requests.post", side_effect=RuntimeError("connection failed")): - out = actor(df) + out = _run(actor(df)) # Should not raise; should return a DataFrame with error payload assert isinstance(out, pd.DataFrame) @@ -592,6 +594,6 @@ def test_actor_custom_score_column_name(self): mock_resp.json.return_value = {"rankings": [{"index": 0, "logit": 0.7}]} with patch("requests.post", return_value=mock_resp): - out = actor(df) + out = _run(actor(df)) assert "my_score" in out.columns diff --git a/nemo_retriever/tests/test_operator_flags_and_cpu_actors.py b/nemo_retriever/tests/test_operator_flags_and_cpu_actors.py index 7baf23ff0..087f7207a 100644 --- a/nemo_retriever/tests/test_operator_flags_and_cpu_actors.py +++ b/nemo_retriever/tests/test_operator_flags_and_cpu_actors.py @@ -27,7 +27,7 @@ def test_gpu_operators_have_flag(self): from nemo_retriever.table.table_detection import TableStructureGPUActor from nemo_retriever.ocr.ocr import OCRGPUActor from nemo_retriever.parse.nemotron_parse import NemotronParseGPUActor - from nemo_retriever.text_embed.operators import _BatchEmbedGPUActor + from nemo_retriever.text_embed.operators import BatchEmbedGPUActor from nemo_retriever.caption.caption import CaptionGPUActor from nemo_retriever.infographic.infographic_detection import InfographicDetectionGPUActor from nemo_retriever.rerank.rerank import NemotronRerankGPUActor @@ -38,7 +38,7 @@ def test_gpu_operators_have_flag(self): assert issubclass(TableStructureGPUActor, GPUOperator) assert issubclass(OCRGPUActor, GPUOperator) assert issubclass(NemotronParseGPUActor, GPUOperator) - assert issubclass(_BatchEmbedGPUActor, GPUOperator) + assert issubclass(BatchEmbedGPUActor, GPUOperator) assert issubclass(CaptionGPUActor, GPUOperator) assert issubclass(InfographicDetectionGPUActor, GPUOperator) assert issubclass(NemotronRerankGPUActor, GPUOperator) @@ -311,47 +311,47 @@ def _make_params(self): return EmbedParams(model_name="test-model", embed_invoke_url="http://fake") def test_inherits_cpu_operator(self): - from nemo_retriever.text_embed.cpu_operator import _BatchEmbedCPUActor + from nemo_retriever.text_embed.cpu_operator import BatchEmbedCPUActor - assert issubclass(_BatchEmbedCPUActor, CPUOperator) - assert not issubclass(_BatchEmbedCPUActor, GPUOperator) + assert issubclass(BatchEmbedCPUActor, CPUOperator) + assert not issubclass(BatchEmbedCPUActor, GPUOperator) def test_uses_default_invoke_url(self): - from nemo_retriever.text_embed.cpu_operator import _BatchEmbedCPUActor + from nemo_retriever.text_embed.cpu_operator import BatchEmbedCPUActor from nemo_retriever.params import EmbedParams - actor = _BatchEmbedCPUActor(params=EmbedParams(model_name="test-model")) + actor = BatchEmbedCPUActor(params=EmbedParams(model_name="test-model")) assert actor._model is None assert "integrate.api.nvidia.com" in actor._kwargs["embedding_endpoint"] def test_creates_with_custom_invoke_url(self): - from nemo_retriever.text_embed.cpu_operator import _BatchEmbedCPUActor + from nemo_retriever.text_embed.cpu_operator import BatchEmbedCPUActor - actor = _BatchEmbedCPUActor(params=self._make_params()) + actor = BatchEmbedCPUActor(params=self._make_params()) assert actor._model is None assert actor._kwargs["embedding_endpoint"] == "http://fake" @patch("nemo_retriever.text_embed.cpu_operator.embed_text_main_text_embed") def test_process(self, mock_fn): - from nemo_retriever.text_embed.cpu_operator import _BatchEmbedCPUActor + from nemo_retriever.text_embed.cpu_operator import BatchEmbedCPUActor expected = pd.DataFrame({"text": ["hello"], "embedding": [[0.1, 0.2]]}) mock_fn.return_value = expected - actor = _BatchEmbedCPUActor(params=self._make_params()) + actor = BatchEmbedCPUActor(params=self._make_params()) result = actor.process(pd.DataFrame({"text": ["hello"]})) mock_fn.assert_called_once() pd.testing.assert_frame_equal(result, expected) def test_preprocess_passthrough(self): - from nemo_retriever.text_embed.cpu_operator import _BatchEmbedCPUActor + from nemo_retriever.text_embed.cpu_operator import BatchEmbedCPUActor - actor = _BatchEmbedCPUActor(params=self._make_params()) + actor = BatchEmbedCPUActor(params=self._make_params()) df = pd.DataFrame({"text": ["hello"]}) pd.testing.assert_frame_equal(actor.preprocess(df), df) def test_postprocess_passthrough(self): - from nemo_retriever.text_embed.cpu_operator import _BatchEmbedCPUActor + from nemo_retriever.text_embed.cpu_operator import BatchEmbedCPUActor - actor = _BatchEmbedCPUActor(params=self._make_params()) + actor = BatchEmbedCPUActor(params=self._make_params()) df = pd.DataFrame({"text": ["hello"]}) pd.testing.assert_frame_equal(actor.postprocess(df), df) diff --git a/nemo_retriever/tests/test_table_structure.py b/nemo_retriever/tests/test_table_structure.py index 1a55a51ea..e6c8a59ea 100644 --- a/nemo_retriever/tests/test_table_structure.py +++ b/nemo_retriever/tests/test_table_structure.py @@ -14,6 +14,7 @@ import pandas as pd import pytest +from nemo_retriever.tests.testing_utils import _run from nemo_retriever.utils.table_and_chart import join_table_structure_and_ocr_output @@ -328,7 +329,7 @@ def test_actor_error_returns_dataframe_with_error(self) -> None: df = _make_page_df() # This will fail because both models are None and no URLs set. - result = actor(df) + result = _run(actor(df)) assert "table" in result.columns assert "table_structure_ocr_v1" in result.columns meta = result.iloc[0]["table_structure_ocr_v1"] diff --git a/nemo_retriever/tests/testing_utils.py b/nemo_retriever/tests/testing_utils.py new file mode 100644 index 000000000..9989d245f --- /dev/null +++ b/nemo_retriever/tests/testing_utils.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared helpers for unit tests.""" + +from __future__ import annotations + +import asyncio + + +def _run(coro_or_result): + """Run a coroutine synchronously in tests; pass through plain values.""" + if not asyncio.iscoroutine(coro_or_result): + return coro_or_result + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro_or_result) + finally: + loop.close() diff --git a/nemo_retriever/uv.lock b/nemo_retriever/uv.lock index a1f761ba9..f06cdec02 100644 --- a/nemo_retriever/uv.lock +++ b/nemo_retriever/uv.lock @@ -2752,6 +2752,7 @@ wheels = [ name = "nemo-retriever" source = { editable = "." } dependencies = [ + { name = "aiohttp" }, { name = "debugpy" }, { name = "fastapi" }, { name = "httpx" }, @@ -2869,6 +2870,7 @@ stores = [ requires-dist = [ { name = "accelerate", marker = "extra == 'local'", specifier = "==1.12.0" }, { name = "addict", marker = "extra == 'local'" }, + { name = "aiohttp", specifier = ">=3.9.0" }, { name = "albumentations", marker = "extra == 'local'", specifier = "==2.0.8" }, { name = "apscheduler", marker = "extra == 'local'", specifier = ">=3.10" }, { name = "build", marker = "extra == 'dev'", specifier = ">=1.2.2" },