diff --git a/nemo_retriever/harness/test_configs.yaml b/nemo_retriever/harness/test_configs.yaml index 4d90e05fe..f98b1af3e 100644 --- a/nemo_retriever/harness/test_configs.yaml +++ b/nemo_retriever/harness/test_configs.yaml @@ -200,7 +200,7 @@ presets: datasets: bo20: - path: /home/jdyer/datasets/bo20 + path: /datasets/nv-ingest/bo20 query_csv: null input_type: pdf recall_required: false diff --git a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py index 3c55c6745..4de3d2c83 100644 --- a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py @@ -53,11 +53,11 @@ from nemo_retriever.params import ExtractParams from nemo_retriever.params import StoreParams from nemo_retriever.params import TextChunkParams +from nemo_retriever.params import VdbUploadParams from nemo_retriever.model import VL_EMBED_MODEL, VL_RERANK_MODEL -from nemo_retriever.params.models import BatchTuningParams +from nemo_retriever.params.models import BatchTuningParams, LanceDbParams from nemo_retriever.utils.input_files import resolve_input_patterns from nemo_retriever.utils.remote_auth import resolve_remote_api_key -from nemo_retriever.vector_store.lancedb_store import handle_lancedb logger = logging.getLogger(__name__) app = typer.Typer() @@ -126,46 +126,77 @@ def _configure_logging(log_file: Optional[Path], *, debug: bool = False) -> tupl return fh, original_stdout, original_stderr -def _ensure_lancedb_table(uri: str, table_name: str) -> None: - from nemo_retriever.vector_store.lancedb_utils import lancedb_schema - import lancedb - import pyarrow as pa - - Path(uri).mkdir(parents=True, exist_ok=True) - db = lancedb.connect(uri) - try: - db.open_table(table_name) - return - except Exception: - pass - schema = lancedb_schema() - empty = pa.table({f.name: [] for f in schema}, schema=schema) - db.create_table(table_name, data=empty, schema=schema, mode="create") - - def _write_runtime_summary( runtime_metrics_dir: Optional[Path], runtime_metrics_prefix: Optional[str], payload: dict[str, object], + metrics_output_file: Optional[Path] = None, ) -> None: - if runtime_metrics_dir is None and not runtime_metrics_prefix: - return + if runtime_metrics_dir is not None or runtime_metrics_prefix: + target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve() + target_dir.mkdir(parents=True, exist_ok=True) + prefix = (runtime_metrics_prefix or "run").strip() or "run" + target = target_dir / f"{prefix}.runtime.summary.json" + target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + + if metrics_output_file is not None: + out_path = Path(metrics_output_file).expanduser() + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + - target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve() - target_dir.mkdir(parents=True, exist_ok=True) - prefix = (runtime_metrics_prefix or "run").strip() or "run" - target = target_dir / f"{prefix}.runtime.summary.json" - target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") +def _page_keys_from_df(result_df) -> list[str]: + """Return stable page/input-unit keys from a pipeline result DataFrame.""" + if result_df is None or result_df.empty: + return [] + source_column = next((c for c in ("source_id", "path", "source_path") if c in result_df.columns), None) + if source_column is not None and "page_number" in result_df.columns: + key_df = result_df[[source_column, "page_number"]].dropna() + return [f"{source}\x1f{page}" for source, page in key_df.itertuples(index=False, name=None)] -def _count_input_units(result_df) -> int: - if "source_id" in result_df.columns: - return int(result_df["source_id"].nunique()) - if "source_path" in result_df.columns: - return int(result_df["source_path"].nunique()) + if source_column is not None: + return [str(v) for v in result_df[source_column].dropna().tolist()] + + if "page_number" in result_df.columns: + return [str(v) for v in result_df["page_number"].dropna().tolist()] + + return [] + + +def _count_processed_pages_from_df(result_df) -> int: + keys = _page_keys_from_df(result_df) + if keys: + return int(len(set(keys))) return int(len(result_df.index)) +def _extract_page_key_batch(batch): + import pandas as pd + + keys = _page_keys_from_df(batch) + return pd.DataFrame({"_page_key": keys}) + + +def _count_processed_pages_from_dataset(dataset, *, fallback_rows: int) -> int: + try: + columns = set(dataset.columns()) + except Exception: + return int(fallback_rows) + + if not columns.intersection({"source_id", "path", "source_path", "page_number"}): + return int(fallback_rows) + + try: + key_ds = dataset.map_batches(_extract_page_key_batch, batch_format="pandas") + if int(key_ds.count()) == 0: + return 0 + return int(key_ds.groupby("_page_key").count().count()) + except Exception: + logger.warning("Could not estimate processed pages from Ray Dataset; falling back to output row count.", exc_info=True) + return int(fallback_rows) + + def _resolve_file_patterns(input_path: Path, input_type: str) -> list[str]: import glob as _glob @@ -310,6 +341,13 @@ def main( runtime_metrics_dir: Optional[Path] = typer.Option(None, "--runtime-metrics-dir", path_type=Path), runtime_metrics_prefix: Optional[str] = typer.Option(None, "--runtime-metrics-prefix"), detection_summary_file: Optional[Path] = typer.Option(None, "--detection-summary-file", path_type=Path), + metrics_output_file: Optional[Path] = typer.Option( + None, + "--metrics-output-file", + path_type=Path, + dir_okay=False, + help="JSON file path to write structured run metrics (used by the harness).", + ), log_file: Optional[Path] = typer.Option(None, "--log-file", path_type=Path, dir_okay=False), ) -> None: _ = ctx @@ -328,7 +366,6 @@ def main( os.environ["RAY_LOG_TO_DRIVER"] = "1" if ray_log_to_driver else "0" lancedb_uri = str(Path(lancedb_uri).expanduser().resolve()) - _ensure_lancedb_table(lancedb_uri, LANCEDB_TABLE) remote_api_key = resolve_remote_api_key(api_key) extract_remote_api_key = remote_api_key @@ -535,6 +572,21 @@ def main( ingestor = ingestor.embed(embed_params) + # VDB upload runs inside the graph — rows stream to the configured + # backend as they are produced, so we never need to collect the entire + # result set on the driver just for the write. Index creation happens + # automatically in GraphIngestor._finalize_vdb() after the pipeline. + ingestor = ingestor.vdb_upload( + VdbUploadParams( + lancedb=LanceDbParams( + lancedb_uri=lancedb_uri, + table_name=LANCEDB_TABLE, + hybrid=hybrid, + overwrite=True, + ), + ) + ) + # ------------------------------------------------------------------ # Execute the graph via the executor # ------------------------------------------------------------------ @@ -542,7 +594,7 @@ def main( ingest_start = time.perf_counter() # GraphIngestor.ingest() builds the Graph, creates the executor, - # and calls executor.ingest(file_patterns) returning: + # calls executor.ingest(file_patterns), and finalizes the VDB index. # batch mode -> materialized ray.data.Dataset # inprocess mode -> pandas.DataFrame result = ingestor.ingest() @@ -550,26 +602,36 @@ def main( ingestion_only_total_time = time.perf_counter() - ingest_start # ------------------------------------------------------------------ - # Collect results + # Collect results only when downstream features need the full DataFrame. + # Page/row metrics stay separate: PPS is pages/sec, while num_rows tracks + # output rows after any content explosion. # ------------------------------------------------------------------ if run_mode == "batch": import ray - ray_download_start = time.perf_counter() - ingest_local_results = result.take_all() - ray_download_time = time.perf_counter() - ray_download_start - - import pandas as pd - - result_df = pd.DataFrame(ingest_local_results) - num_rows = _count_input_units(result_df) + needs_result_df = detection_summary_file is not None or save_intermediate is not None + if needs_result_df: + ray_download_start = time.perf_counter() + ingest_local_results = result.take_all() + ray_download_time = time.perf_counter() - ray_download_start + + import pandas as pd + + result_df = pd.DataFrame(ingest_local_results) + processed_pages = _count_processed_pages_from_df(result_df) + output_rows = int(len(result_df.index)) + else: + ray_download_time = 0.0 + result_df = None + output_rows = int(result.count()) + processed_pages = _count_processed_pages_from_dataset(result, fallback_rows=output_rows) else: import pandas as pd result_df = result - ingest_local_results = result_df.to_dict("records") ray_download_time = 0.0 - num_rows = _count_input_units(result_df) + processed_pages = _count_processed_pages_from_df(result_df) + output_rows = int(len(result_df.index)) if save_intermediate is not None: out_dir = Path(save_intermediate).expanduser().resolve() @@ -589,35 +651,22 @@ def main( collect_detection_summary_from_df(result_df), ) - # ------------------------------------------------------------------ - # Write to LanceDB - # ------------------------------------------------------------------ - lancedb_write_start = time.perf_counter() - handle_lancedb(ingest_local_results, lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite") - lancedb_write_time = time.perf_counter() - lancedb_write_start - # ------------------------------------------------------------------ # Recall / BEIR evaluation # ------------------------------------------------------------------ - import lancedb as _lancedb_mod - - db = _lancedb_mod.connect(lancedb_uri) - table = db.open_table(LANCEDB_TABLE) - - if int(table.count_rows()) == 0: - logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode) + def _empty_summary(reason_label: str) -> None: _write_runtime_summary( runtime_metrics_dir, runtime_metrics_prefix, { "run_mode": run_mode, "input_path": str(Path(input_path).resolve()), - "input_pages": int(num_rows), - "num_pages": int(num_rows), - "num_rows": int(len(result_df.index)), + "input_pages": int(processed_pages), + "num_pages": int(processed_pages), + "num_rows": int(output_rows), "ingestion_only_secs": float(ingestion_only_total_time), "ray_download_secs": float(ray_download_time), - "lancedb_write_secs": float(lancedb_write_time), + "lancedb_write_secs": 0.0, "evaluation_secs": 0.0, "total_secs": float(time.perf_counter() - ingest_start), "evaluation_mode": evaluation_mode, @@ -625,8 +674,26 @@ def main( "recall_details": bool(recall_details), "lancedb_uri": str(lancedb_uri), "lancedb_table": str(LANCEDB_TABLE), + "skip_reason": reason_label, }, + metrics_output_file=metrics_output_file, ) + + import lancedb as _lancedb_mod + + db = _lancedb_mod.connect(lancedb_uri) + try: + table = db.open_table(LANCEDB_TABLE) + except Exception: + logger.warning("LanceDB table %r was not created; skipping %s evaluation.", LANCEDB_TABLE, evaluation_mode) + _empty_summary("lancedb_table_missing") + if run_mode == "batch": + ray.shutdown() + return + + if int(table.count_rows()) == 0: + logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode) + _empty_summary("lancedb_table_empty") if run_mode == "batch": ray.shutdown() return @@ -675,27 +742,7 @@ def main( query_csv_path = Path(query_csv) if not query_csv_path.exists(): logger.warning("Query CSV not found at %s; skipping recall evaluation.", query_csv_path) - _write_runtime_summary( - runtime_metrics_dir, - runtime_metrics_prefix, - { - "run_mode": run_mode, - "input_path": str(Path(input_path).resolve()), - "input_pages": int(num_rows), - "num_pages": int(num_rows), - "num_rows": int(len(result_df.index)), - "ingestion_only_secs": float(ingestion_only_total_time), - "ray_download_secs": float(ray_download_time), - "lancedb_write_secs": float(lancedb_write_time), - "evaluation_secs": 0.0, - "total_secs": float(time.perf_counter() - ingest_start), - "evaluation_mode": evaluation_mode, - "evaluation_metrics": {}, - "recall_details": bool(recall_details), - "lancedb_uri": str(lancedb_uri), - "lancedb_table": str(LANCEDB_TABLE), - }, - ) + _empty_summary("query_csv_missing") if run_mode == "batch": ray.shutdown() return @@ -733,12 +780,12 @@ def main( { "run_mode": run_mode, "input_path": str(Path(input_path).resolve()), - "input_pages": int(num_rows), - "num_pages": int(num_rows), - "num_rows": int(len(result_df.index)), + "input_pages": int(processed_pages), + "num_pages": int(processed_pages), + "num_rows": int(output_rows), "ingestion_only_secs": float(ingestion_only_total_time), "ray_download_secs": float(ray_download_time), - "lancedb_write_secs": float(lancedb_write_time), + "lancedb_write_secs": 0.0, "evaluation_secs": float(evaluation_total_time), "total_secs": float(total_time), "evaluation_mode": evaluation_mode, @@ -748,13 +795,14 @@ def main( "lancedb_uri": str(lancedb_uri), "lancedb_table": str(LANCEDB_TABLE), }, + metrics_output_file=metrics_output_file, ) if run_mode == "batch": ray.shutdown() print_run_summary( - num_rows, + processed_pages, Path(input_path), hybrid, lancedb_uri, @@ -762,7 +810,7 @@ def main( total_time, ingestion_only_total_time, ray_download_time, - lancedb_write_time, + 0.0, evaluation_total_time, evaluation_metrics, evaluation_label=evaluation_label, diff --git a/nemo_retriever/src/nemo_retriever/graph/__init__.py b/nemo_retriever/src/nemo_retriever/graph/__init__.py index cd7fcf8a5..6966e35cd 100644 --- a/nemo_retriever/src/nemo_retriever/graph/__init__.py +++ b/nemo_retriever/src/nemo_retriever/graph/__init__.py @@ -16,6 +16,7 @@ from nemo_retriever.graph.graph_pipeline_registry import GraphPipelineRegistry, default_registry from nemo_retriever.graph.pipeline_graph import Graph, Node from nemo_retriever.graph.store_operator import StoreOperator +from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator from nemo_retriever.graph.webhook_operator import WebhookNotifyOperator __all__ = [ @@ -33,6 +34,7 @@ "RayDataExecutor", "StoreOperator", "UDFOperator", + "VDBUploadOperator", "WebhookNotifyOperator", "default_registry", ] diff --git a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py index 95acdf0b6..0a27065cb 100644 --- a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py +++ b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py @@ -15,7 +15,7 @@ from nemo_retriever.audio import MediaChunkActor from nemo_retriever.chart.chart_detection import GraphicElementsActor from nemo_retriever.dedup.dedup import dedup_images -from nemo_retriever.graph import Graph, StoreOperator, UDFOperator, WebhookNotifyOperator +from nemo_retriever.graph import Graph, StoreOperator, UDFOperator, VDBUploadOperator, WebhookNotifyOperator from nemo_retriever.graph.content_transforms import ( _CONTENT_COLUMNS, collapse_content_to_page_rows, @@ -344,7 +344,8 @@ def _resolve_execution_inputs( caption_params: Any | None, store_params: Any | None, embed_params: Any | None, - webhook_params: Any | None = None, + vdb_upload_params: Any | None, + webhook_params: Any | None, stage_order: tuple[str, ...], ) -> tuple[ str, @@ -359,6 +360,7 @@ def _resolve_execution_inputs( Any | None, Any | None, Any | None, + Any | None, tuple[str, ...], ]: """Resolve legacy builder args or a shared execution plan into one input tuple.""" @@ -376,11 +378,13 @@ def _resolve_execution_inputs( caption_params, store_params, embed_params, + vdb_upload_params, webhook_params, stage_order, ) stage_map = {stage.name: stage.params for stage in execution_plan.stages} + sink_map = {sink.name: sink.params for sink in execution_plan.sinks} return ( execution_plan.extraction_mode, execution_plan.extract_params, @@ -393,8 +397,9 @@ def _resolve_execution_inputs( stage_map.get("caption"), stage_map.get("store"), stage_map.get("embed"), + sink_map.get("vdb_upload"), stage_map.get("webhook"), - tuple(stage.name for stage in execution_plan.stages), + tuple(stage.name for stage in execution_plan.stages) + tuple(sink.name for sink in execution_plan.sinks), ) @@ -420,17 +425,21 @@ def _append_ordered_transform_stages( caption_params: Any | None, store_params: Any | None, embed_params: Any | None, + vdb_upload_params: Any | None, + vdb_op: Any | None, webhook_params: Any | None = None, stage_order: tuple[str, ...], supports_dedup: bool, reshape_for_modal_content: bool, + vdb_upload_ops_out: list | None = None, ) -> Graph: """Append post-extraction transform stages in the exact recorded plan order.""" pending_stages = [ stage for stage in stage_order - if stage in {"dedup", "split", "caption", "store", "embed"} and (supports_dedup or stage != "dedup") + if stage in {"dedup", "split", "caption", "store", "embed", "vdb_upload"} + and (supports_dedup or stage != "dedup") ] if not pending_stages: if supports_dedup and dedup_params is not None: @@ -443,6 +452,8 @@ def _append_ordered_transform_stages( pending_stages.append("split") if embed_params is not None: pending_stages.append("embed") + if vdb_upload_params is not None: + pending_stages.append("vdb_upload") for stage_name in pending_stages: if stage_name == "store" and store_params is not None: @@ -480,6 +491,11 @@ def _append_ordered_transform_stages( name="ExplodeContentToRows", ) graph = graph >> _BatchEmbedActor(params=embed_params) + elif stage_name == "vdb_upload" and vdb_upload_params is not None: + vdb_upload_op = VDBUploadOperator(params=vdb_upload_params, vdb_op=vdb_op) + if vdb_upload_ops_out is not None: + vdb_upload_ops_out.append(vdb_upload_op) + graph = graph >> vdb_upload_op if webhook_params is not None and getattr(webhook_params, "endpoint_url", None): graph = graph >> WebhookNotifyOperator(params=webhook_params) @@ -501,8 +517,11 @@ def build_graph( split_params: Any | None = None, caption_params: Any | None = None, store_params: Any | None = None, + vdb_upload_params: Any | None = None, + vdb_op: Any | None = None, webhook_params: Any | None = None, stage_order: tuple[str, ...] = (), + vdb_upload_ops_out: list | None = None, ) -> Graph: """Build a batch graph from explicit params or a shared execution plan.""" @@ -518,6 +537,7 @@ def build_graph( caption_params, store_params, embed_params, + vdb_upload_params, webhook_params, stage_order, ) = _resolve_execution_inputs( @@ -533,6 +553,7 @@ def build_graph( caption_params=caption_params, store_params=store_params, embed_params=embed_params, + vdb_upload_params=vdb_upload_params, webhook_params=webhook_params, stage_order=stage_order, ) @@ -670,10 +691,13 @@ def build_graph( caption_params=caption_params, store_params=store_params, embed_params=embed_params, + vdb_upload_params=vdb_upload_params, + vdb_op=vdb_op, webhook_params=webhook_params, stage_order=stage_order, supports_dedup=True, reshape_for_modal_content=True, + vdb_upload_ops_out=vdb_upload_ops_out, ) diff --git a/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py new file mode 100644 index 000000000..979f44f32 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py @@ -0,0 +1,379 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Graph operator for streaming VDB uploads during pipeline execution. + +Wraps existing client VDB classes (``nv_ingest_client.util.vdb``) so that +any backend implementing the client :class:`VDB` ABC can be used as a +pipeline sink. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, List, Sequence + +from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.graph.cpu_operator import CPUOperator + +logger = logging.getLogger(__name__) + + +def _canonical_to_nvingest(rows: Sequence[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: + """Convert canonical VDB records to NV-Ingest pipeline format. + + The client Milvus implementation expects ``list[list[dict]]`` where each + inner dict has ``document_type`` and a nested ``metadata`` dict with keys + ``embedding``, ``content``, ``content_metadata``, ``source_metadata``. + """ + elements: List[Dict[str, Any]] = [] + for row in rows: + meta_str = row.get("metadata", "{}") + source_str = row.get("source", "{}") + try: + content_metadata = json.loads(meta_str) if isinstance(meta_str, str) else meta_str + except (json.JSONDecodeError, TypeError): + content_metadata = {} + try: + source_metadata = json.loads(source_str) if isinstance(source_str, str) else source_str + except (json.JSONDecodeError, TypeError): + source_metadata = {} + + elements.append( + { + "document_type": "text", + "metadata": { + "content": row.get("text", ""), + "embedding": row.get("vector"), + "content_metadata": content_metadata, + "source_metadata": source_metadata, + }, + } + ) + return [elements] if elements else [] + + +class VDBUploadOperator(AbstractOperator, CPUOperator): + """Write pipeline embeddings to a vector store as data flows through the graph. + + Wraps a client VDB instance (LanceDB, Milvus, etc.) from + ``nv_ingest_client.util.vdb``. ``preprocess`` extracts canonical + records from the DataFrame (backend-agnostic); ``process`` converts + to the target format and writes (backend-specific). + + The DataFrame is passed through unchanged — this is a side-effect + operator. + + **Configuring the backend.** Callers can either: + + * pass ``params`` and let the operator construct the client VDB from + ``VdbUploadParams.backend`` + ``client_vdb_kwargs``, or + * pass a pre-constructed client VDB via ``vdb_op`` + (e.g. ``Milvus(...)`` / ``LanceDB(...)`` from + ``nv_ingest_client.util.vdb``). ``params`` still supplies the + record-shaping config (embedding column, text column, etc.). + + LanceDB is the default backend. Custom ``vdb_op`` implementations should + follow the existing ``nv_ingest_client.util.vdb.adt_vdb.VDB`` contract: + ``create_index()``, ``write_to_index(records)``, ``retrieval()``, and + ``run(records)``. Milvus is handled specially inside this wrapper only to + avoid per-batch index waits; callers still pass the standard Milvus VDB + instance via ``vdb_op=Milvus(...)``. + + **Concurrency**: This operator must run with ``concurrency=1`` in batch + mode. The single actor creates the backend destination on its first + write (respecting backend-specific overwrite/recreate settings) and + appends on subsequent writes. Backend work that must run exactly once + after all batches is handled by ``finalize()`` on the driver. + """ + + def __init__( + self, + *, + params: Any = None, + vdb_op: Any = None, + ) -> None: + super().__init__() + from nemo_retriever.params.models import LanceDbParams, VdbUploadParams + + # Store as self. so get_constructor_kwargs() captures both for + # deferred reconstruction on Ray workers. + self.params = params + self.vdb_op = vdb_op + + if isinstance(params, VdbUploadParams): + self._vdb_params = params + self._lance_params = params.lancedb + elif isinstance(params, LanceDbParams): + self._vdb_params = None + self._lance_params = params + else: + self._vdb_params = VdbUploadParams() + self._lance_params = self._vdb_params.lancedb + + if vdb_op is not None: + self._backend_name = type(vdb_op).__name__.lower() + else: + self._backend_name = getattr(self._vdb_params, "backend", "lancedb") if self._vdb_params else "lancedb" + + self._client_vdb: Any = vdb_op + self._table: Any = None + self._milvus_client: Any = None + self._index_created: bool = False + self._pending_records: List[Dict[str, Any]] = [] + + # ------------------------------------------------------------------ + # Client VDB construction + # ------------------------------------------------------------------ + + def _create_client_vdb(self) -> Any: + """Lazily construct the client VDB instance.""" + from nv_ingest_client.util.vdb import get_vdb_op_cls + + if self._backend_name == "lancedb": + from nemo_retriever.vector_store.lancedb_utils import build_client_lancedb + + return build_client_lancedb(self._lance_params) + + kwargs = getattr(self._vdb_params, "client_vdb_kwargs", {}) or {} + return get_vdb_op_cls(self._backend_name)(**kwargs) + + # ------------------------------------------------------------------ + # Operator lifecycle + # ------------------------------------------------------------------ + + def preprocess(self, data: Any, **kwargs: Any) -> Any: + """Extract canonical VDB records from the DataFrame (backend-agnostic).""" + import pandas as pd + + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + if not isinstance(data, pd.DataFrame) or data.empty: + self._pending_records = [] + return data + + self._pending_records = build_vdb_records( + data, + embedding_column=self._lance_params.embedding_column, + embedding_key=self._lance_params.embedding_key, + include_text=self._lance_params.include_text, + text_column=self._lance_params.text_column, + ) + return data + + def process(self, data: Any, **kwargs: Any) -> Any: + """Write pending records to the backend (backend-specific).""" + if not self._pending_records: + return data + + if self._client_vdb is None: + self._client_vdb = self._create_client_vdb() + + if self._backend_name == "lancedb": + self._write_lancedb_batch(self._pending_records) + else: + self._write_via_client(self._pending_records) + + logger.debug("VDBUploadOperator: wrote %d records", len(self._pending_records)) + self._pending_records = [] + return data + + def postprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + def finalize(self) -> None: + """Driver-side post-pipeline hook. + + ``AbstractOperator.run()`` fires the preprocess/process/postprocess + lifecycle once per batch, so anything that must run exactly once at + the end of ingestion (flush, wait-for-index, etc.) is driven from + here by the driver after ``executor.ingest()`` returns. + """ + if self._client_vdb is None and self._backend_name == "milvus": + self._client_vdb = self._create_client_vdb() + if self._client_vdb is None: + return + if self._is_milvus_vdb(self._client_vdb): + self._finalize_milvus() + + # ------------------------------------------------------------------ + # LanceDB streaming write path + # ------------------------------------------------------------------ + + def _write_lancedb_batch(self, records: List[Dict[str, Any]]) -> None: + """Stream records to LanceDB via table.add(). + + The client LanceDB class is used for config (uri, table_name, + overwrite) and post-pipeline index creation, but its + ``create_index()`` writes all records at once with no append + support. For streaming writes we call the lancedb library + directly. + """ + if self._table is None: + import lancedb + + from nemo_retriever.vector_store.lancedb_utils import infer_vector_dim, lancedb_schema + + dim = infer_vector_dim(records) + schema = lancedb_schema(vector_dim=dim) + mode = "overwrite" if self._client_vdb.overwrite else "create" + db = lancedb.connect(uri=self._client_vdb.uri) + self._table = db.create_table( + self._client_vdb.table_name, + schema=schema, + mode=mode, + ) + + self._table.add(records) + + # ------------------------------------------------------------------ + # Non-LanceDB write path (Milvus, custom VDBs, etc.) + # ------------------------------------------------------------------ + + def _write_via_client(self, records: List[Dict[str, Any]]) -> None: + """Convert canonical records and delegate to the configured client VDB.""" + nvingest_records = _canonical_to_nvingest(records) + if not nvingest_records: + return + + if self._is_milvus_vdb(self._client_vdb): + self._write_milvus_streaming(nvingest_records) + return + + if not self._index_created: + self._client_vdb.create_index() + self._index_created = True + + self._client_vdb.write_to_index(nvingest_records) + + # ------------------------------------------------------------------ + # Milvus streaming write path + # ------------------------------------------------------------------ + + @staticmethod + def _milvus_import_error() -> ImportError: + return ImportError( + "Milvus VDB upload requires pymilvus. Install it with " + "`pip install 'nv-ingest-client[milvus]'`." + ) + + @staticmethod + def _is_milvus_vdb(vdb_op: Any) -> bool: + try: + from nv_ingest_client.util.vdb.milvus import Milvus + except ImportError: + return False + return isinstance(vdb_op, Milvus) + + @staticmethod + def _milvus_client_kwargs(create_params: Dict[str, Any]) -> Dict[str, Any]: + client_kwargs = {"uri": create_params["milvus_uri"]} + username = create_params.get("username") + password = create_params.get("password") + if username or password: + client_kwargs["token"] = f"{username or ''}:{password or ''}" + alias = create_params.get("alias") + if alias is not None: + client_kwargs["alias"] = alias + return client_kwargs + + def _write_milvus_streaming(self, nvingest_records: List[List[Dict[str, Any]]]) -> None: + """Stream a batch to Milvus without per-batch wait_for_index.""" + ( + MilvusClient, + cleanup_records, + create_nvingest_collection, + pandas_file_reader, + ) = self._load_milvus_write_helpers() + + collection_name, create_params = self._client_vdb.get_connection_params() + if not isinstance(collection_name, str): + raise ValueError( + "VDBUploadOperator's Milvus streaming path requires a string collection_name; " + f"got {type(collection_name).__name__}." + ) + if bool(create_params.get("sparse", False)): + raise NotImplementedError( + "Milvus sparse/hybrid ingestion via VDBUploadOperator's streaming path is not yet supported. " + "Construct the Milvus client with sparse=False." + ) + + if self._milvus_client is None: + create_nvingest_collection(collection_name=collection_name, **create_params) + self._milvus_client = MilvusClient(**self._milvus_client_kwargs(create_params)) + self._index_created = True + + meta_dataframe = getattr(self._client_vdb, "meta_dataframe", None) + if isinstance(meta_dataframe, str): + meta_dataframe = pandas_file_reader(meta_dataframe) + + cleaned_records = cleanup_records( + nvingest_records, + enable_text=getattr(self._client_vdb, "enable_text", True), + enable_charts=getattr(self._client_vdb, "enable_charts", True), + enable_tables=getattr(self._client_vdb, "enable_tables", True), + enable_images=getattr(self._client_vdb, "enable_images", True), + enable_infographics=getattr(self._client_vdb, "enable_infographics", True), + meta_dataframe=meta_dataframe, + meta_source_field=getattr(self._client_vdb, "meta_source_field", None), + meta_fields=getattr(self._client_vdb, "meta_fields", None), + ) + if not cleaned_records: + logger.warning("No records with embeddings to insert into Milvus.") + return + + self._milvus_client.insert(collection_name=collection_name, data=cleaned_records) + + @staticmethod + def _load_milvus_write_helpers() -> tuple[Any, Any, Any, Any]: + try: + from nv_ingest_client.util.vdb.milvus import ( + MilvusClient, + cleanup_records, + create_nvingest_collection, + pandas_file_reader, + ) + except ImportError as exc: + raise VDBUploadOperator._milvus_import_error() from exc + if MilvusClient is None: + raise VDBUploadOperator._milvus_import_error() + return MilvusClient, cleanup_records, create_nvingest_collection, pandas_file_reader + + @staticmethod + def _load_milvus_finalize_helpers() -> tuple[Any, Any]: + try: + from nv_ingest_client.util.vdb.milvus import MilvusClient, wait_for_index + except ImportError as exc: + raise VDBUploadOperator._milvus_import_error() from exc + if MilvusClient is None: + raise VDBUploadOperator._milvus_import_error() + return MilvusClient, wait_for_index + + def _finalize_milvus(self) -> None: + """Flush Milvus and wait once for all collection indexes to catch up.""" + MilvusClient, wait_for_index = self._load_milvus_finalize_helpers() + + collection_name, create_params = self._client_vdb.get_connection_params() + if not isinstance(collection_name, str): + raise ValueError( + "VDBUploadOperator.finalize() requires a string Milvus collection_name; " + f"got {type(collection_name).__name__}." + ) + + client = MilvusClient(**self._milvus_client_kwargs(create_params)) + if hasattr(client, "has_collection") and not client.has_collection(collection_name): + return + + client.flush(collection_name) + row_count = int(client.get_collection_stats(collection_name=collection_name).get("row_count", 0)) + if row_count == 0: + return + + index_names = client.list_indexes(collection_name) + expected_rows = {index_name: row_count for index_name in index_names} + if getattr(self._client_vdb, "no_wait_index", False) or not expected_rows: + return + + wait_for_index(collection_name, expected_rows, client) diff --git a/nemo_retriever/src/nemo_retriever/graph_ingestor.py b/nemo_retriever/src/nemo_retriever/graph_ingestor.py index 35fe52611..ba6c5a213 100644 --- a/nemo_retriever/src/nemo_retriever/graph_ingestor.py +++ b/nemo_retriever/src/nemo_retriever/graph_ingestor.py @@ -27,6 +27,7 @@ from __future__ import annotations import json +import logging import os import sys from typing import Any, Callable, Dict, List, Optional, Union @@ -45,10 +46,13 @@ HtmlChunkParams, StoreParams, TextChunkParams, + VdbUploadParams, WebhookParams, ) from nemo_retriever.utils.remote_auth import resolve_remote_api_key +logger = logging.getLogger(__name__) + def _resolve_api_key(params: Any) -> Any: """Auto-resolve api_key from NVIDIA_API_KEY / NGC_API_KEY if not explicitly set.""" @@ -149,6 +153,8 @@ def __init__( self._caption_params: Any = None self._dedup_params: Any = None self._store_params: Any = None + self._vdb_upload_params: Any = None + self._vdb_op: Any = None self._webhook_params: Any = None # Ordered list of stage names; "extract" is tracked but excluded from # the post-extraction stage_order passed to graph builders. @@ -243,6 +249,35 @@ def embed(self, params: Optional[EmbedParams] = None, **kwargs: Any) -> "GraphIn self._record_stage("embed") return self + def vdb_upload( + self, + params: Optional[VdbUploadParams] = None, + *, + vdb_op: Any = None, + **kwargs: Any, + ) -> "GraphIngestor": + """Record a VDB upload sink. + + Writes embeddings to a vector store as data flows through the graph, + eliminating the need to collect all results on the driver. Index + creation happens automatically after the pipeline completes. + + Parameters + ---------- + params: + Upload/record-shaping config. When ``vdb_op`` is ``None`` this + also determines which backend to construct. + vdb_op: + Optional pre-constructed client VDB instance (e.g. ``Milvus(...)`` + or ``LanceDB(...)`` from ``nv_ingest_client.util.vdb``). When + provided, the operator uses it directly instead of building one + from ``params``. + """ + self._vdb_upload_params = _coerce(params, kwargs, default_factory=VdbUploadParams) + self._vdb_op = vdb_op + self._record_stage("vdb_upload") + return self + def webhook(self, params: Optional[WebhookParams] = None, **kwargs: Any) -> "GraphIngestor": """Record a webhook notification stage (always runs last). @@ -305,6 +340,7 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: ) cluster_resources = gather_cluster_resources(ray) + vdb_upload_ops: list[Any] = [] graph = build_graph( extraction_mode=self._extraction_mode, extract_params=self._extract_params, @@ -317,8 +353,11 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: caption_params=self._caption_params, dedup_params=self._dedup_params, store_params=self._store_params, + vdb_upload_params=self._vdb_upload_params, + vdb_op=self._vdb_op, webhook_params=self._webhook_params, stage_order=post_extract_order, + vdb_upload_ops_out=vdb_upload_ops, ) # Derive per-node Ray scheduling config from BatchTuningParams plus # cluster-scaled heuristic defaults, then let any explicit @@ -331,6 +370,16 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: allow_no_gpu=effective_allow_no_gpu, caption_params=self._caption_params, ) + # VDBUploadOperator must run with concurrency=1 to avoid table + # creation races — a single actor creates the table on its first + # write and appends on subsequent batches. A large batch_size + # amortises per-call LanceDB disk I/O across many rows. + if self._vdb_upload_params is not None: + from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator + + vdb_overrides = derived_overrides.setdefault(VDBUploadOperator.__name__, {}) + vdb_overrides["concurrency"] = 1 + vdb_overrides["batch_size"] = 64 merged_overrides: Dict[str, Dict[str, Any]] = {} for node_name in set(derived_overrides) | set(self._node_overrides): merged_overrides[node_name] = { @@ -347,8 +396,12 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: ) result = executor.ingest(self._documents) self._rd_dataset = result + for op in vdb_upload_ops: + op.finalize() + self._finalize_vdb() return result else: + vdb_upload_ops = [] graph = build_graph( extraction_mode=self._extraction_mode, extract_params=self._extract_params, @@ -361,12 +414,19 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: caption_params=self._caption_params, dedup_params=self._dedup_params, store_params=self._store_params, + vdb_upload_params=self._vdb_upload_params, + vdb_op=self._vdb_op, webhook_params=self._webhook_params, stage_order=post_extract_order, + vdb_upload_ops_out=vdb_upload_ops, ) executor = InprocessExecutor(graph, show_progress=self._show_progress) self._rd_dataset = None - return executor.ingest(self._documents) + result = executor.ingest(self._documents) + for op in vdb_upload_ops: + op.finalize() + self._finalize_vdb() + return result # ------------------------------------------------------------------ # Internal helpers @@ -441,6 +501,75 @@ def get_error_rows(self, dataset: Any = None) -> Any: def get_dataset(self) -> Any: return self._rd_dataset + def _finalize_vdb(self) -> None: + """Create VDB indices after the pipeline writes all rows. + + Delegates to the client VDB class for index creation. For LanceDB + the client's ``write_to_index(table=table)`` builds the ANN (and + optionally FTS) index. For other backends index creation is handled + during streaming writes. + """ + if self._vdb_upload_params is None: + return + + from nemo_retriever.params.models import LanceDbParams, VdbUploadParams + + if self._vdb_op is not None: + backend_type = type(self._vdb_op).__name__.lower() + elif isinstance(self._vdb_upload_params, VdbUploadParams): + backend_type = self._vdb_upload_params.backend + elif isinstance(self._vdb_upload_params, LanceDbParams): + backend_type = "lancedb" + else: + return + + if backend_type != "lancedb": + return + + if isinstance(self._vdb_upload_params, VdbUploadParams): + lance_params = self._vdb_upload_params.lancedb + elif isinstance(self._vdb_upload_params, LanceDbParams): + lance_params = self._vdb_upload_params + else: + lance_params = LanceDbParams() + + if not lance_params.create_index: + return + + import lancedb + + from nemo_retriever.vector_store.lancedb_utils import build_client_lancedb + + uri = getattr(self._vdb_op, "uri", lance_params.lancedb_uri) if self._vdb_op else lance_params.lancedb_uri + table_name = ( + getattr(self._vdb_op, "table_name", lance_params.table_name) if self._vdb_op else lance_params.table_name + ) + + try: + db = lancedb.connect(uri=uri) + table = db.open_table(table_name) + except Exception: + return + + client = self._vdb_op if self._vdb_op is not None else build_client_lancedb(lance_params, overwrite=False) + try: + client.write_to_index( + None, + table=table, + index_type=getattr(client, "index_type", lance_params.index_type), + metric=getattr(client, "metric", lance_params.metric), + num_partitions=getattr(client, "num_partitions", lance_params.num_partitions), + num_sub_vectors=getattr(client, "num_sub_vectors", lance_params.num_sub_vectors), + hybrid=getattr(client, "hybrid", lance_params.hybrid), + fts_language=getattr(client, "fts_language", lance_params.fts_language), + ) + except RuntimeError: + logger.warning( + "Index creation failed (likely too few rows for %d partitions); skipping.", + lance_params.num_partitions, + exc_info=True, + ) + def _record_stage(self, name: str) -> None: """Append *name* to the stage order list (deduplicated in place).""" if name not in self._stage_order: diff --git a/nemo_retriever/src/nemo_retriever/params/models.py b/nemo_retriever/src/nemo_retriever/params/models.py index ce4cc4825..9b714802d 100644 --- a/nemo_retriever/src/nemo_retriever/params/models.py +++ b/nemo_retriever/src/nemo_retriever/params/models.py @@ -293,6 +293,8 @@ def _warn_page_granularity_overrides(self) -> "EmbedParams": class VdbUploadParams(_ParamsModel): purge_results_after_upload: bool = True lancedb: LanceDbParams = Field(default_factory=LanceDbParams) + backend: str = "lancedb" # "lancedb" | "milvus" | "opensearch" + client_vdb_kwargs: dict[str, Any] = Field(default_factory=dict) class StoreParams(_ParamsModel): diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index e5be26dea..3f993388d 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -307,10 +307,13 @@ def _hits_to_keys(raw_hits: List[List[Dict[str, Any]]]) -> List[List[str]]: keys: List[str] = [] for h in hits: page_number = h["page_number"] - source = h["source"] + raw_source = h["source"] + # source may be a bare path string or a JSON object {"source_id": "..."}. + source_map = _parse_mapping(raw_source) + source = source_map.get("source_id", raw_source) if source_map else raw_source # Prefer explicit `pdf_page` column; fall back to derived form. if page_number is not None and source: - filename = Path(source).stem + filename = Path(str(source)).stem keys.append(f"{filename}_{str(page_number)}") else: logger.warning( diff --git a/nemo_retriever/src/nemo_retriever/text_embed/processor.py b/nemo_retriever/src/nemo_retriever/text_embed/processor.py index 81dd4b8a6..662829112 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/processor.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/processor.py @@ -14,7 +14,8 @@ from nv_ingest_api.internal.transform.embed_text import transform_create_text_embeddings_internal from nemo_retriever.io.dataframe import validate_primitives_dataframe -from nemo_retriever.vector_store.lancedb_store import LanceDBConfig, write_embeddings_to_lancedb +from nemo_retriever.params.models import LanceDbParams +from nemo_retriever.vector_store.lancedb_store import write_embeddings_to_lancedb logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ def embed_text_from_primitives_df( *, transform_config: TextEmbeddingSchema, task_config: Optional[Dict[str, Any]] = None, - lancedb: Optional[LanceDBConfig] = None, + lancedb: Optional[LanceDbParams] = None, trace_info: Optional[Dict[str, Any]] = None, ) -> Tuple[pd.DataFrame, Dict[str, Any]]: """Generate embeddings for supported content types and write to metadata.""" diff --git a/nemo_retriever/src/nemo_retriever/vector_store/__init__.py b/nemo_retriever/src/nemo_retriever/vector_store/__init__.py index 1c05e4e8f..3320092d4 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/__init__.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/__init__.py @@ -4,15 +4,15 @@ from .__main__ import app from .lancedb_store import ( - LanceDBConfig, create_lancedb_index, write_embeddings_to_lancedb, write_text_embeddings_dir_to_lancedb, ) +from .vdb_records import build_vdb_records __all__ = [ "app", - "LanceDBConfig", + "build_vdb_records", "create_lancedb_index", "write_embeddings_to_lancedb", "write_text_embeddings_dir_to_lancedb", diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py index 70503229a..d99e3c09e 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py @@ -23,6 +23,23 @@ } +def _ensure_dict(value: Any) -> Optional[Dict[str, Any]]: + """Coerce *value* to a dict, parsing JSON strings if needed. + + Arrow serialization between Ray actors can convert dict columns to + strings. This helper lets downstream code handle both forms. + """ + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + parsed = json.loads(value) + return parsed if isinstance(parsed, dict) else None + except (json.JSONDecodeError, TypeError): + return None + return None + + def extract_embedding_from_row( row: Any, *, @@ -30,14 +47,14 @@ def extract_embedding_from_row( embedding_key: str = "embedding", ) -> Optional[List[float]]: """Extract an embedding vector from a row.""" - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): + meta = _ensure_dict(getattr(row, "metadata", None)) + if meta is not None: emb = meta.get("embedding") if isinstance(emb, list) and emb: return emb # type: ignore[return-value] - payload = getattr(row, embedding_column, None) - if isinstance(payload, dict): + payload = _ensure_dict(getattr(row, embedding_column, None)) + if payload is not None: emb = payload.get(embedding_key) if isinstance(emb, list) and emb: return emb # type: ignore[return-value] @@ -60,8 +77,8 @@ def extract_source_path_and_page(row: Any) -> Tuple[str, int]: except Exception: pass - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): + meta = _ensure_dict(getattr(row, "metadata", None)) + if meta is not None: source_path = meta.get("source_path") if isinstance(source_path, str) and source_path.strip(): path = source_path.strip() @@ -143,8 +160,8 @@ def build_lancedb_row( metadata_obj.update(_build_detection_metadata(row)) update_metadata_with_content_type(metadata_obj, content_type=getattr(row, "_content_type", None)) - orig_meta = getattr(row, "metadata", None) - if isinstance(orig_meta, dict): + orig_meta = _ensure_dict(getattr(row, "metadata", None)) + if orig_meta is not None: for key in ("chunk_index", "chunk_count"): if key in orig_meta: metadata_obj[key] = orig_meta[key] @@ -180,29 +197,6 @@ def build_lancedb_row( return row_out -def build_lancedb_rows( - df: Any, - *, - embedding_column: str = "text_embeddings_1b_v2", - embedding_key: str = "embedding", - text_column: str = "text", - include_text: bool = True, -) -> List[Dict[str, Any]]: - """Build LanceDB rows from a pandas DataFrame.""" - rows: List[Dict[str, Any]] = [] - for row in df.itertuples(index=False): - row_out = build_lancedb_row( - row, - embedding_column=embedding_column, - embedding_key=embedding_key, - text_column=text_column, - include_text=include_text, - ) - if row_out is not None: - rows.append(row_out) - return rows - - def lancedb_schema(vector_dim: int = 2048) -> Any: """Return a PyArrow schema for the standard LanceDB table layout.""" import pyarrow as pa # type: ignore @@ -235,6 +229,28 @@ def infer_vector_dim(rows: List[Dict[str, Any]]) -> int: return 0 +def build_client_lancedb(lance_params: Any, *, overwrite: Optional[bool] = None) -> Any: + """Construct an ``nv_ingest_client.util.vdb.LanceDB`` from ``LanceDbParams``. + + Pass ``overwrite=False`` to open an existing table (e.g. for post-pipeline + index creation); otherwise ``lance_params.overwrite`` is used. + """ + from nv_ingest_client.util.vdb import get_vdb_op_cls + + LanceDB = get_vdb_op_cls("lancedb") + return LanceDB( + uri=lance_params.lancedb_uri, + table_name=lance_params.table_name, + overwrite=lance_params.overwrite if overwrite is None else overwrite, + index_type=lance_params.index_type, + metric=lance_params.metric, + num_partitions=lance_params.num_partitions, + num_sub_vectors=lance_params.num_sub_vectors, + hybrid=lance_params.hybrid, + fts_language=lance_params.fts_language, + ) + + def create_or_append_lancedb_table( db: Any, table_name: str, diff --git a/nemo_retriever/src/nemo_retriever/vector_store/stage.py b/nemo_retriever/src/nemo_retriever/vector_store/stage.py index ba3f994a1..7c59d718b 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/stage.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/stage.py @@ -10,7 +10,8 @@ import typer from rich.console import Console -from nemo_retriever.vector_store.lancedb_store import LanceDBConfig, write_text_embeddings_dir_to_lancedb +from nemo_retriever.params.models import LanceDbParams +from nemo_retriever.vector_store.lancedb_store import write_text_embeddings_dir_to_lancedb console = Console() app = typer.Typer(help="Vector store stage: upload stage5 embeddings to a vector DB (LanceDB).") @@ -54,8 +55,8 @@ def run( - `page_number`: page number from `metadata.content_metadata.page_number` - `path` / `source_id`: source identifiers """ - cfg = LanceDBConfig( - uri=str(lancedb_uri), + cfg = LanceDbParams( + lancedb_uri=str(lancedb_uri), table_name=str(table_name), overwrite=bool(overwrite), create_index=bool(create_index), @@ -73,7 +74,7 @@ def run( ) console.print( f"[green]Done[/green] files={info['n_files']} processed={info['processed']} skipped={info['skipped']} " - f"failed={info['failed']} lancedb_uri={cfg.uri} table={cfg.table_name}" + f"failed={info['failed']} lancedb_uri={cfg.lancedb_uri} table={cfg.table_name}" ) diff --git a/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py b/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py new file mode 100644 index 000000000..c25d2256e --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Canonical VDB record builder. + +Converts a pandas DataFrame (the graph pipeline's output format) into a list +of backend-neutral VDB record dicts. Every VDB backend in ``nemo_retriever`` +consumes this record format — it is the single source of truth for the +DataFrame → VDB record contract. + +Canonical record schema (matches ``retriever.py`` query expectations):: + + vector : list[float] # embedding + text : str # content + metadata : str # JSON string (round-trips via json.loads) + source : str # JSON string {"source_id": "..."} + page_number : int + pdf_page : str # "basename_pagenum" + pdf_basename : str + filename : str + source_id : str + path : str +""" + +from __future__ import annotations + +from typing import Any, Dict, List + +import pandas as pd + +from nemo_retriever.vector_store.lancedb_utils import build_lancedb_row + + +def build_vdb_records( + df: pd.DataFrame, + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", + text_column: str = "text", + include_text: bool = True, +) -> List[Dict[str, Any]]: + """Convert a post-embed DataFrame into canonical VDB records. + + Rows without a valid embedding are silently skipped. + """ + rows: List[Dict[str, Any]] = [] + for row in df.itertuples(index=False): + row_out = build_lancedb_row( + row, + embedding_column=embedding_column, + embedding_key=embedding_key, + text_column=text_column, + include_text=include_text, + ) + if row_out is not None: + rows.append(row_out) + return rows diff --git a/nemo_retriever/tests/integration_test_milvus_recall.py b/nemo_retriever/tests/integration_test_milvus_recall.py new file mode 100644 index 000000000..bbeddfe4e --- /dev/null +++ b/nemo_retriever/tests/integration_test_milvus_recall.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end Milvus recall test: ingest jp20 → embed queries → search → compute recall. + +Requires: + - Running Milvus instance + - GPU for extraction + embedding + - jp20 dataset at /datasets/nv-ingest/jp20 + - Query CSV at data/jp20_query_gt.csv + +Usage: + python tests/integration_test_milvus_recall.py [--milvus-uri URI] +""" + +from __future__ import annotations + +import sys +import time +from pathlib import Path + +MILVUS_URI = "http://172.20.0.4:19530" +COLLECTION_NAME = "jp20_recall_test" +DATASET_DIR = "/datasets/nv-ingest/jp20" +QUERY_CSV = "/raid/jioffe/NeMo-Retriever/data/jp20_query_gt.csv" +EMBED_MODEL = "nvidia/llama-nemotron-embed-1b-v2" +EMBED_DIM = 2048 +TOP_K = 10 + + +def ingest_jp20_to_milvus(): + """Run the full graph pipeline with Milvus as the VDB backend.""" + from nemo_retriever import create_ingestor + from nemo_retriever.params import EmbedParams, VdbUploadParams + + print("--- Step 1: Ingesting jp20 into Milvus ---") + t0 = time.perf_counter() + + ingestor = ( + create_ingestor(run_mode="batch") + .files(DATASET_DIR + "/*.pdf") + .extract(extract_text=True, extract_tables=True, extract_charts=True) + .embed( + EmbedParams( + model_name=EMBED_MODEL, + inference_batch_size=32, + ) + ) + .vdb_upload( + VdbUploadParams( + backend="milvus", + client_vdb_kwargs={ + "milvus_uri": MILVUS_URI, + "collection_name": COLLECTION_NAME, + "dense_dim": EMBED_DIM, + "recreate": True, + "gpu_index": False, + "stream": True, + "sparse": False, + }, + ) + ) + ) + ingestor.ingest() + + elapsed = time.perf_counter() - t0 + print(f"[OK] Ingestion complete in {elapsed:.1f}s") + return elapsed + + +def embed_queries(): + """Embed jp20 queries using the local model.""" + import pandas as pd + + from nemo_retriever.model import create_local_embedder + + print("--- Step 2: Embedding queries ---") + + df = pd.read_csv(QUERY_CSV) + queries = df["query"].astype(str).tolist() + gold_keys = [] + for _, row in df.iterrows(): + if "pdf_page" in df.columns: + gold_keys.append(str(row["pdf_page"])) + else: + gold_keys.append(f"{row['pdf']}_{row['page']}") + + print(f" {len(queries)} queries, embedding with {EMBED_MODEL}...") + embedder = create_local_embedder(EMBED_MODEL, device="cuda") + vecs = embedder.embed(["query: " + q for q in queries], batch_size=32) + query_embeddings = vecs.detach().to("cpu").tolist() + print(f"[OK] Embedded {len(query_embeddings)} queries") + + return queries, gold_keys, query_embeddings + + +def search_milvus(query_embeddings): + """Search Milvus collection with pre-computed query embeddings.""" + from pymilvus import MilvusClient + + print("--- Step 3: Searching Milvus ---") + + client = MilvusClient(uri=MILVUS_URI) + t0 = time.perf_counter() + + all_hits = [] + # Batch queries to avoid overwhelming Milvus + batch_size = 20 + for i in range(0, len(query_embeddings), batch_size): + batch = query_embeddings[i : i + batch_size] + results = client.search( + collection_name=COLLECTION_NAME, + data=batch, + limit=TOP_K, + output_fields=["text", "source", "content_metadata"], + ) + all_hits.extend(results) + + elapsed = time.perf_counter() - t0 + print(f"[OK] Searched {len(query_embeddings)} queries in {elapsed:.1f}s ({len(query_embeddings)/elapsed:.1f} QPS)") + + return all_hits + + +def hits_to_keys(all_hits): + """Extract pdf_page keys from Milvus search results.""" + import json + + retrieved_keys = [] + for hits in all_hits: + keys = [] + for h in hits: + entity = h.get("entity", {}) + # Milvus stores content_metadata and source as dicts (not JSON strings) + content_meta = entity.get("content_metadata", {}) + source = entity.get("source", {}) + + if isinstance(content_meta, str): + try: + content_meta = json.loads(content_meta) + except (json.JSONDecodeError, TypeError): + content_meta = {} + if isinstance(source, str): + try: + source = json.loads(source) + except (json.JSONDecodeError, TypeError): + source = {} + + source_id = source.get("source_id", "") + page_number = content_meta.get("page_number") + if page_number is None: + hierarchy = content_meta.get("hierarchy", {}) + if isinstance(hierarchy, dict): + page_number = hierarchy.get("page") + + if source_id and page_number is not None: + filename = Path(str(source_id)).stem + keys.append(f"{filename}_{page_number}") + retrieved_keys.append(keys) + + return retrieved_keys + + +def compute_recall(gold_keys, retrieved_keys, ks=(1, 5, 10)): + """Compute recall@k metrics.""" + print("--- Step 4: Computing recall ---") + + metrics = {} + for k in ks: + hits = 0 + for gold, retrieved in zip(gold_keys, retrieved_keys): + parts = str(gold).rsplit("_", 1) + if len(parts) == 2: + specific = f"{parts[0]}_{parts[1]}" + whole_doc = f"{parts[0]}_-1" + top = retrieved[:k] + if specific in top or whole_doc in top: + hits += 1 + elif gold in retrieved[:k]: + hits += 1 + recall = hits / max(1, len(gold_keys)) + metrics[f"recall@{k}"] = recall + print(f" recall@{k}: {recall:.4f}") + + return metrics + + +def cleanup(): + from pymilvus import MilvusClient + + client = MilvusClient(uri=MILVUS_URI) + if client.has_collection(COLLECTION_NAME): + client.drop_collection(COLLECTION_NAME) + print(f"[OK] Cleaned up collection {COLLECTION_NAME}") + + +def main(): + print("=" * 60) + print("Milvus End-to-End Recall Test (jp20)") + print(f" Milvus: {MILVUS_URI}") + print(f" Collection: {COLLECTION_NAME}") + print(f" Dataset: {DATASET_DIR}") + print(f" Queries: {QUERY_CSV}") + print("=" * 60) + print() + + try: + ingest_secs = ingest_jp20_to_milvus() + queries, gold_keys, query_embeddings = embed_queries() + all_hits = search_milvus(query_embeddings) + retrieved_keys = hits_to_keys(all_hits) + metrics = compute_recall(gold_keys, retrieved_keys) + + print() + print("=" * 60) + print("RESULTS") + print(f" Ingestion: {ingest_secs:.1f}s") + for k, v in metrics.items(): + print(f" {k}: {v:.4f}") + print() + + # Compare against LanceDB baseline + baseline = {"recall@1": 0.6435, "recall@5": 0.8783, "recall@10": 0.9304} + print("vs LanceDB baseline:") + all_match = True + for k, bl in baseline.items(): + mv = metrics.get(k, 0) + delta = mv - bl + status = "MATCH" if abs(delta) < 0.01 else ("BETTER" if delta > 0 else "WORSE") + print(f" {k}: Milvus={mv:.4f} LanceDB={bl:.4f} ({status})") + if status == "WORSE": + all_match = False + + print() + if all_match: + print("=== PASS: Milvus recall matches LanceDB baseline ===") + else: + print("=== NOTE: Milvus recall differs from LanceDB baseline ===") + + return 0 + + except Exception as e: + print(f"\n[FAIL] {e}") + import traceback + + traceback.print_exc() + return 1 + + finally: + cleanup() + + +if __name__ == "__main__": + if "--milvus-uri" in sys.argv: + idx = sys.argv.index("--milvus-uri") + MILVUS_URI = sys.argv[idx + 1] + sys.exit(main()) diff --git a/nemo_retriever/tests/integration_test_milvus_vdb.py b/nemo_retriever/tests/integration_test_milvus_vdb.py new file mode 100644 index 000000000..c2cf49302 --- /dev/null +++ b/nemo_retriever/tests/integration_test_milvus_vdb.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Integration test: VDBUploadOperator writing to a real Milvus instance. + +Requires a running Milvus at the URI below. Run manually: + + python tests/integration_test_milvus_vdb.py + +This is NOT part of the pytest suite — it requires external infrastructure. +""" + +from __future__ import annotations + +import os +import sys + +import pandas as pd + +MILVUS_URI = os.environ.get("NEMO_RETRIEVER_MILVUS_URI", "http://172.20.0.4:19530") +COLLECTION_NAME = "nemo_retriever_integration_test" +EMBED_DIM = 128 + + +def _make_embedded_df(n: int = 10, dim: int = EMBED_DIM) -> pd.DataFrame: + """Build a minimal post-embed DataFrame.""" + rows = [] + for i in range(n): + embedding = [float(i * dim + j) / (n * dim) for j in range(dim)] + metadata = { + "embedding": embedding, + "source_path": f"/data/doc_{i}.pdf", + "content_metadata": {"hierarchy": {"page": i}}, + } + rows.append( + { + "metadata": metadata, + "text_embeddings_1b_v2": {"embedding": embedding, "info_msg": None}, + "text": f"This is the content of page {i} in the test document.", + "path": f"/data/doc_{i}.pdf", + "page_number": i, + "document_type": "text", + } + ) + return pd.DataFrame(rows) + + +def main(): + from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator + from nemo_retriever.params.models import VdbUploadParams + + print(f"=== Milvus Integration Test ===") + print(f"Milvus URI: {MILVUS_URI}") + print(f"Collection: {COLLECTION_NAME}") + print() + + # --- Step 1: Create operator with Milvus backend --- + params = VdbUploadParams( + backend="milvus", + client_vdb_kwargs={ + "milvus_uri": MILVUS_URI, + "collection_name": COLLECTION_NAME, + "dense_dim": EMBED_DIM, + "recreate": True, + "gpu_index": False, + "stream": True, + "sparse": False, + }, + ) + op = VDBUploadOperator(params=params) + print(f"[OK] VDBUploadOperator created with backend='milvus'") + + # --- Step 2: Run two batches through the operator --- + df_batch1 = _make_embedded_df(5) + df_batch2 = _make_embedded_df(5) + + result1 = op.run(df_batch1) + print(f"[OK] Batch 1: {len(result1)} rows passed through, records written") + + result2 = op.run(df_batch2) + print(f"[OK] Batch 2: {len(result2)} rows passed through, records written") + + # --- Step 3: Finalize indexing and verify data landed in Milvus --- + op.finalize() + print("[OK] Finalized Milvus collection") + + from pymilvus import MilvusClient + + client = MilvusClient(uri=MILVUS_URI) + + if not client.has_collection(COLLECTION_NAME): + print(f"[FAIL] Collection {COLLECTION_NAME} not found!") + return 1 + + stats = client.get_collection_stats(COLLECTION_NAME) + row_count = stats.get("row_count", 0) + print(f"[OK] Collection exists with {row_count} rows") + + # Query to verify data is searchable + query_vector = _make_embedded_df(1).iloc[0]["metadata"]["embedding"] + results = client.search( + collection_name=COLLECTION_NAME, + data=[query_vector], + limit=3, + output_fields=["text"], + ) + print(f"[OK] Search returned {len(results[0])} hits") + for i, hit in enumerate(results[0]): + text = hit.get("entity", {}).get("text", "")[:60] + print(f" Hit {i}: distance={hit['distance']:.4f} text='{text}...'") + + # --- Step 4: Cleanup --- + client.drop_collection(COLLECTION_NAME) + print(f"[OK] Cleaned up collection {COLLECTION_NAME}") + + print() + print("=== PASS: Milvus integration test completed successfully ===") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/nemo_retriever/tests/integration_test_vdb_op_passthrough.py b/nemo_retriever/tests/integration_test_vdb_op_passthrough.py new file mode 100644 index 000000000..fbdfc8d0a --- /dev/null +++ b/nemo_retriever/tests/integration_test_vdb_op_passthrough.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Integration test: pass pre-constructed client VDB instances into the graph. + +Exercises the lead-review ask for PR 1847 — that `GraphIngestor.vdb_upload` +accept a pre-built `nv_ingest_client.util.vdb` instance (`LanceDB(...)` or +`Milvus(...)`) and wire it directly into `VDBUploadOperator` without the +operator reconstructing it from `VdbUploadParams`. + +Runs a small bo20 ingest against each backend and verifies the target table / +collection was populated. + +Requires: + - Running Milvus at ``MILVUS_URI`` + - GPU for extraction + embedding + - bo20 dataset at ``DATASET_DIR`` + +Usage: + python tests/integration_test_vdb_op_passthrough.py [--milvus-uri URI] +""" + +from __future__ import annotations + +import shutil +import sys +import tempfile +import time +import traceback +from pathlib import Path + +MILVUS_URI = "http://172.20.0.4:19530" +MILVUS_COLLECTION = "vdb_op_passthrough_bo20" +DATASET_DIR = "/datasets/nv-ingest/bo20" +EMBED_MODEL = "nvidia/llama-nemotron-embed-1b-v2" +EMBED_DIM = 2048 + + +def _run_ingest(vdb_op): + from nemo_retriever import create_ingestor + from nemo_retriever.params import EmbedParams + + t0 = time.perf_counter() + ( + create_ingestor(run_mode="batch") + .files(DATASET_DIR + "/*.pdf") + .extract(extract_text=True, extract_tables=True, extract_charts=True) + .embed(EmbedParams(model_name=EMBED_MODEL, inference_batch_size=32)) + .vdb_upload(vdb_op=vdb_op) + .ingest() + ) + return time.perf_counter() - t0 + + +def test_lancedb_passthrough(): + from nv_ingest_client.util.vdb import get_vdb_op_cls + + LanceDB = get_vdb_op_cls("lancedb") + tmp_dir = Path(tempfile.mkdtemp(prefix="vdb_op_passthrough_lance_")) + uri = str(tmp_dir / "lancedb") + table_name = "bo20_lance_passthrough" + + print("--- LanceDB pre-constructed instance ---") + print(f" uri: {uri}") + print(f" table: {table_name}") + + client = LanceDB(uri=uri, table_name=table_name, overwrite=True) + assert type(client).__name__ == "LanceDB", f"expected LanceDB, got {type(client).__name__}" + + elapsed = _run_ingest(client) + + import lancedb + + db = lancedb.connect(uri) + table = db.open_table(table_name) + n = table.count_rows() + print(f"[OK] LanceDB passthrough ingested {n} rows in {elapsed:.1f}s") + + shutil.rmtree(tmp_dir, ignore_errors=True) + return n > 0 + + +def test_milvus_passthrough(): + from nv_ingest_client.util.vdb import get_vdb_op_cls + from pymilvus import MilvusClient + + Milvus = get_vdb_op_cls("milvus") + print("--- Milvus pre-constructed instance ---") + print(f" uri: {MILVUS_URI}") + print(f" collection: {MILVUS_COLLECTION}") + + client = Milvus( + milvus_uri=MILVUS_URI, + collection_name=MILVUS_COLLECTION, + dense_dim=EMBED_DIM, + recreate=True, + gpu_index=False, + stream=True, + sparse=False, + ) + assert type(client).__name__ == "Milvus", f"expected Milvus, got {type(client).__name__}" + + elapsed = _run_ingest(client) + + mc = MilvusClient(uri=MILVUS_URI) + mc.load_collection(MILVUS_COLLECTION) + stats = mc.get_collection_stats(collection_name=MILVUS_COLLECTION) + n = int(stats.get("row_count", 0)) + print(f"[OK] Milvus passthrough ingested {n} rows in {elapsed:.1f}s") + + mc.drop_collection(MILVUS_COLLECTION) + return n > 0 + + +def main(): + print("=" * 60) + print("VDB-op passthrough integration test (bo20)") + print(f" Dataset: {DATASET_DIR}") + print("=" * 60) + + results = {} + try: + results["lancedb"] = test_lancedb_passthrough() + except Exception as exc: # noqa: BLE001 + print(f"[FAIL] LanceDB passthrough: {exc}") + traceback.print_exc() + results["lancedb"] = False + + try: + results["milvus"] = test_milvus_passthrough() + except Exception as exc: # noqa: BLE001 + print(f"[FAIL] Milvus passthrough: {exc}") + traceback.print_exc() + results["milvus"] = False + + print() + print("=" * 60) + print("RESULTS") + for backend, ok in results.items(): + print(f" {backend}: {'PASS' if ok else 'FAIL'}") + print("=" * 60) + return 0 if all(results.values()) else 1 + + +if __name__ == "__main__": + if "--milvus-uri" in sys.argv: + idx = sys.argv.index("--milvus-uri") + MILVUS_URI = sys.argv[idx + 1] + sys.exit(main()) diff --git a/nemo_retriever/tests/test_batch_pipeline.py b/nemo_retriever/tests/test_batch_pipeline.py index 33a2ee474..83a5aa887 100644 --- a/nemo_retriever/tests/test_batch_pipeline.py +++ b/nemo_retriever/tests/test_batch_pipeline.py @@ -2,6 +2,7 @@ import json from types import SimpleNamespace +import pandas as pd from typer.testing import CliRunner import nemo_retriever.examples.graph_pipeline as batch_pipeline @@ -20,6 +21,9 @@ def materialize(self): def take_all(self): return [] + def count(self): + return 1 + def groupby(self, _key): class _FakeGrouped: @staticmethod @@ -80,6 +84,18 @@ def embed(self, params): self.embed_params = params return self + def vdb_upload(self, params=None): + return self + + def store(self, params=None): + return self + + def caption(self, params=None): + return self + + def dedup(self, params=None): + return self + def ingest(self, params=None): return _FakeDataset() @@ -111,6 +127,19 @@ def test_graph_pipeline_resolves_nested_pdf_directories(tmp_path) -> None: assert patterns == [str(dataset_dir / "**" / "*.pdf")] +def test_graph_pipeline_counts_pages_separately_from_output_rows() -> None: + df = pd.DataFrame( + [ + {"path": "/data/a.pdf", "page_number": 1, "text": "page text"}, + {"path": "/data/a.pdf", "page_number": 1, "text": "table text"}, + {"path": "/data/a.pdf", "page_number": 2, "text": "page text"}, + ] + ) + + assert batch_pipeline._count_processed_pages_from_df(df) == 2 + assert len(df.index) == 3 + + def test_batch_pipeline_accepts_multimodal_embed_and_page_image_flags(tmp_path, monkeypatch) -> None: dataset_dir = tmp_path / "dataset" dataset_dir.mkdir() @@ -119,8 +148,6 @@ def test_batch_pipeline_accepts_multimodal_embed_and_page_image_flags(tmp_path, fake_ingestor = _FakeIngestor() monkeypatch.setattr(batch_pipeline, "GraphIngestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) monkeypatch.setitem(sys.modules, "ray", SimpleNamespace(shutdown=lambda: None)) class _FakeTable: @@ -165,8 +192,6 @@ def test_batch_pipeline_routes_audio_input_to_audio_ingestor(tmp_path, monkeypat fake_ingestor = _FakeIngestor() monkeypatch.setattr(batch_pipeline, "GraphIngestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) monkeypatch.setitem(sys.modules, "ray", SimpleNamespace(shutdown=lambda: None)) monkeypatch.setattr( batch_pipeline, "asr_params_from_env", lambda: SimpleNamespace(model_copy=lambda update: update) @@ -217,8 +242,6 @@ def test_batch_pipeline_routes_beir_mode_to_evaluator(tmp_path, monkeypatch) -> fake_ingestor = _FakeIngestor() monkeypatch.setattr(batch_pipeline, "GraphIngestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) monkeypatch.setattr(detection_summary_module, "print_run_summary", lambda *args, **kwargs: None) class _FakeTable: @@ -278,8 +301,6 @@ def test_batch_pipeline_accepts_harness_runtime_metric_flags(tmp_path, monkeypat fake_ingestor = _FakeIngestor() monkeypatch.setattr(batch_pipeline, "GraphIngestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) monkeypatch.setitem(sys.modules, "ray", SimpleNamespace(shutdown=lambda: None)) class _FakeTable: diff --git a/nemo_retriever/tests/test_vdb_record_contract.py b/nemo_retriever/tests/test_vdb_record_contract.py new file mode 100644 index 000000000..29fecdf85 --- /dev/null +++ b/nemo_retriever/tests/test_vdb_record_contract.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the canonical VDB record contract and LanceDB store helpers. + +These tests validate: + - build_vdb_records produces the correct canonical record format + - _ensure_dict handles Arrow serialization robustness +""" + +from __future__ import annotations + +import copy +import json + +import pandas as pd + + +# --------------------------------------------------------------------------- +# Shared fixture +# --------------------------------------------------------------------------- + + +def _make_sample_dataframe() -> pd.DataFrame: + """Build a minimal DataFrame matching the graph pipeline's post-embed output.""" + embedding = [0.1, 0.2, 0.3, 0.4] + metadata = { + "embedding": embedding, + "source_path": "/data/test.pdf", + "content_metadata": {"hierarchy": {"page": 0}}, + } + return pd.DataFrame( + [ + { + "metadata": metadata, + "text_embeddings_1b_v2": {"embedding": embedding, "info_msg": None}, + "text": "Hello world", + "path": "/data/test.pdf", + "page_number": 0, + "page_elements_v3_num_detections": 5, + "page_elements_v3_counts_by_label": {"text": 3, "table": 2}, + } + ] + ) + + +# --------------------------------------------------------------------------- +# Canonical record builder tests +# --------------------------------------------------------------------------- + + +class TestBuildVdbRecords: + def test_produces_all_required_fields(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + assert len(rows) == 1 + row = rows[0] + for field in ( + "vector", + "text", + "metadata", + "source", + "page_number", + "pdf_page", + "pdf_basename", + "source_id", + "path", + "filename", + ): + assert field in row, f"Missing required field: {field}" + + def test_metadata_is_valid_json(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + meta = json.loads(rows[0]["metadata"]) + assert isinstance(meta, dict) + assert "page_number" in meta + + def test_metadata_includes_detection_counts(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + meta = json.loads(rows[0]["metadata"]) + assert meta["page_elements_v3_num_detections"] == 5 + assert meta["page_elements_v3_counts_by_label"] == {"text": 3, "table": 2} + + def test_source_is_json_object(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + source = json.loads(rows[0]["source"]) + assert isinstance(source, dict) + assert "source_id" in source + + def test_does_not_mutate_input(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + original_meta = copy.deepcopy(df.iloc[0]["metadata"]) + + build_vdb_records(df) + + assert df.iloc[0]["metadata"] == original_meta + + def test_vector_is_embedding(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + assert rows[0]["vector"] == [0.1, 0.2, 0.3, 0.4] + + def test_skips_rows_without_embedding(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = pd.DataFrame( + [ + { + "metadata": {"source_path": "/data/test.pdf"}, + "text": "No embedding here", + "path": "/data/test.pdf", + "page_number": 0, + } + ] + ) + rows = build_vdb_records(df) + assert len(rows) == 0 + + def test_empty_dataframe(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = pd.DataFrame() + rows = build_vdb_records(df) + assert rows == [] + + def test_text_content(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + assert rows[0]["text"] == "Hello world" + + def test_include_text_false(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df, include_text=False) + assert rows[0]["text"] == "" + + +# --------------------------------------------------------------------------- +# _ensure_dict — Arrow serialization robustness +# --------------------------------------------------------------------------- + + +class TestEnsureDict: + """Tests for the _ensure_dict helper that handles Arrow-serialized dict columns.""" + + def test_dict_passthrough(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + d = {"a": 1} + assert _ensure_dict(d) is d + + def test_json_string_parsed(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict('{"a": 1}') == {"a": 1} + + def test_none_returns_none(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict(None) is None + + def test_non_dict_json_returns_none(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict("[1, 2, 3]") is None + + def test_malformed_json_returns_none(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict("not json{") is None + + def test_integer_returns_none(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict(42) is None + + +class TestBuildVdbRecordsArrowCompat: + """build_vdb_records handles string-encoded dict columns from Arrow.""" + + def test_string_metadata_with_embedding(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + embedding = [0.1, 0.2, 0.3, 0.4] + metadata = json.dumps( + { + "embedding": embedding, + "source_path": "/data/test.pdf", + "content_metadata": {"hierarchy": {"page": 0}}, + } + ) + df = pd.DataFrame( + [ + { + "metadata": metadata, + "text_embeddings_1b_v2": json.dumps({"embedding": embedding}), + "text": "hello world", + "path": "/data/test.pdf", + "page_number": 0, + } + ] + ) + rows = build_vdb_records(df) + assert len(rows) == 1 + assert rows[0]["vector"] == embedding + + def test_string_embed_column_only(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + embedding = [1.0, 2.0] + df = pd.DataFrame( + [ + { + "metadata": "{}", + "text_embeddings_1b_v2": json.dumps({"embedding": embedding}), + "text": "test", + "path": "/x.pdf", + "page_number": 0, + } + ] + ) + rows = build_vdb_records(df) + assert len(rows) == 1 + assert rows[0]["vector"] == embedding diff --git a/nemo_retriever/tests/test_vdb_upload_operator.py b/nemo_retriever/tests/test_vdb_upload_operator.py new file mode 100644 index 000000000..0b3b8be80 --- /dev/null +++ b/nemo_retriever/tests/test_vdb_upload_operator.py @@ -0,0 +1,452 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for VDBUploadOperator — the graph-based VDB write path.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator +from nemo_retriever.params.models import LanceDbParams, VdbUploadParams +from nv_ingest_client.util.vdb.adt_vdb import VDB + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +def _make_embedded_df(n: int = 3, dim: int = 4) -> pd.DataFrame: + """Build a minimal post-embed DataFrame with *n* rows.""" + rows = [] + for i in range(n): + embedding = [float(i + j) for j in range(dim)] + metadata = { + "embedding": embedding, + "source_path": f"/data/doc_{i}.pdf", + "content_metadata": {"hierarchy": {"page": i}}, + } + rows.append( + { + "metadata": metadata, + "text_embeddings_1b_v2": {"embedding": embedding, "info_msg": None}, + "text": f"content of page {i}", + "path": f"/data/doc_{i}.pdf", + "page_number": i, + "document_type": "text", + } + ) + return pd.DataFrame(rows) + + +@pytest.fixture() +def lance_params(tmp_path): + return LanceDbParams(lancedb_uri=str(tmp_path / "test_lancedb"), table_name="test_table") + + +@pytest.fixture() +def vdb_params(lance_params): + return VdbUploadParams(lancedb=lance_params) + + +class CustomVDB(VDB): + """Minimal custom backend proving the operator uses the legacy VDB ADT.""" + + def __init__(self): + super().__init__() + self.create_index_calls = [] + self.write_to_index_calls = [] + + def create_index(self, **kwargs): + self.create_index_calls.append(kwargs) + + def write_to_index(self, records: list, **kwargs): + self.write_to_index_calls.append((records, kwargs)) + + def retrieval(self, queries: list, **kwargs): + return [] + + def run(self, records): + self.create_index() + self.write_to_index(records) + + +# --------------------------------------------------------------------------- +# LanceDB write path tests +# --------------------------------------------------------------------------- + + +class TestVDBUploadOperator: + def test_writes_records_to_lancedb(self, vdb_params, lance_params): + """Operator writes canonical VDB records during process().""" + df = _make_embedded_df(3) + op = VDBUploadOperator(params=vdb_params) + result = op.run(df) + + assert isinstance(result, pd.DataFrame) + assert len(result) == 3 + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 3 + + def test_multiple_batches_accumulate(self, vdb_params, lance_params): + """Multiple process() calls append rows, not overwrite.""" + op = VDBUploadOperator(params=vdb_params) + op.run(_make_embedded_df(2)) + op.run(_make_embedded_df(3)) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 5 + + def test_empty_dataframe_is_noop(self, vdb_params): + """Empty DataFrame doesn't create a table or crash.""" + op = VDBUploadOperator(params=vdb_params) + result = op.run(pd.DataFrame()) + assert isinstance(result, pd.DataFrame) + assert len(result) == 0 + assert op._client_vdb is None + + def test_no_embeddings_is_noop(self, vdb_params): + """DataFrame without embedding columns produces no VDB records.""" + df = pd.DataFrame({"text": ["hello"], "path": ["/test.txt"]}) + op = VDBUploadOperator(params=vdb_params) + result = op.run(df) + assert len(result) == 1 + assert op._client_vdb is None + + def test_accepts_lance_params_directly(self, lance_params): + """Operator accepts LanceDbParams in addition to VdbUploadParams.""" + df = _make_embedded_df(1) + op = VDBUploadOperator(params=lance_params) + op.run(df) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 1 + + def test_default_params(self): + """Operator works with no params (uses defaults).""" + op = VDBUploadOperator() + assert op._lance_params is not None + + def test_preprocess_extracts_records(self, vdb_params): + """preprocess populates _pending_records from the DataFrame.""" + df = _make_embedded_df(3) + op = VDBUploadOperator(params=vdb_params) + result = op.preprocess(df) + + assert result is df + assert len(op._pending_records) == 3 + assert all("vector" in r for r in op._pending_records) + + +# --------------------------------------------------------------------------- +# Arrow serialization compat +# --------------------------------------------------------------------------- + + +class TestArrowSerializationCompat: + """Regression tests for Arrow-serialized dict columns (Ray Data pipeline).""" + + def _make_string_encoded_df(self, n: int = 3, dim: int = 4) -> pd.DataFrame: + """Build a DataFrame where dict columns are JSON strings, simulating Arrow round-trip.""" + rows = [] + for i in range(n): + embedding = [float(i + j) for j in range(dim)] + metadata = json.dumps( + { + "embedding": embedding, + "source_path": f"/data/doc_{i}.pdf", + "content_metadata": {"hierarchy": {"page": i}}, + } + ) + embed_payload = json.dumps({"embedding": embedding, "info_msg": None}) + rows.append( + { + "metadata": metadata, + "text_embeddings_1b_v2": embed_payload, + "text": f"content of page {i}", + "path": f"/data/doc_{i}.pdf", + "page_number": i, + "document_type": "text", + } + ) + return pd.DataFrame(rows) + + def test_writes_records_with_string_metadata(self, vdb_params, lance_params): + """Operator handles Arrow-serialized string columns (not dicts).""" + df = self._make_string_encoded_df(3) + op = VDBUploadOperator(params=vdb_params) + result = op.run(df) + + assert len(result) == 3 + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 3 + + def test_writes_records_with_string_embed_column_only(self, vdb_params, lance_params): + """Embedding extracted from string-encoded embedding column.""" + embedding = [1.0, 2.0, 3.0, 4.0] + df = pd.DataFrame( + [ + { + "metadata": {"source_path": "/test.pdf"}, + "text_embeddings_1b_v2": json.dumps({"embedding": embedding}), + "text": "hello", + "path": "/test.pdf", + "page_number": 0, + } + ] + ) + op = VDBUploadOperator(params=vdb_params) + op.run(df) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 1 + + +# --------------------------------------------------------------------------- +# Milvus write path (mocked client) +# --------------------------------------------------------------------------- + + +class TestVDBUploadMilvus: + """VDBUploadOperator wrapping a Milvus client VDB.""" + + def _make_milvus_vdb(self): + from nv_ingest_client.util.vdb import get_vdb_op_cls + + Milvus = get_vdb_op_cls("milvus") + return Milvus( + milvus_uri="http://localhost:19530", + collection_name="test", + dense_dim=4, + recreate=True, + gpu_index=False, + sparse=False, + ) + + def test_streams_directly_to_milvus_client(self): + milvus_vdb = self._make_milvus_vdb() + op = VDBUploadOperator(vdb_op=milvus_vdb) + fake_client = MagicMock() + fake_client_cls = MagicMock(return_value=fake_client) + create_collection = MagicMock() + cleanup_records = MagicMock( + return_value=[ + {"text": "content 0", "vector": [0.0, 1.0, 2.0, 3.0]}, + {"text": "content 1", "vector": [1.0, 2.0, 3.0, 4.0]}, + {"text": "content 2", "vector": [2.0, 3.0, 4.0, 5.0]}, + ] + ) + + with patch.object( + VDBUploadOperator, + "_load_milvus_write_helpers", + return_value=(fake_client_cls, cleanup_records, create_collection, MagicMock()), + ): + op.run(_make_embedded_df(3)) + + create_collection.assert_called_once() + fake_client_cls.assert_called_once_with(uri="http://localhost:19530") + cleanup_records.assert_called_once() + fake_client.insert.assert_called_once() + assert fake_client.insert.call_args.kwargs["collection_name"] == "test" + assert len(fake_client.insert.call_args.kwargs["data"]) == 3 + + def test_multiple_batches_reuse_milvus_client(self): + milvus_vdb = self._make_milvus_vdb() + op = VDBUploadOperator(vdb_op=milvus_vdb) + fake_client = MagicMock() + fake_client_cls = MagicMock(return_value=fake_client) + create_collection = MagicMock() + cleanup_records = MagicMock(return_value=[{"text": "content", "vector": [0.0, 1.0, 2.0, 3.0]}]) + + with patch.object( + VDBUploadOperator, + "_load_milvus_write_helpers", + return_value=(fake_client_cls, cleanup_records, create_collection, MagicMock()), + ): + op.run(_make_embedded_df(2)) + op.run(_make_embedded_df(3)) + + create_collection.assert_called_once() + fake_client_cls.assert_called_once_with(uri="http://localhost:19530") + assert fake_client.insert.call_count == 2 + + def test_milvus_sparse_streaming_is_explicitly_unsupported(self): + milvus_vdb = self._make_milvus_vdb() + milvus_vdb.sparse = True + op = VDBUploadOperator(vdb_op=milvus_vdb) + + with patch.object( + VDBUploadOperator, + "_load_milvus_write_helpers", + return_value=(MagicMock(), MagicMock(), MagicMock(), MagicMock()), + ): + with pytest.raises(NotImplementedError, match="sparse/hybrid"): + op.run(_make_embedded_df(1)) + + +# --------------------------------------------------------------------------- +# Pre-constructed VDB injection +# --------------------------------------------------------------------------- + + +class TestPreConstructedVDB: + """Operator accepts a pre-built client VDB instance (lead's review ask).""" + + def test_accepts_client_lancedb_instance(self, lance_params): + """Passing a client LanceDB object skips internal construction.""" + from nv_ingest_client.util.vdb import get_vdb_op_cls + + LanceDB = get_vdb_op_cls("lancedb") + client_vdb = LanceDB( + uri=lance_params.lancedb_uri, + table_name=lance_params.table_name, + overwrite=True, + ) + + op = VDBUploadOperator(vdb_op=client_vdb) + assert op._client_vdb is client_vdb + assert op._backend_name == "lancedb" + + op.run(_make_embedded_df(2)) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 2 + + def test_accepts_arbitrary_vdb_instance(self): + """A non-LanceDB custom instance routes through ADT methods only.""" + mock_vdb = CustomVDB() + + op = VDBUploadOperator(vdb_op=mock_vdb) + assert op._backend_name == "customvdb" + assert op._client_vdb is mock_vdb + + op.run(_make_embedded_df(2)) + + assert mock_vdb.create_index_calls == [{}] + assert len(mock_vdb.write_to_index_calls) == 1 + written_records, write_kwargs = mock_vdb.write_to_index_calls[0] + assert write_kwargs == {} + assert isinstance(written_records, list) + assert isinstance(written_records[0], list) + assert written_records[0][0]["document_type"] == "text" + assert not hasattr(mock_vdb, "get_connection_params") + assert not hasattr(mock_vdb, "get_write_params") + + def test_constructor_kwargs_round_trip(self): + """get_constructor_kwargs captures both params and vdb_op for Ray reconstruction.""" + mock_vdb = CustomVDB() + + op = VDBUploadOperator(params=VdbUploadParams(), vdb_op=mock_vdb) + kwargs = op.get_constructor_kwargs() + assert "vdb_op" in kwargs + assert kwargs["vdb_op"] is mock_vdb + assert "params" in kwargs + + +# --------------------------------------------------------------------------- +# Finalization +# --------------------------------------------------------------------------- + + +class TestVDBFinalization: + def test_lancedb_table_accessible_after_writes(self, vdb_params, lance_params): + """After operator writes, the table can be opened from disk.""" + df = _make_embedded_df(5) + op = VDBUploadOperator(params=vdb_params) + op.run(df) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 5 + + def test_lancedb_finalize_is_noop(self, lance_params): + from nv_ingest_client.util.vdb import get_vdb_op_cls + + LanceDB = get_vdb_op_cls("lancedb") + op = VDBUploadOperator(vdb_op=LanceDB(uri=lance_params.lancedb_uri, table_name=lance_params.table_name)) + + with patch.object(op, "_finalize_milvus") as finalize_milvus: + op.finalize() + + finalize_milvus.assert_not_called() + + def test_milvus_finalize_waits_once(self): + from nv_ingest_client.util.vdb import get_vdb_op_cls + + Milvus = get_vdb_op_cls("milvus") + milvus_vdb = Milvus( + milvus_uri="http://localhost:19530", + collection_name="test", + dense_dim=4, + recreate=True, + gpu_index=False, + sparse=False, + ) + op = VDBUploadOperator(vdb_op=milvus_vdb) + fake_client = MagicMock() + fake_client.has_collection.return_value = True + fake_client.get_collection_stats.return_value = {"row_count": 7} + fake_client.list_indexes.return_value = ["dense_index", "sparse_index"] + fake_client_cls = MagicMock(return_value=fake_client) + wait_for_index = MagicMock() + + with patch.object( + VDBUploadOperator, + "_load_milvus_finalize_helpers", + return_value=(fake_client_cls, wait_for_index), + ): + op.finalize() + + fake_client_cls.assert_called_once_with(uri="http://localhost:19530") + fake_client.flush.assert_called_once_with("test") + wait_for_index.assert_called_once_with( + "test", + {"dense_index": 7, "sparse_index": 7}, + fake_client, + ) + + def test_milvus_finalize_constructs_params_backend(self): + params = VdbUploadParams(backend="milvus", client_vdb_kwargs={"collection_name": "test"}) + op = VDBUploadOperator(params=params) + fake_vdb = MagicMock() + fake_cls = MagicMock(return_value=fake_vdb) + + with ( + patch("nv_ingest_client.util.vdb.get_vdb_op_cls", return_value=fake_cls), + patch.object(VDBUploadOperator, "_is_milvus_vdb", return_value=True), + patch.object(op, "_finalize_milvus") as finalize_milvus, + ): + op.finalize() + + fake_cls.assert_called_once_with(collection_name="test") + assert op._client_vdb is fake_vdb + finalize_milvus.assert_called_once_with()