Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo_retriever/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"uvicorn[standard]>=0.30.0",
"python-multipart>=0.0.9",
# HTTP clients
"aiohttp>=3.9.0",
"httpx>=0.27.0",
"requests>=2.32.5",
"urllib3==2.6.3",
Expand Down
2 changes: 1 addition & 1 deletion nemo_retriever/src/nemo_retriever/audio/asr_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,4 @@ def apply_asr_to_df(
"""
params = ASRParams(**(asr_params or {}))
actor = ASRActor(params=params)
return actor(batch_df)
return actor.run(batch_df)
20 changes: 17 additions & 3 deletions nemo_retriever/src/nemo_retriever/chart/cpu_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nemo_retriever.graph.cpu_operator import CPUOperator
from nemo_retriever.nim.nim import NIMClient
from nemo_retriever.params import RemoteRetryParams
from nemo_retriever.chart.shared import graphic_elements_ocr_page_elements
from nemo_retriever.chart.shared import agraphic_elements_ocr_page_elements, graphic_elements_ocr_page_elements

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -143,9 +143,23 @@ def process(self, data: Any, **kwargs: Any) -> Any:
def postprocess(self, data: Any, **kwargs: Any) -> Any:
return data

def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any:
async def aprocess(self, data: Any, **kwargs: Any) -> Any:
return await agraphic_elements_ocr_page_elements(
data,
graphic_elements_model=self._graphic_elements_model,
ocr_model=self._ocr_model,
graphic_elements_invoke_url=self._graphic_elements_invoke_url,
ocr_invoke_url=self._ocr_invoke_url,
api_key=self._api_key,
request_timeout_s=self._request_timeout_s,
remote_retry=self._remote_retry,
inference_batch_size=self._inference_batch_size,
**kwargs,
)

async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any:
try:
return self.run(batch_df, **override_kwargs)
return await self.arun(batch_df, **override_kwargs)
except BaseException as exc:
if isinstance(batch_df, pd.DataFrame):
out = batch_df.copy()
Expand Down
20 changes: 17 additions & 3 deletions nemo_retriever/src/nemo_retriever/chart/gpu_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from nemo_retriever.graph.gpu_operator import GPUOperator
from nemo_retriever.nim.nim import NIMClient
from nemo_retriever.params import RemoteRetryParams
from nemo_retriever.chart.shared import graphic_elements_ocr_page_elements
from nemo_retriever.chart.shared import agraphic_elements_ocr_page_elements, graphic_elements_ocr_page_elements


class GraphicElementsActor(AbstractOperator, GPUOperator):
Expand Down Expand Up @@ -88,9 +88,23 @@ def process(self, data: Any, **kwargs: Any) -> Any:
def postprocess(self, data: Any, **kwargs: Any) -> Any:
return data

def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any:
async def aprocess(self, data: Any, **kwargs: Any) -> Any:
return await agraphic_elements_ocr_page_elements(
data,
graphic_elements_model=self._graphic_elements_model,
ocr_model=self._ocr_model,
graphic_elements_invoke_url=self._graphic_elements_invoke_url,
ocr_invoke_url=self._ocr_invoke_url,
api_key=self._api_key,
request_timeout_s=self._request_timeout_s,
remote_retry=self._remote_retry,
inference_batch_size=self._inference_batch_size,
**kwargs,
)

async def __call__(self, batch_df: Any, **override_kwargs: Any) -> Any:
try:
return self.run(batch_df, **override_kwargs)
return await self.arun(batch_df, **override_kwargs)
except BaseException as exc:
if isinstance(batch_df, pd.DataFrame):
out = batch_df.copy()
Expand Down
231 changes: 231 additions & 0 deletions nemo_retriever/src/nemo_retriever/chart/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@

import base64
import io
import logging
import time
import traceback

logger = logging.getLogger(__name__)

import pandas as pd
from nemo_retriever.nim.nim import NIMClient, invoke_image_inference_batches
from nemo_retriever.params import RemoteRetryParams
Expand Down Expand Up @@ -576,6 +579,234 @@ def graphic_elements_ocr_page_elements(
return out


async def agraphic_elements_ocr_page_elements(
batch_df: Any,
*,
graphic_elements_model: Any = None,
ocr_model: Any = None,
graphic_elements_invoke_url: str = "",
ocr_invoke_url: str = "",
api_key: str = "",
request_timeout_s: float = 120.0,
remote_retry: RemoteRetryParams | None = None,
**kwargs: Any,
) -> Any:
"""Async version of :func:`graphic_elements_ocr_page_elements`."""
import asyncio

from nemo_retriever.nim.nim import ainvoke_image_inference_batches
from nemo_retriever.ocr.ocr import (
_blocks_to_text,
_crop_all_from_page,
_extract_remote_ocr_item,
_np_rgb_to_b64_png,
_parse_ocr_result,
)
from nemo_retriever.utils.table_and_chart import join_graphic_elements_and_ocr_output

retry = remote_retry or RemoteRetryParams(
remote_max_pool_workers=int(kwargs.get("remote_max_pool_workers", 16)),
remote_max_retries=int(kwargs.get("remote_max_retries", 10)),
remote_max_429_retries=int(kwargs.get("remote_max_429_retries", 5)),
)

if not isinstance(batch_df, pd.DataFrame):
raise NotImplementedError("agraphic_elements_ocr_page_elements currently only supports pandas.DataFrame input.")

ge_url = (graphic_elements_invoke_url or kwargs.get("graphic_elements_invoke_url") or "").strip()
ocr_url = (ocr_invoke_url or kwargs.get("ocr_invoke_url") or "").strip()
use_remote_ge = bool(ge_url)
use_remote_ocr = bool(ocr_url)

if not use_remote_ge and graphic_elements_model is None:
raise ValueError("A local `graphic_elements_model` is required when `graphic_elements_invoke_url` is not set.")
if not use_remote_ocr and ocr_model is None:
raise ValueError("A local `ocr_model` is required when `ocr_invoke_url` is not set.")

label_names = _labels_from_model(graphic_elements_model) if graphic_elements_model is not None else []
inference_batch_size = int(kwargs.get("inference_batch_size", 8))

all_chart: List[List[Dict[str, Any]]] = []
all_meta: List[Dict[str, Any]] = []

t0_total = time.perf_counter()

for row in batch_df.itertuples(index=False):
chart_items: List[Dict[str, Any]] = []
row_error: Any = None

try:
pe = getattr(row, "page_elements_v3", None)
dets: List[Dict[str, Any]] = []
if isinstance(pe, dict):
dets = pe.get("detections") or []
if not isinstance(dets, list):
dets = []

page_image = getattr(row, "page_image", None) or {}
page_image_b64 = page_image.get("image_b64") if isinstance(page_image, dict) else None

if not isinstance(page_image_b64, str) or not page_image_b64:
all_chart.append(chart_items)
all_meta.append({"timing": None, "error": None})
continue

crops = _crop_all_from_page(page_image_b64, dets, {"chart"})

if not crops:
all_chart.append(chart_items)
all_meta.append({"timing": None, "error": None})
continue

crop_b64s = (
[_np_rgb_to_b64_png(crop_array) for _, _, crop_array in crops]
if (use_remote_ge or use_remote_ocr)
else []
)

ge_results: List[List[Dict[str, Any]]] = []
ocr_results: List[Any] = []

if use_remote_ge and use_remote_ocr:
ge_task = ainvoke_image_inference_batches(
invoke_url=ge_url,
image_b64_list=crop_b64s,
api_key=api_key or None,
timeout_s=float(request_timeout_s),
max_batch_size=inference_batch_size,
max_concurrency=int(retry.remote_max_pool_workers),
max_retries=int(retry.remote_max_retries),
max_429_retries=int(retry.remote_max_429_retries),
)
ocr_task = ainvoke_image_inference_batches(
invoke_url=ocr_url,
image_b64_list=crop_b64s,
api_key=api_key or None,
timeout_s=float(request_timeout_s),
max_batch_size=inference_batch_size,
max_concurrency=int(retry.remote_max_pool_workers),
max_retries=int(retry.remote_max_retries),
max_429_retries=int(retry.remote_max_429_retries),
)
ge_items, ocr_items = await asyncio.gather(ge_task, ocr_task)

if len(ge_items) != len(crops):
raise RuntimeError(f"Expected {len(crops)} GE responses, got {len(ge_items)}")
for resp in ge_items:
ge_results.append(
[
d
for d in _remote_response_to_ge_detections(resp)
if (d.get("score") or 0.0) >= YOLOX_GRAPHIC_MIN_SCORE
]
)
if len(ocr_items) != len(crops):
raise RuntimeError(f"Expected {len(crops)} OCR responses, got {len(ocr_items)}")
for resp in ocr_items:
ocr_results.append(_extract_remote_ocr_item(resp))
else:
if use_remote_ge:
ge_items = await ainvoke_image_inference_batches(
invoke_url=ge_url,
image_b64_list=crop_b64s,
api_key=api_key or None,
timeout_s=float(request_timeout_s),
max_batch_size=inference_batch_size,
max_concurrency=int(retry.remote_max_pool_workers),
max_retries=int(retry.remote_max_retries),
max_429_retries=int(retry.remote_max_429_retries),
)
if len(ge_items) != len(crops):
raise RuntimeError(f"Expected {len(crops)} GE responses, got {len(ge_items)}")
for resp in ge_items:
ge_results.append(
[
d
for d in _remote_response_to_ge_detections(resp)
if (d.get("score") or 0.0) >= YOLOX_GRAPHIC_MIN_SCORE
]
)
else:

def _run_local_ge():
results = []
for _, _, crop_array in crops:
chw = torch.from_numpy(crop_array).permute(2, 0, 1).contiguous().to(dtype=torch.float32)
h, w = crop_array.shape[:2]
x = chw.unsqueeze(0)
try:
pre = graphic_elements_model.preprocess(x)
except Exception:
pre = x
if isinstance(pre, torch.Tensor) and pre.ndim == 3:
pre = pre.unsqueeze(0)
pred = graphic_elements_model.invoke(pre, (h, w))
ge_dets = _prediction_to_detections(pred, label_names=label_names)
results.append([d for d in ge_dets if (d.get("score") or 0.0) >= YOLOX_GRAPHIC_MIN_SCORE])
return results

ge_results = await asyncio.to_thread(_run_local_ge)

if use_remote_ocr:
ocr_items = await ainvoke_image_inference_batches(
invoke_url=ocr_url,
image_b64_list=crop_b64s,
api_key=api_key or None,
timeout_s=float(request_timeout_s),
max_batch_size=inference_batch_size,
max_concurrency=int(retry.remote_max_pool_workers),
max_retries=int(retry.remote_max_retries),
max_429_retries=int(retry.remote_max_429_retries),
)
if len(ocr_items) != len(crops):
raise RuntimeError(f"Expected {len(crops)} OCR responses, got {len(ocr_items)}")
for resp in ocr_items:
ocr_results.append(_extract_remote_ocr_item(resp))
else:

def _run_local_ocr():
results = []
for _, _, crop_array in crops:
results.append(ocr_model.invoke(crop_array, merge_level="word"))
return results

ocr_results = await asyncio.to_thread(_run_local_ocr)

for crop_i, (label_name, bbox, crop_array) in enumerate(crops):
crop_hw = (int(crop_array.shape[0]), int(crop_array.shape[1]))
ge_dets = ge_results[crop_i]
ocr_preds = ocr_results[crop_i]

text = join_graphic_elements_and_ocr_output(ge_dets, ocr_preds, crop_hw)

if not text:
blocks = _parse_ocr_result(ocr_preds)
text = _blocks_to_text(blocks)

chart_items.append({"bbox_xyxy_norm": bbox, "text": text})

except BaseException as e:
logger.warning("graphic-elements+OCR failed: %s: %s", type(e).__name__, e, exc_info=True)
row_error = {
"stage": "graphic_elements_ocr_page_elements",
"type": e.__class__.__name__,
"message": str(e),
"traceback": "".join(traceback.format_exception(type(e), e, e.__traceback__)),
}

all_chart.append(chart_items)
all_meta.append({"timing": None, "error": row_error})

elapsed = time.perf_counter() - t0_total
for meta in all_meta:
meta["timing"] = {"seconds": float(elapsed)}

out = batch_df.copy()
out["chart"] = all_chart
out["graphic_elements_ocr_v1"] = all_meta
return out


# ---------------------------------------------------------------------------
# Combined graphic-elements + OCR Ray Actor
# ---------------------------------------------------------------------------
Loading
Loading