From 9ed88eca8dae3ead858305e9979736a5d33736a2 Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Tue, 14 Apr 2026 16:10:54 +0000 Subject: [PATCH 1/6] vector_store: stream VDB uploads in graph pipeline via client VDB wrapper - VDBUploadOperator wraps client VDB classes (LanceDB, Milvus) as a streaming graph stage with concurrency=1 and batch_size=64 - Preprocess extracts canonical records; process writes per-backend - Finalization delegates to client LanceDB.write_to_index for indexing - jp20 recall@5=0.8783 (parity), PPS=21.79 (parity with baseline 21.50) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Jacob Ioffe --- .../nemo_retriever/examples/graph_pipeline.py | 82 ++-- .../src/nemo_retriever/graph/__init__.py | 2 + .../nemo_retriever/graph/ingestor_runtime.py | 21 +- .../graph/vdb_upload_operator.py | 216 +++++++++++ .../src/nemo_retriever/graph_ingestor.py | 89 ++++- .../src/nemo_retriever/params/models.py | 2 + .../src/nemo_retriever/recall/core.py | 7 +- .../nemo_retriever/text_embed/processor.py | 5 +- .../nemo_retriever/vector_store/__init__.py | 5 +- .../vector_store/lancedb_store.py | 239 +++--------- .../vector_store/lancedb_utils.py | 33 +- .../src/nemo_retriever/vector_store/stage.py | 9 +- .../vector_store/vdb_records.py | 79 ++++ .../tests/test_vdb_record_contract.py | 352 ++++++++++++++++++ .../tests/test_vdb_upload_operator.py | 267 +++++++++++++ 15 files changed, 1160 insertions(+), 248 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py create mode 100644 nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py create mode 100644 nemo_retriever/tests/test_vdb_record_contract.py create mode 100644 nemo_retriever/tests/test_vdb_upload_operator.py diff --git a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py index a65778abe..770d7aa6f 100644 --- a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py @@ -48,9 +48,9 @@ from nemo_retriever.params import ExtractParams from nemo_retriever.params import StoreParams from nemo_retriever.params import TextChunkParams -from nemo_retriever.params.models import BatchTuningParams +from nemo_retriever.params import VdbUploadParams +from nemo_retriever.params.models import BatchTuningParams, LanceDbParams from nemo_retriever.utils.remote_auth import resolve_remote_api_key -from nemo_retriever.vector_store.lancedb_store import handle_lancedb logger = logging.getLogger(__name__) app = typer.Typer() @@ -119,23 +119,6 @@ def _configure_logging(log_file: Optional[Path], *, debug: bool = False) -> tupl return fh, original_stdout, original_stderr -def _ensure_lancedb_table(uri: str, table_name: str) -> None: - from nemo_retriever.vector_store.lancedb_utils import lancedb_schema - import lancedb - import pyarrow as pa - - Path(uri).mkdir(parents=True, exist_ok=True) - db = lancedb.connect(uri) - try: - db.open_table(table_name) - return - except Exception: - pass - schema = lancedb_schema() - empty = pa.table({f.name: [] for f in schema}, schema=schema) - db.create_table(table_name, data=empty, schema=schema, mode="create") - - def _write_runtime_summary( runtime_metrics_dir: Optional[Path], runtime_metrics_prefix: Optional[str], @@ -314,7 +297,6 @@ def main( os.environ["RAY_LOG_TO_DRIVER"] = "1" if ray_log_to_driver else "0" lancedb_uri = str(Path(lancedb_uri).expanduser().resolve()) - _ensure_lancedb_table(lancedb_uri, LANCEDB_TABLE) remote_api_key = resolve_remote_api_key(api_key) extract_remote_api_key = remote_api_key @@ -502,6 +484,21 @@ def main( ingestor = ingestor.embed(embed_params) + # VDB upload runs inside the graph — rows stream to LanceDB as they + # are produced, so we never need to collect the entire result set on + # the driver just for the LanceDB write. Index creation happens + # automatically in GraphIngestor._finalize_vdb() after the pipeline. + ingestor = ingestor.vdb_upload( + VdbUploadParams( + lancedb=LanceDbParams( + lancedb_uri=lancedb_uri, + table_name=LANCEDB_TABLE, + hybrid=hybrid, + overwrite=True, + ), + ) + ) + # ------------------------------------------------------------------ # Execute the graph via the executor # ------------------------------------------------------------------ @@ -509,7 +506,7 @@ def main( ingest_start = time.perf_counter() # GraphIngestor.ingest() builds the Graph, creates the executor, - # and calls executor.ingest(file_patterns) returning: + # calls executor.ingest(file_patterns), and finalizes the VDB index. # batch mode -> materialized ray.data.Dataset # inprocess mode -> pandas.DataFrame result = ingestor.ingest() @@ -517,28 +514,32 @@ def main( ingestion_only_total_time = time.perf_counter() - ingest_start # ------------------------------------------------------------------ - # Collect results + # Collect results (only when needed for detection summary / counting) # ------------------------------------------------------------------ if run_mode == "batch": import ray - ray_download_start = time.perf_counter() - ingest_local_results = result.take_all() - ray_download_time = time.perf_counter() - ray_download_start + if detection_summary_file is not None: + ray_download_start = time.perf_counter() + ingest_local_results = result.take_all() + ray_download_time = time.perf_counter() - ray_download_start - import pandas as pd + import pandas as pd - result_df = pd.DataFrame(ingest_local_results) - num_rows = _count_input_units(result_df) + result_df = pd.DataFrame(ingest_local_results) + num_rows = _count_input_units(result_df) + else: + ray_download_time = 0.0 + result_df = None + num_rows = result.count() else: import pandas as pd result_df = result - ingest_local_results = result_df.to_dict("records") ray_download_time = 0.0 num_rows = _count_input_units(result_df) - if detection_summary_file is not None: + if detection_summary_file is not None and result_df is not None: from nemo_retriever.utils.detection_summary import ( collect_detection_summary_from_df, write_detection_summary, @@ -549,13 +550,6 @@ def main( collect_detection_summary_from_df(result_df), ) - # ------------------------------------------------------------------ - # Write to LanceDB - # ------------------------------------------------------------------ - lancedb_write_start = time.perf_counter() - handle_lancedb(ingest_local_results, lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite") - lancedb_write_time = time.perf_counter() - lancedb_write_start - # ------------------------------------------------------------------ # Recall / BEIR evaluation # ------------------------------------------------------------------ @@ -574,10 +568,10 @@ def main( "input_path": str(Path(input_path).resolve()), "input_pages": int(num_rows), "num_pages": int(num_rows), - "num_rows": int(len(result_df.index)), + "num_rows": int(num_rows), "ingestion_only_secs": float(ingestion_only_total_time), "ray_download_secs": float(ray_download_time), - "lancedb_write_secs": float(lancedb_write_time), + "lancedb_write_secs": 0.0, "evaluation_secs": 0.0, "total_secs": float(time.perf_counter() - ingest_start), "evaluation_mode": evaluation_mode, @@ -641,10 +635,10 @@ def main( "input_path": str(Path(input_path).resolve()), "input_pages": int(num_rows), "num_pages": int(num_rows), - "num_rows": int(len(result_df.index)), + "num_rows": int(num_rows), "ingestion_only_secs": float(ingestion_only_total_time), "ray_download_secs": float(ray_download_time), - "lancedb_write_secs": float(lancedb_write_time), + "lancedb_write_secs": 0.0, "evaluation_secs": 0.0, "total_secs": float(time.perf_counter() - ingest_start), "evaluation_mode": evaluation_mode, @@ -690,10 +684,10 @@ def main( "input_path": str(Path(input_path).resolve()), "input_pages": int(num_rows), "num_pages": int(num_rows), - "num_rows": int(len(result_df.index)), + "num_rows": int(num_rows), "ingestion_only_secs": float(ingestion_only_total_time), "ray_download_secs": float(ray_download_time), - "lancedb_write_secs": float(lancedb_write_time), + "lancedb_write_secs": 0.0, "evaluation_secs": float(evaluation_total_time), "total_secs": float(total_time), "evaluation_mode": evaluation_mode, @@ -717,7 +711,7 @@ def main( total_time, ingestion_only_total_time, ray_download_time, - lancedb_write_time, + 0.0, evaluation_total_time, evaluation_metrics, evaluation_label=evaluation_label, diff --git a/nemo_retriever/src/nemo_retriever/graph/__init__.py b/nemo_retriever/src/nemo_retriever/graph/__init__.py index 729b2c001..740f2a798 100644 --- a/nemo_retriever/src/nemo_retriever/graph/__init__.py +++ b/nemo_retriever/src/nemo_retriever/graph/__init__.py @@ -16,6 +16,7 @@ from nemo_retriever.graph.graph_pipeline_registry import GraphPipelineRegistry, default_registry from nemo_retriever.graph.pipeline_graph import Graph, Node from nemo_retriever.graph.store_operator import StoreOperator +from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator __all__ = [ "AbstractExecutor", @@ -32,6 +33,7 @@ "RayDataExecutor", "StoreOperator", "UDFOperator", + "VDBUploadOperator", "default_registry", ] diff --git a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py index 8a3193814..96e703f55 100644 --- a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py +++ b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py @@ -15,7 +15,7 @@ from nemo_retriever.audio import ASRActor from nemo_retriever.audio import MediaChunkActor from nemo_retriever.dedup.dedup import dedup_images -from nemo_retriever.graph import Graph, StoreOperator, UDFOperator +from nemo_retriever.graph import Graph, StoreOperator, UDFOperator, VDBUploadOperator from nemo_retriever.graph.content_transforms import ( _CONTENT_COLUMNS, collapse_content_to_page_rows, @@ -243,6 +243,7 @@ def _resolve_execution_inputs( caption_params: Any | None, store_params: Any | None, embed_params: Any | None, + vdb_upload_params: Any | None, stage_order: tuple[str, ...], ) -> tuple[ str, @@ -256,6 +257,7 @@ def _resolve_execution_inputs( Any | None, Any | None, Any | None, + Any | None, tuple[str, ...], ]: """Resolve legacy builder args or a shared execution plan into one input tuple.""" @@ -273,10 +275,12 @@ def _resolve_execution_inputs( caption_params, store_params, embed_params, + vdb_upload_params, stage_order, ) stage_map = {stage.name: stage.params for stage in execution_plan.stages} + sink_map = {sink.name: sink.params for sink in execution_plan.sinks} return ( execution_plan.extraction_mode, execution_plan.extract_params, @@ -289,7 +293,8 @@ def _resolve_execution_inputs( stage_map.get("caption"), stage_map.get("store"), stage_map.get("embed"), - tuple(stage.name for stage in execution_plan.stages), + sink_map.get("vdb_upload"), + tuple(stage.name for stage in execution_plan.stages) + tuple(sink.name for sink in execution_plan.sinks), ) @@ -315,6 +320,7 @@ def _append_ordered_transform_stages( caption_params: Any | None, store_params: Any | None, embed_params: Any | None, + vdb_upload_params: Any | None, stage_order: tuple[str, ...], supports_dedup: bool, reshape_for_modal_content: bool, @@ -324,7 +330,8 @@ def _append_ordered_transform_stages( pending_stages = [ stage for stage in stage_order - if stage in {"dedup", "split", "caption", "store", "embed"} and (supports_dedup or stage != "dedup") + if stage in {"dedup", "split", "caption", "store", "embed", "vdb_upload"} + and (supports_dedup or stage != "dedup") ] if not pending_stages: if supports_dedup and dedup_params is not None: @@ -337,6 +344,8 @@ def _append_ordered_transform_stages( pending_stages.append("split") if embed_params is not None: pending_stages.append("embed") + if vdb_upload_params is not None: + pending_stages.append("vdb_upload") for stage_name in pending_stages: if stage_name == "store" and store_params is not None: @@ -374,6 +383,8 @@ def _append_ordered_transform_stages( name="ExplodeContentToRows", ) graph = graph >> _BatchEmbedActor(params=embed_params) + elif stage_name == "vdb_upload" and vdb_upload_params is not None: + graph = graph >> VDBUploadOperator(params=vdb_upload_params) return graph @@ -392,6 +403,7 @@ def build_graph( split_params: Any | None = None, caption_params: Any | None = None, store_params: Any | None = None, + vdb_upload_params: Any | None = None, stage_order: tuple[str, ...] = (), ) -> Graph: """Build a batch graph from explicit params or a shared execution plan.""" @@ -408,6 +420,7 @@ def build_graph( caption_params, store_params, embed_params, + vdb_upload_params, stage_order, ) = _resolve_execution_inputs( execution_plan=execution_plan, @@ -422,6 +435,7 @@ def build_graph( caption_params=caption_params, store_params=store_params, embed_params=embed_params, + vdb_upload_params=vdb_upload_params, stage_order=stage_order, ) @@ -552,6 +566,7 @@ def build_graph( caption_params=caption_params, store_params=store_params, embed_params=embed_params, + vdb_upload_params=vdb_upload_params, stage_order=stage_order, supports_dedup=True, reshape_for_modal_content=True, diff --git a/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py new file mode 100644 index 000000000..762c5ed50 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py @@ -0,0 +1,216 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Graph operator for streaming VDB uploads during pipeline execution. + +Wraps existing client VDB classes (``nv_ingest_client.util.vdb``) so that +any backend implementing the client :class:`VDB` ABC can be used as a +pipeline sink. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, List, Sequence + +from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.graph.cpu_operator import CPUOperator + +logger = logging.getLogger(__name__) + + +def _canonical_to_nvingest(rows: Sequence[Dict[str, Any]]) -> List[List[Dict[str, Any]]]: + """Convert canonical VDB records to NV-Ingest pipeline format. + + The client Milvus implementation expects ``list[list[dict]]`` where each + inner dict has ``document_type`` and a nested ``metadata`` dict with keys + ``embedding``, ``content``, ``content_metadata``, ``source_metadata``. + """ + elements: List[Dict[str, Any]] = [] + for row in rows: + meta_str = row.get("metadata", "{}") + source_str = row.get("source", "{}") + try: + content_metadata = json.loads(meta_str) if isinstance(meta_str, str) else meta_str + except (json.JSONDecodeError, TypeError): + content_metadata = {} + try: + source_metadata = json.loads(source_str) if isinstance(source_str, str) else source_str + except (json.JSONDecodeError, TypeError): + source_metadata = {} + + elements.append( + { + "document_type": "text", + "metadata": { + "content": row.get("text", ""), + "embedding": row.get("vector"), + "content_metadata": content_metadata, + "source_metadata": source_metadata, + }, + } + ) + return [elements] if elements else [] + + +class VDBUploadOperator(AbstractOperator, CPUOperator): + """Write pipeline embeddings to a vector store as data flows through the graph. + + Wraps a client VDB instance (LanceDB, Milvus, etc.) from + ``nv_ingest_client.util.vdb``. ``preprocess`` extracts canonical + records from the DataFrame (backend-agnostic); ``process`` converts + to the target format and writes (backend-specific). + + The DataFrame is passed through unchanged — this is a side-effect + operator. + + **Concurrency**: This operator must run with ``concurrency=1`` in batch + mode. The single actor creates the table on its first write (respecting + ``overwrite``) and appends on subsequent writes. Index creation happens + post-pipeline via the client VDB's ``write_to_index`` called from the + driver. + """ + + def __init__( + self, + *, + params: Any = None, + ) -> None: + super().__init__() + from nemo_retriever.params.models import LanceDbParams, VdbUploadParams + + # Store as self.params so get_constructor_kwargs() can capture it + # for deferred reconstruction on Ray workers. + self.params = params + + if isinstance(params, VdbUploadParams): + self._vdb_params = params + self._lance_params = params.lancedb + elif isinstance(params, LanceDbParams): + self._vdb_params = None + self._lance_params = params + else: + self._vdb_params = VdbUploadParams() + self._lance_params = self._vdb_params.lancedb + + self._backend_name: str = getattr(self._vdb_params, "backend", "lancedb") if self._vdb_params else "lancedb" + self._client_vdb: Any = None + self._table: Any = None + self._pending_records: List[Dict[str, Any]] = [] + + # ------------------------------------------------------------------ + # Client VDB construction + # ------------------------------------------------------------------ + + def _create_client_vdb(self) -> Any: + """Lazily construct the client VDB instance.""" + from nv_ingest_client.util.vdb import get_vdb_op_cls + + if self._backend_name == "lancedb": + LanceDB = get_vdb_op_cls("lancedb") + return LanceDB( + uri=self._lance_params.lancedb_uri, + table_name=self._lance_params.table_name, + overwrite=self._lance_params.overwrite, + index_type=self._lance_params.index_type, + metric=self._lance_params.metric, + num_partitions=self._lance_params.num_partitions, + num_sub_vectors=self._lance_params.num_sub_vectors, + hybrid=self._lance_params.hybrid, + fts_language=self._lance_params.fts_language, + ) + + kwargs = getattr(self._vdb_params, "client_vdb_kwargs", {}) or {} + return get_vdb_op_cls(self._backend_name)(**kwargs) + + # ------------------------------------------------------------------ + # Operator lifecycle + # ------------------------------------------------------------------ + + def preprocess(self, data: Any, **kwargs: Any) -> Any: + """Extract canonical VDB records from the DataFrame (backend-agnostic).""" + import pandas as pd + + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + if not isinstance(data, pd.DataFrame) or data.empty: + self._pending_records = [] + return data + + self._pending_records = build_vdb_records( + data, + embedding_column=self._lance_params.embedding_column, + embedding_key=self._lance_params.embedding_key, + include_text=self._lance_params.include_text, + text_column=self._lance_params.text_column, + ) + return data + + def process(self, data: Any, **kwargs: Any) -> Any: + """Write pending records to the backend (backend-specific).""" + if not self._pending_records: + return data + + if self._client_vdb is None: + self._client_vdb = self._create_client_vdb() + + if self._backend_name == "lancedb": + self._write_lancedb_batch(self._pending_records) + else: + self._write_via_client(self._pending_records) + + logger.debug("VDBUploadOperator: wrote %d records", len(self._pending_records)) + self._pending_records = [] + return data + + def postprocess(self, data: Any, **kwargs: Any) -> Any: + return data + + # ------------------------------------------------------------------ + # LanceDB streaming write path + # ------------------------------------------------------------------ + + def _write_lancedb_batch(self, records: List[Dict[str, Any]]) -> None: + """Stream records to LanceDB via table.add(). + + The client LanceDB class is used for config (uri, table_name, + overwrite) and post-pipeline index creation, but its + ``create_index()`` writes all records at once with no append + support. For streaming writes we call the lancedb library + directly. + """ + if self._table is None: + import lancedb + + from nemo_retriever.vector_store.lancedb_utils import infer_vector_dim, lancedb_schema + + dim = infer_vector_dim(records) + schema = lancedb_schema(vector_dim=dim) + mode = "overwrite" if self._client_vdb.overwrite else "create" + db = lancedb.connect(uri=self._client_vdb.uri) + self._table = db.create_table( + self._client_vdb.table_name, + schema=schema, + mode=mode, + ) + + self._table.add(records) + + # ------------------------------------------------------------------ + # Non-LanceDB write path (Milvus, OpenSearch, etc.) + # ------------------------------------------------------------------ + + def _write_via_client(self, records: List[Dict[str, Any]]) -> None: + """Convert canonical records to NV-Ingest format and delegate to the client VDB.""" + nvingest_records = _canonical_to_nvingest(records) + if not nvingest_records: + return + + if self._table is None: + # First batch — create collection schema. + self._client_vdb.create_index() + self._table = True # sentinel: schema created + + self._client_vdb.write_to_index(nvingest_records) diff --git a/nemo_retriever/src/nemo_retriever/graph_ingestor.py b/nemo_retriever/src/nemo_retriever/graph_ingestor.py index c9e3434c7..237c05597 100644 --- a/nemo_retriever/src/nemo_retriever/graph_ingestor.py +++ b/nemo_retriever/src/nemo_retriever/graph_ingestor.py @@ -43,6 +43,7 @@ HtmlChunkParams, StoreParams, TextChunkParams, + VdbUploadParams, ) from nemo_retriever.utils.remote_auth import resolve_remote_api_key @@ -146,6 +147,7 @@ def __init__( self._caption_params: Any = None self._dedup_params: Any = None self._store_params: Any = None + self._vdb_upload_params: Any = None # Ordered list of stage names; "extract" is tracked but excluded from # the post-extraction stage_order passed to graph builders. self._stage_order: List[str] = [] @@ -239,6 +241,17 @@ def embed(self, params: Optional[EmbedParams] = None, **kwargs: Any) -> "GraphIn self._record_stage("embed") return self + def vdb_upload(self, params: Optional[VdbUploadParams] = None, **kwargs: Any) -> "GraphIngestor": + """Record a VDB upload sink. + + Writes embeddings to a vector store as data flows through the graph, + eliminating the need to collect all results on the driver. Index + creation happens automatically after the pipeline completes. + """ + self._vdb_upload_params = _coerce(params, kwargs, default_factory=VdbUploadParams) + self._record_stage("vdb_upload") + return self + # ------------------------------------------------------------------ # Execution # ------------------------------------------------------------------ @@ -287,6 +300,7 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: caption_params=self._caption_params, dedup_params=self._dedup_params, store_params=self._store_params, + vdb_upload_params=self._vdb_upload_params, stage_order=post_extract_order, ) # Derive per-node Ray scheduling config from BatchTuningParams plus @@ -295,6 +309,16 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: derived_overrides = batch_tuning_to_node_overrides( self._extract_params, self._embed_params, cluster_resources=cluster_resources ) + # VDBUploadOperator must run with concurrency=1 to avoid table + # creation races — a single actor creates the table on its first + # write and appends on subsequent batches. A large batch_size + # amortises per-call LanceDB disk I/O across many rows. + if self._vdb_upload_params is not None: + from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator + + vdb_overrides = derived_overrides.setdefault(VDBUploadOperator.__name__, {}) + vdb_overrides["concurrency"] = 1 + vdb_overrides["batch_size"] = 64 merged_overrides: Dict[str, Dict[str, Any]] = {} for node_name in set(derived_overrides) | set(self._node_overrides): merged_overrides[node_name] = { @@ -311,6 +335,7 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: ) result = executor.ingest(self._documents) self._rd_dataset = result + self._finalize_vdb() return result else: graph = build_graph( @@ -325,11 +350,14 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: caption_params=self._caption_params, dedup_params=self._dedup_params, store_params=self._store_params, + vdb_upload_params=self._vdb_upload_params, stage_order=post_extract_order, ) executor = InprocessExecutor(graph, show_progress=self._show_progress) self._rd_dataset = None - return executor.ingest(self._documents) + result = executor.ingest(self._documents) + self._finalize_vdb() + return result # ------------------------------------------------------------------ # Internal helpers @@ -404,6 +432,65 @@ def get_error_rows(self, dataset: Any = None) -> Any: def get_dataset(self) -> Any: return self._rd_dataset + def _finalize_vdb(self) -> None: + """Create VDB indices after the pipeline writes all rows. + + Delegates to the client VDB class for index creation. For LanceDB + the client's ``write_to_index(table=table)`` builds the ANN (and + optionally FTS) index. For other backends index creation is handled + during streaming writes. + """ + if self._vdb_upload_params is None: + return + + from nemo_retriever.params.models import LanceDbParams, VdbUploadParams + + if isinstance(self._vdb_upload_params, VdbUploadParams): + backend_type = self._vdb_upload_params.backend + lance_params = self._vdb_upload_params.lancedb + elif isinstance(self._vdb_upload_params, LanceDbParams): + backend_type = "lancedb" + lance_params = self._vdb_upload_params + else: + return + + if backend_type != "lancedb": + return + + if not lance_params.create_index: + return + + import lancedb + + from nv_ingest_client.util.vdb import get_vdb_op_cls + + try: + db = lancedb.connect(uri=lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + except Exception: + return + + ClientLanceDB = get_vdb_op_cls("lancedb") + client = ClientLanceDB( + uri=lance_params.lancedb_uri, + table_name=lance_params.table_name, + overwrite=False, + index_type=lance_params.index_type, + metric=lance_params.metric, + num_partitions=lance_params.num_partitions, + num_sub_vectors=lance_params.num_sub_vectors, + hybrid=lance_params.hybrid, + fts_language=lance_params.fts_language, + ) + try: + client.write_to_index(None, table=table) + except RuntimeError: + logger.warning( + "Index creation failed (likely too few rows for %d partitions); skipping.", + lance_params.num_partitions, + exc_info=True, + ) + def _record_stage(self, name: str) -> None: """Append *name* to the stage order list (deduplicated in place).""" if name not in self._stage_order: diff --git a/nemo_retriever/src/nemo_retriever/params/models.py b/nemo_retriever/src/nemo_retriever/params/models.py index 13d02eef2..77b869320 100644 --- a/nemo_retriever/src/nemo_retriever/params/models.py +++ b/nemo_retriever/src/nemo_retriever/params/models.py @@ -271,6 +271,8 @@ def _warn_page_granularity_overrides(self) -> "EmbedParams": class VdbUploadParams(_ParamsModel): purge_results_after_upload: bool = True lancedb: LanceDbParams = Field(default_factory=LanceDbParams) + backend: str = "lancedb" # "lancedb" | "milvus" | "opensearch" + client_vdb_kwargs: dict[str, Any] = Field(default_factory=dict) class StoreParams(_ParamsModel): diff --git a/nemo_retriever/src/nemo_retriever/recall/core.py b/nemo_retriever/src/nemo_retriever/recall/core.py index 9076ad33e..f0e7687bd 100644 --- a/nemo_retriever/src/nemo_retriever/recall/core.py +++ b/nemo_retriever/src/nemo_retriever/recall/core.py @@ -305,10 +305,13 @@ def _hits_to_keys(raw_hits: List[List[Dict[str, Any]]]) -> List[List[str]]: keys: List[str] = [] for h in hits: page_number = h["page_number"] - source = h["source"] + raw_source = h["source"] + # source may be a bare path string or a JSON object {"source_id": "..."}. + source_map = _parse_mapping(raw_source) + source = source_map.get("source_id", raw_source) if source_map else raw_source # Prefer explicit `pdf_page` column; fall back to derived form. if page_number is not None and source: - filename = Path(source).stem + filename = Path(str(source)).stem keys.append(f"{filename}_{str(page_number)}") else: logger.warning( diff --git a/nemo_retriever/src/nemo_retriever/text_embed/processor.py b/nemo_retriever/src/nemo_retriever/text_embed/processor.py index 81dd4b8a6..662829112 100644 --- a/nemo_retriever/src/nemo_retriever/text_embed/processor.py +++ b/nemo_retriever/src/nemo_retriever/text_embed/processor.py @@ -14,7 +14,8 @@ from nv_ingest_api.internal.transform.embed_text import transform_create_text_embeddings_internal from nemo_retriever.io.dataframe import validate_primitives_dataframe -from nemo_retriever.vector_store.lancedb_store import LanceDBConfig, write_embeddings_to_lancedb +from nemo_retriever.params.models import LanceDbParams +from nemo_retriever.vector_store.lancedb_store import write_embeddings_to_lancedb logger = logging.getLogger(__name__) @@ -25,7 +26,7 @@ def embed_text_from_primitives_df( *, transform_config: TextEmbeddingSchema, task_config: Optional[Dict[str, Any]] = None, - lancedb: Optional[LanceDBConfig] = None, + lancedb: Optional[LanceDbParams] = None, trace_info: Optional[Dict[str, Any]] = None, ) -> Tuple[pd.DataFrame, Dict[str, Any]]: """Generate embeddings for supported content types and write to metadata.""" diff --git a/nemo_retriever/src/nemo_retriever/vector_store/__init__.py b/nemo_retriever/src/nemo_retriever/vector_store/__init__.py index 1c05e4e8f..de519b5e8 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/__init__.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/__init__.py @@ -4,15 +4,16 @@ from .__main__ import app from .lancedb_store import ( - LanceDBConfig, create_lancedb_index, write_embeddings_to_lancedb, write_text_embeddings_dir_to_lancedb, ) +from .vdb_records import build_vdb_records, build_vdb_records_from_dicts __all__ = [ "app", - "LanceDBConfig", + "build_vdb_records", + "build_vdb_records_from_dicts", "create_lancedb_index", "write_embeddings_to_lancedb", "write_text_embeddings_dir_to_lancedb", diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py index 39165a7c5..dec989ed5 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py @@ -6,43 +6,18 @@ import json import logging -from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple # noqa: F401 +from typing import Any, Dict, List, Optional from datetime import timedelta from nv_ingest_client.util.vdb.lancedb import LanceDB -from nemo_retriever.vector_store.lancedb_utils import lancedb_schema +from nemo_retriever.params.models import LanceDbParams +from nemo_retriever.vector_store.vdb_records import build_vdb_records, build_vdb_records_from_dicts import pandas as pd -import lancedb logger = logging.getLogger(__name__) -@dataclass(frozen=True) -class LanceDBConfig: - """ - Minimal config for writing embeddings into LanceDB. - - This module is intentionally lightweight: it can be used by the text-embedding - stage (`nemo_retriever.text_embed.stage`) and by the vector-store CLI (`nemo_retriever.vector_store.stage`). - """ - - uri: str = "lancedb" - table_name: str = "nv-ingest" - overwrite: bool = True - - # Optional index creation (recommended for recall/search runs). - create_index: bool = True - index_type: str = "IVF_HNSW_SQ" - metric: str = "l2" - num_partitions: int = 16 - num_sub_vectors: int = 256 - - hybrid: bool = False - fts_language: str = "English" - - def _read_text_embeddings_json_df(path: Path) -> pd.DataFrame: """ Read a `*.text_embeddings.json` file emitted by `nemo_retriever.text_embed.stage`. @@ -86,110 +61,7 @@ def _iter_text_embeddings_json_files(input_dir: Path, *, recursive: bool) -> Lis return sorted([p for p in files if p.is_file()]) -def _safe_str(x: Any) -> str: - return "" if x is None else str(x) - - -def _extract_source_path_and_id(meta: Dict[str, Any]) -> Tuple[str, str]: - """ - Extract a stable source path/id from metadata. - - Prefers: - - metadata.source_metadata.source_id - - metadata.source_metadata.source_name - - metadata.custom_content.path - """ - source = meta.get("source_metadata") if isinstance(meta.get("source_metadata"), dict) else {} - source_id = source.get("source_id") or "" - source_name = source.get("source_name") or "" - - custom = meta.get("custom_content") if isinstance(meta.get("custom_content"), dict) else {} - custom_path = custom.get("path") or custom.get("input_pdf") or custom.get("pdf_path") or "" - - path = _safe_str(custom_path or source_id or source_name) - sid = _safe_str(source_id or path or source_name) - return path, sid - - -def _extract_page_number(meta: Dict[str, Any]) -> int: - cm = meta.get("content_metadata") if isinstance(meta.get("content_metadata"), dict) else {} - page = cm.get("hierarchy", {}).get("page", -1) - try: - return int(page) - except Exception: - return -1 - - -def _build_lancedb_rows_from_df(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """ - Transform an embeddings-enriched primitives DataFrame into LanceDB rows. - - Rows include: - - vector (embedding) - - pdf_basename - - page_number - - pdf_page (basename_page) - - source_id - - path - """ - out: List[Dict[str, Any]] = [] - - for row in rows: - meta = row.get("metadata") - if not isinstance(meta, dict): - continue - - embedding = meta.get("embedding") - if embedding is None: - continue - - # Normalize embedding to list[float] - if not isinstance(embedding, list): - try: - embedding = list(embedding) # type: ignore[arg-type] - except Exception: - continue - meta.pop("embedding", None) # Remove embedding from metadata to save space in LanceDB. - # path, source_id = _extract_source_path_and_id(meta) - path = row.get("path", "") - source_id = meta.get("source_path", path) - # page_number = _extract_page_number(meta) - page_number = row.get("page_number", -1) - p = Path(path) if path else None - filename = p.name if p is not None else "" - pdf_basename = p.stem if p is not None else "" - pdf_page = f"{pdf_basename}_{page_number}" if (pdf_basename and page_number >= 0) else "" - - if page_number == -1: - logger.debug("Unable to determine page number for %s", path) - - out.append( - { - "vector": embedding, - "pdf_page": pdf_page, - "filename": filename, - "pdf_basename": pdf_basename, - "page_number": int(page_number), - "source": source_id, - "source_id": source_id, - "path": path, - "text": row.get("text", ""), - "metadata": str(meta), - } - ) - - return out - - -def _infer_vector_dim(rows: Sequence[Dict[str, Any]]) -> int: - for r in rows: - v = r.get("vector") - if isinstance(v, list) and v: - return int(len(v)) - return 0 - - -def create_lancedb_index(table: Any, *, cfg: LanceDBConfig, text_column: str = "text") -> None: +def create_lancedb_index(table: Any, *, cfg: LanceDbParams, text_column: str = "text") -> None: """Create vector (IVF_HNSW_SQ) and optionally FTS indices on a LanceDB table.""" try: table.create_index( @@ -216,48 +88,32 @@ def create_lancedb_index(table: Any, *, cfg: LanceDBConfig, text_column: str = " table.wait_for_index([index_stub.name], timeout=timedelta(seconds=600)) -def _write_rows_to_lancedb(rows: Sequence[Dict[str, Any]], *, cfg: LanceDBConfig) -> None: - if not rows: - logger.warning("No embeddings rows provided; nothing to write to LanceDB.") - return - - dim = _infer_vector_dim(rows) - if dim <= 0: - raise ValueError("Failed to infer embedding dimension from rows.") +def write_embeddings_to_lancedb(df_with_embeddings: pd.DataFrame, *, cfg: LanceDbParams) -> None: + """ + Write embeddings found in *df_with_embeddings* to LanceDB. - try: - import lancedb # type: ignore - except Exception as e: - raise RuntimeError( - "LanceDB write requested but dependencies are missing. " - "Install `lancedb` and `pyarrow` in this environment." - ) from e + This is used programmatically by ``nemo_retriever.text_embed.stage``. + """ + import lancedb - db = lancedb.connect(uri=cfg.uri) + from nemo_retriever.vector_store.lancedb_utils import infer_vector_dim, lancedb_schema + records = build_vdb_records(df_with_embeddings) + if not records: + return + dim = infer_vector_dim(records) schema = lancedb_schema(vector_dim=dim) - - mode = "overwrite" if cfg.overwrite else "append" - table = db.create_table(cfg.table_name, data=list(rows), schema=schema, mode=mode) - + mode = "overwrite" if cfg.overwrite else "create" + db = lancedb.connect(uri=cfg.lancedb_uri) + table = db.create_table(cfg.table_name, data=records, schema=schema, mode=mode) if cfg.create_index: create_lancedb_index(table, cfg=cfg) -def write_embeddings_to_lancedb(df_with_embeddings: pd.DataFrame, *, cfg: LanceDBConfig) -> None: - """ - Write embeddings found in `df_with_embeddings.metadata.embedding` to LanceDB. - - This is used programmatically by `nemo_retriever.text_embed.stage.embed_text_from_primitives_df(...)`. - """ - rows = _build_lancedb_rows_from_df(df_with_embeddings) - _write_rows_to_lancedb(rows, cfg=cfg) - - def write_text_embeddings_dir_to_lancedb( input_dir: Path, *, - cfg: LanceDBConfig, + cfg: LanceDbParams, recursive: bool = False, limit: Optional[int] = None, ) -> Dict[str, Any]: @@ -273,7 +129,7 @@ def write_text_embeddings_dir_to_lancedb( skipped = 0 failed = 0 - lancedb = LanceDB(uri=cfg.uri, table_name=cfg.table_name, overwrite=cfg.overwrite) + lancedb_client = LanceDB(uri=cfg.lancedb_uri, table_name=cfg.table_name, overwrite=cfg.overwrite) results = [] @@ -290,10 +146,10 @@ def write_text_embeddings_dir_to_lancedb( "processed": 0, "skipped": 0, "failed": 0, - "lancedb": {"uri": cfg.uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, + "lancedb": {"uri": cfg.lancedb_uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, } - lancedb.run(results) + lancedb_client.run(results) return { "input_dir": str(input_dir), @@ -301,31 +157,50 @@ def write_text_embeddings_dir_to_lancedb( "processed": processed, "skipped": skipped, "failed": failed, - # "rows_written": len(all_rows), - "lancedb": {"uri": cfg.uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, + "lancedb": {"uri": cfg.lancedb_uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, } def handle_lancedb( - rows: Path, + rows: Any, uri: str, table_name: str, hybrid: bool = False, mode: str = "overwrite", -) -> Dict[str, Any]: +) -> None: + """Write pipeline results to LanceDB. + + Accepts *rows* as a ``pd.DataFrame`` or ``list[dict]`` (e.g. from + ``take_all()``). Converts to canonical VDB records, writes to + LanceDB, and creates search indices. """ - Handle LanceDB writing for a batch pipeline run. + import lancedb - This is used by `nemo_retriever.examples.batch_pipeline.run(...)` after the embedding stage. + from nemo_retriever.vector_store.lancedb_utils import infer_vector_dim, lancedb_schema - Reads `*.text_embeddings.json` files from `input_dir`, extracts embeddings, and uploads to LanceDB. + if isinstance(rows, pd.DataFrame): + records = build_vdb_records(rows) + else: + records = build_vdb_records_from_dicts(rows) + + if not records: + return + + cfg = LanceDbParams( + lancedb_uri=uri, + table_name=table_name, + hybrid=hybrid, + overwrite=(mode == "overwrite"), ) - """ - lancedb_config = LanceDBConfig( - uri=uri, table_name=table_name, hybrid=hybrid - ) # Use the same LanceDB config for writing and recall. - db = lancedb.connect(uri=lancedb_config.uri) - cleaned_rows = _build_lancedb_rows_from_df(rows) - _write_rows_to_lancedb(cleaned_rows, cfg=lancedb_config) - table = db.open_table(lancedb_config.table_name) # Ensure table is open and metadata is updated before proceeding. - create_lancedb_index(table, cfg=lancedb_config) + dim = infer_vector_dim(records) + schema = lancedb_schema(vector_dim=dim) + db = lancedb.connect(uri=uri) + table = db.create_table(table_name, data=records, schema=schema, mode=mode) + try: + create_lancedb_index(table, cfg=cfg) + except RuntimeError: + logger.warning( + "Index creation failed (likely too few rows for %d partitions); skipping.", + cfg.num_partitions, + exc_info=True, + ) diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py index 11e117ac6..72db70ab4 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py @@ -11,6 +11,23 @@ from typing import Any, Dict, List, Optional, Tuple +def _ensure_dict(value: Any) -> Optional[Dict[str, Any]]: + """Coerce *value* to a dict, parsing JSON strings if needed. + + Arrow serialization between Ray actors can convert dict columns to + strings. This helper lets downstream code handle both forms. + """ + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + parsed = json.loads(value) + return parsed if isinstance(parsed, dict) else None + except (json.JSONDecodeError, TypeError): + return None + return None + + def extract_embedding_from_row( row: Any, *, @@ -18,14 +35,14 @@ def extract_embedding_from_row( embedding_key: str = "embedding", ) -> Optional[List[float]]: """Extract an embedding vector from a row.""" - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): + meta = _ensure_dict(getattr(row, "metadata", None)) + if meta is not None: emb = meta.get("embedding") if isinstance(emb, list) and emb: return emb # type: ignore[return-value] - payload = getattr(row, embedding_column, None) - if isinstance(payload, dict): + payload = _ensure_dict(getattr(row, embedding_column, None)) + if payload is not None: emb = payload.get(embedding_key) if isinstance(emb, list) and emb: return emb # type: ignore[return-value] @@ -48,8 +65,8 @@ def extract_source_path_and_page(row: Any) -> Tuple[str, int]: except Exception: pass - meta = getattr(row, "metadata", None) - if isinstance(meta, dict): + meta = _ensure_dict(getattr(row, "metadata", None)) + if meta is not None: source_path = meta.get("source_path") if isinstance(source_path, str) and source_path.strip(): path = source_path.strip() @@ -116,8 +133,8 @@ def build_lancedb_row( metadata_obj["pdf_page"] = pdf_page metadata_obj.update(_build_detection_metadata(row)) - orig_meta = getattr(row, "metadata", None) - if isinstance(orig_meta, dict): + orig_meta = _ensure_dict(getattr(row, "metadata", None)) + if orig_meta is not None: for key in ("chunk_index", "chunk_count"): if key in orig_meta: metadata_obj[key] = orig_meta[key] diff --git a/nemo_retriever/src/nemo_retriever/vector_store/stage.py b/nemo_retriever/src/nemo_retriever/vector_store/stage.py index ba3f994a1..7c59d718b 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/stage.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/stage.py @@ -10,7 +10,8 @@ import typer from rich.console import Console -from nemo_retriever.vector_store.lancedb_store import LanceDBConfig, write_text_embeddings_dir_to_lancedb +from nemo_retriever.params.models import LanceDbParams +from nemo_retriever.vector_store.lancedb_store import write_text_embeddings_dir_to_lancedb console = Console() app = typer.Typer(help="Vector store stage: upload stage5 embeddings to a vector DB (LanceDB).") @@ -54,8 +55,8 @@ def run( - `page_number`: page number from `metadata.content_metadata.page_number` - `path` / `source_id`: source identifiers """ - cfg = LanceDBConfig( - uri=str(lancedb_uri), + cfg = LanceDbParams( + lancedb_uri=str(lancedb_uri), table_name=str(table_name), overwrite=bool(overwrite), create_index=bool(create_index), @@ -73,7 +74,7 @@ def run( ) console.print( f"[green]Done[/green] files={info['n_files']} processed={info['processed']} skipped={info['skipped']} " - f"failed={info['failed']} lancedb_uri={cfg.uri} table={cfg.table_name}" + f"failed={info['failed']} lancedb_uri={cfg.lancedb_uri} table={cfg.table_name}" ) diff --git a/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py b/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py new file mode 100644 index 000000000..668e2a89a --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Canonical VDB record builder. + +Converts a pandas DataFrame (the graph pipeline's output format) into a list +of backend-neutral VDB record dicts. Every VDB backend in ``nemo_retriever`` +consumes this record format — it is the single source of truth for the +DataFrame → VDB record contract. + +Canonical record schema (matches ``retriever.py`` query expectations):: + + vector : list[float] # embedding + text : str # content + metadata : str # JSON string (round-trips via json.loads) + source : str # JSON string {"source_id": "..."} + page_number : int + pdf_page : str # "basename_pagenum" + pdf_basename : str + filename : str + source_id : str + path : str +""" + +from __future__ import annotations + +from typing import Any, Dict, List + +import pandas as pd + +from nemo_retriever.vector_store.lancedb_utils import build_lancedb_row + + +def build_vdb_records( + df: pd.DataFrame, + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", + text_column: str = "text", + include_text: bool = True, +) -> List[Dict[str, Any]]: + """Convert a post-embed DataFrame into canonical VDB records. + + Rows without a valid embedding are silently skipped. + """ + rows: List[Dict[str, Any]] = [] + for row in df.itertuples(index=False): + row_out = build_lancedb_row( + row, + embedding_column=embedding_column, + embedding_key=embedding_key, + text_column=text_column, + include_text=include_text, + ) + if row_out is not None: + rows.append(row_out) + return rows + + +def build_vdb_records_from_dicts( + records: List[Dict[str, Any]], + *, + embedding_column: str = "text_embeddings_1b_v2", + embedding_key: str = "embedding", + text_column: str = "text", + include_text: bool = True, +) -> List[Dict[str, Any]]: + """Convert a list of dicts (e.g. from ``take_all()``) into canonical VDB records.""" + if not records: + return [] + df = pd.DataFrame(records) + return build_vdb_records( + df, + embedding_column=embedding_column, + embedding_key=embedding_key, + text_column=text_column, + include_text=include_text, + ) diff --git a/nemo_retriever/tests/test_vdb_record_contract.py b/nemo_retriever/tests/test_vdb_record_contract.py new file mode 100644 index 000000000..031757cb5 --- /dev/null +++ b/nemo_retriever/tests/test_vdb_record_contract.py @@ -0,0 +1,352 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the canonical VDB record contract and LanceDB store helpers. + +These tests validate: + - build_vdb_records produces the correct canonical record format + - build_vdb_records_from_dicts (transitional list[dict] path) matches + - handle_lancedb round-trips data correctly (regression) + - _ensure_dict handles Arrow serialization robustness +""" + +from __future__ import annotations + +import copy +import json + +import pandas as pd +import pytest + + +# --------------------------------------------------------------------------- +# Shared fixture +# --------------------------------------------------------------------------- + + +def _make_sample_dataframe() -> pd.DataFrame: + """Build a minimal DataFrame matching the graph pipeline's post-embed output.""" + embedding = [0.1, 0.2, 0.3, 0.4] + metadata = { + "embedding": embedding, + "source_path": "/data/test.pdf", + "content_metadata": {"hierarchy": {"page": 0}}, + } + return pd.DataFrame( + [ + { + "metadata": metadata, + "text_embeddings_1b_v2": {"embedding": embedding, "info_msg": None}, + "text": "Hello world", + "path": "/data/test.pdf", + "page_number": 0, + "page_elements_v3_num_detections": 5, + "page_elements_v3_counts_by_label": {"text": 3, "table": 2}, + } + ] + ) + + +# --------------------------------------------------------------------------- +# Canonical record builder tests +# --------------------------------------------------------------------------- + + +class TestBuildVdbRecords: + def test_produces_all_required_fields(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + assert len(rows) == 1 + row = rows[0] + for field in ( + "vector", + "text", + "metadata", + "source", + "page_number", + "pdf_page", + "pdf_basename", + "source_id", + "path", + "filename", + ): + assert field in row, f"Missing required field: {field}" + + def test_metadata_is_valid_json(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + meta = json.loads(rows[0]["metadata"]) + assert isinstance(meta, dict) + assert "page_number" in meta + + def test_metadata_includes_detection_counts(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + meta = json.loads(rows[0]["metadata"]) + assert meta["page_elements_v3_num_detections"] == 5 + assert meta["page_elements_v3_counts_by_label"] == {"text": 3, "table": 2} + + def test_source_is_json_object(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + source = json.loads(rows[0]["source"]) + assert isinstance(source, dict) + assert "source_id" in source + + def test_does_not_mutate_input(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + original_meta = copy.deepcopy(df.iloc[0]["metadata"]) + + build_vdb_records(df) + + assert df.iloc[0]["metadata"] == original_meta + + def test_vector_is_embedding(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + + assert rows[0]["vector"] == [0.1, 0.2, 0.3, 0.4] + + def test_skips_rows_without_embedding(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = pd.DataFrame( + [ + { + "metadata": {"source_path": "/data/test.pdf"}, + "text": "No embedding here", + "path": "/data/test.pdf", + "page_number": 0, + } + ] + ) + rows = build_vdb_records(df) + assert len(rows) == 0 + + def test_empty_dataframe(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = pd.DataFrame() + rows = build_vdb_records(df) + assert rows == [] + + def test_text_content(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df) + assert rows[0]["text"] == "Hello world" + + def test_include_text_false(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + df = _make_sample_dataframe() + rows = build_vdb_records(df, include_text=False) + assert rows[0]["text"] == "" + + +# --------------------------------------------------------------------------- +# Transitional list[dict] builder tests +# --------------------------------------------------------------------------- + + +class TestBuildVdbRecordsFromDicts: + def test_matches_dataframe_path(self): + """list[dict] path should produce identical output to DataFrame path.""" + from nemo_retriever.vector_store.vdb_records import build_vdb_records, build_vdb_records_from_dicts + + df = _make_sample_dataframe() + records = df.to_dict("records") + + from_df = build_vdb_records(df) + from_dicts = build_vdb_records_from_dicts(records) + + assert len(from_df) == len(from_dicts) + assert from_df[0]["vector"] == from_dicts[0]["vector"] + assert from_df[0]["text"] == from_dicts[0]["text"] + assert from_df[0]["path"] == from_dicts[0]["path"] + assert from_df[0]["page_number"] == from_dicts[0]["page_number"] + + def test_empty_list(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records_from_dicts + + assert build_vdb_records_from_dicts([]) == [] + + +# --------------------------------------------------------------------------- +# Regression: handle_lancedb uses canonical builder +# --------------------------------------------------------------------------- + + +class TestHandleLancedbRegression: + def test_handle_lancedb_writes_valid_json_metadata(self, tmp_path): + """After refactoring, handle_lancedb should produce valid JSON metadata.""" + from nemo_retriever.vector_store.lancedb_store import handle_lancedb + import lancedb + + df = _make_sample_dataframe() + rows = df.to_dict("records") + + uri = str(tmp_path / "test_db") + handle_lancedb(rows, uri, "test_table", mode="overwrite") + + db = lancedb.connect(uri) + table = db.open_table("test_table") + result = table.to_pandas() + + assert len(result) == 1 + meta_str = result.iloc[0]["metadata"] + meta = json.loads(meta_str) + assert isinstance(meta, dict) + assert "page_number" in meta + + def test_handle_lancedb_accepts_dataframe(self, tmp_path): + """handle_lancedb should now accept a DataFrame directly.""" + from nemo_retriever.vector_store.lancedb_store import handle_lancedb + import lancedb + + df = _make_sample_dataframe() + + uri = str(tmp_path / "test_db") + handle_lancedb(df, uri, "test_table", mode="overwrite") + + db = lancedb.connect(uri) + table = db.open_table("test_table") + assert table.count_rows() == 1 + + def test_handle_lancedb_round_trip_preserves_text(self, tmp_path): + """Text content should survive the write→read round-trip.""" + from nemo_retriever.vector_store.lancedb_store import handle_lancedb + import lancedb + + df = _make_sample_dataframe() + rows = df.to_dict("records") + + uri = str(tmp_path / "test_db") + handle_lancedb(rows, uri, "test_table", mode="overwrite") + + db = lancedb.connect(uri) + table = db.open_table("test_table") + result = table.to_pandas() + + assert result.iloc[0]["text"] == "Hello world" + + def test_handle_lancedb_round_trip_preserves_path(self, tmp_path): + from nemo_retriever.vector_store.lancedb_store import handle_lancedb + import lancedb + + df = _make_sample_dataframe() + rows = df.to_dict("records") + + uri = str(tmp_path / "test_db") + handle_lancedb(rows, uri, "test_table", mode="overwrite") + + db = lancedb.connect(uri) + table = db.open_table("test_table") + result = table.to_pandas() + + assert result.iloc[0]["path"] == "/data/test.pdf" + assert result.iloc[0]["page_number"] == 0 + + +# --------------------------------------------------------------------------- +# _ensure_dict — Arrow serialization robustness +# --------------------------------------------------------------------------- + +class TestEnsureDict: + """Tests for the _ensure_dict helper that handles Arrow-serialized dict columns.""" + + def test_dict_passthrough(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + d = {"a": 1} + assert _ensure_dict(d) is d + + def test_json_string_parsed(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict('{"a": 1}') == {"a": 1} + + def test_none_returns_none(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict(None) is None + + def test_non_dict_json_returns_none(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict("[1, 2, 3]") is None + + def test_malformed_json_returns_none(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict("not json{") is None + + def test_integer_returns_none(self): + from nemo_retriever.vector_store.lancedb_utils import _ensure_dict + + assert _ensure_dict(42) is None + + +class TestBuildVdbRecordsArrowCompat: + """build_vdb_records handles string-encoded dict columns from Arrow.""" + + def test_string_metadata_with_embedding(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + embedding = [0.1, 0.2, 0.3, 0.4] + metadata = json.dumps({ + "embedding": embedding, + "source_path": "/data/test.pdf", + "content_metadata": {"hierarchy": {"page": 0}}, + }) + df = pd.DataFrame( + [ + { + "metadata": metadata, + "text_embeddings_1b_v2": json.dumps({"embedding": embedding}), + "text": "hello world", + "path": "/data/test.pdf", + "page_number": 0, + } + ] + ) + rows = build_vdb_records(df) + assert len(rows) == 1 + assert rows[0]["vector"] == embedding + + def test_string_embed_column_only(self): + from nemo_retriever.vector_store.vdb_records import build_vdb_records + + embedding = [1.0, 2.0] + df = pd.DataFrame( + [ + { + "metadata": "{}", + "text_embeddings_1b_v2": json.dumps({"embedding": embedding}), + "text": "test", + "path": "/x.pdf", + "page_number": 0, + } + ] + ) + rows = build_vdb_records(df) + assert len(rows) == 1 + assert rows[0]["vector"] == embedding diff --git a/nemo_retriever/tests/test_vdb_upload_operator.py b/nemo_retriever/tests/test_vdb_upload_operator.py new file mode 100644 index 000000000..1adbd577f --- /dev/null +++ b/nemo_retriever/tests/test_vdb_upload_operator.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for VDBUploadOperator — the graph-based VDB write path.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + +from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator +from nemo_retriever.params.models import LanceDbParams, VdbUploadParams + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +def _make_embedded_df(n: int = 3, dim: int = 4) -> pd.DataFrame: + """Build a minimal post-embed DataFrame with *n* rows.""" + rows = [] + for i in range(n): + embedding = [float(i + j) for j in range(dim)] + metadata = { + "embedding": embedding, + "source_path": f"/data/doc_{i}.pdf", + "content_metadata": {"hierarchy": {"page": i}}, + } + rows.append( + { + "metadata": metadata, + "text_embeddings_1b_v2": {"embedding": embedding, "info_msg": None}, + "text": f"content of page {i}", + "path": f"/data/doc_{i}.pdf", + "page_number": i, + "document_type": "text", + } + ) + return pd.DataFrame(rows) + + +@pytest.fixture() +def lance_params(tmp_path): + return LanceDbParams(lancedb_uri=str(tmp_path / "test_lancedb"), table_name="test_table") + + +@pytest.fixture() +def vdb_params(lance_params): + return VdbUploadParams(lancedb=lance_params) + + +# --------------------------------------------------------------------------- +# LanceDB write path tests +# --------------------------------------------------------------------------- + + +class TestVDBUploadOperator: + def test_writes_records_to_lancedb(self, vdb_params, lance_params): + """Operator writes canonical VDB records during process().""" + df = _make_embedded_df(3) + op = VDBUploadOperator(params=vdb_params) + result = op.run(df) + + assert isinstance(result, pd.DataFrame) + assert len(result) == 3 + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 3 + + def test_multiple_batches_accumulate(self, vdb_params, lance_params): + """Multiple process() calls append rows, not overwrite.""" + op = VDBUploadOperator(params=vdb_params) + op.run(_make_embedded_df(2)) + op.run(_make_embedded_df(3)) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 5 + + def test_empty_dataframe_is_noop(self, vdb_params): + """Empty DataFrame doesn't create a table or crash.""" + op = VDBUploadOperator(params=vdb_params) + result = op.run(pd.DataFrame()) + assert isinstance(result, pd.DataFrame) + assert len(result) == 0 + assert op._client_vdb is None + + def test_no_embeddings_is_noop(self, vdb_params): + """DataFrame without embedding columns produces no VDB records.""" + df = pd.DataFrame({"text": ["hello"], "path": ["/test.txt"]}) + op = VDBUploadOperator(params=vdb_params) + result = op.run(df) + assert len(result) == 1 + assert op._client_vdb is None + + def test_accepts_lance_params_directly(self, lance_params): + """Operator accepts LanceDbParams in addition to VdbUploadParams.""" + df = _make_embedded_df(1) + op = VDBUploadOperator(params=lance_params) + op.run(df) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 1 + + def test_default_params(self): + """Operator works with no params (uses defaults).""" + op = VDBUploadOperator() + assert op._lance_params is not None + + def test_preprocess_extracts_records(self, vdb_params): + """preprocess populates _pending_records from the DataFrame.""" + df = _make_embedded_df(3) + op = VDBUploadOperator(params=vdb_params) + result = op.preprocess(df) + + assert result is df + assert len(op._pending_records) == 3 + assert all("vector" in r for r in op._pending_records) + + +# --------------------------------------------------------------------------- +# Arrow serialization compat +# --------------------------------------------------------------------------- + + +class TestArrowSerializationCompat: + """Regression tests for Arrow-serialized dict columns (Ray Data pipeline).""" + + def _make_string_encoded_df(self, n: int = 3, dim: int = 4) -> pd.DataFrame: + """Build a DataFrame where dict columns are JSON strings, simulating Arrow round-trip.""" + rows = [] + for i in range(n): + embedding = [float(i + j) for j in range(dim)] + metadata = json.dumps({ + "embedding": embedding, + "source_path": f"/data/doc_{i}.pdf", + "content_metadata": {"hierarchy": {"page": i}}, + }) + embed_payload = json.dumps({"embedding": embedding, "info_msg": None}) + rows.append( + { + "metadata": metadata, + "text_embeddings_1b_v2": embed_payload, + "text": f"content of page {i}", + "path": f"/data/doc_{i}.pdf", + "page_number": i, + "document_type": "text", + } + ) + return pd.DataFrame(rows) + + def test_writes_records_with_string_metadata(self, vdb_params, lance_params): + """Operator handles Arrow-serialized string columns (not dicts).""" + df = self._make_string_encoded_df(3) + op = VDBUploadOperator(params=vdb_params) + result = op.run(df) + + assert len(result) == 3 + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 3 + + def test_writes_records_with_string_embed_column_only(self, vdb_params, lance_params): + """Embedding extracted from string-encoded embedding column.""" + embedding = [1.0, 2.0, 3.0, 4.0] + df = pd.DataFrame( + [ + { + "metadata": {"source_path": "/test.pdf"}, + "text_embeddings_1b_v2": json.dumps({"embedding": embedding}), + "text": "hello", + "path": "/test.pdf", + "page_number": 0, + } + ] + ) + op = VDBUploadOperator(params=vdb_params) + op.run(df) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 1 + + +# --------------------------------------------------------------------------- +# Milvus write path (mocked client) +# --------------------------------------------------------------------------- + + +class TestVDBUploadMilvus: + """VDBUploadOperator wrapping a mocked Milvus client VDB.""" + + def _make_milvus_params(self): + return VdbUploadParams(backend="milvus", client_vdb_kwargs={"collection_name": "test"}) + + @patch("nemo_retriever.graph.vdb_upload_operator.get_vdb_op_cls", create=True) + def test_delegates_to_client_write_to_index(self, mock_get_cls): + mock_client = MagicMock() + mock_get_cls.return_value = lambda **kwargs: mock_client + + params = self._make_milvus_params() + op = VDBUploadOperator(params=params) + df = _make_embedded_df(3) + + # Patch get_vdb_op_cls at the right location + with patch("nv_ingest_client.util.vdb.get_vdb_op_cls", return_value=lambda **kw: mock_client): + op.run(df) + + mock_client.create_index.assert_called_once() + mock_client.write_to_index.assert_called_once() + call_args = mock_client.write_to_index.call_args[0][0] + assert isinstance(call_args, list) + assert isinstance(call_args[0], list) + assert call_args[0][0]["document_type"] == "text" + assert "embedding" in call_args[0][0]["metadata"] + + @patch("nemo_retriever.graph.vdb_upload_operator.get_vdb_op_cls", create=True) + def test_multiple_batches_call_write_per_batch(self, mock_get_cls): + mock_client = MagicMock() + mock_get_cls.return_value = lambda **kwargs: mock_client + + params = self._make_milvus_params() + op = VDBUploadOperator(params=params) + + with patch("nv_ingest_client.util.vdb.get_vdb_op_cls", return_value=lambda **kw: mock_client): + op.run(_make_embedded_df(2)) + op.run(_make_embedded_df(3)) + + mock_client.create_index.assert_called_once() + assert mock_client.write_to_index.call_count == 2 + + +# --------------------------------------------------------------------------- +# Finalization +# --------------------------------------------------------------------------- + + +class TestVDBFinalization: + def test_lancedb_table_accessible_after_writes(self, vdb_params, lance_params): + """After operator writes, the table can be opened from disk.""" + df = _make_embedded_df(5) + op = VDBUploadOperator(params=vdb_params) + op.run(df) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 5 From d3de5602fde8fa5ad63434cb838c30964f4a36bf Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Tue, 14 Apr 2026 18:27:42 +0000 Subject: [PATCH 2/6] vector_store: add Milvus integration tests and fix client kwargs dispatch - Fix _write_via_client to use get_connection_params/get_write_params matching the client Milvus.run() dispatch pattern - Add integration_test_milvus_vdb.py (writes + search verification) - Add integration_test_milvus_recall.py (full jp20 pipeline + recall) - Milvus recall: @1=0.6435 @5=0.8783 @10=0.9217 (matches LanceDB) Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Jacob Ioffe --- .../graph/vdb_upload_operator.py | 15 +- .../vector_store/lancedb_store.py | 44 --- .../tests/integration_test_milvus_recall.py | 258 ++++++++++++++++++ .../tests/integration_test_milvus_vdb.py | 119 ++++++++ .../tests/test_vdb_record_contract.py | 77 ------ .../tests/test_vdb_upload_operator.py | 12 +- 6 files changed, 397 insertions(+), 128 deletions(-) create mode 100644 nemo_retriever/tests/integration_test_milvus_recall.py create mode 100644 nemo_retriever/tests/integration_test_milvus_vdb.py diff --git a/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py index 762c5ed50..c10572c00 100644 --- a/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py @@ -203,14 +203,21 @@ def _write_lancedb_batch(self, records: List[Dict[str, Any]]) -> None: # ------------------------------------------------------------------ def _write_via_client(self, records: List[Dict[str, Any]]) -> None: - """Convert canonical records to NV-Ingest format and delegate to the client VDB.""" + """Convert canonical records to NV-Ingest format and delegate to the client VDB. + + Client VDB classes split their config into connection params + (for ``create_index``) and write params (for ``write_to_index``) + via ``get_connection_params()`` / ``get_write_params()``. We + mirror the same dispatch that ``Milvus.run()`` uses. + """ nvingest_records = _canonical_to_nvingest(records) if not nvingest_records: return if self._table is None: - # First batch — create collection schema. - self._client_vdb.create_index() + collection_name, create_params = self._client_vdb.get_connection_params() + self._client_vdb.create_index(collection_name=collection_name, **create_params) self._table = True # sentinel: schema created - self._client_vdb.write_to_index(nvingest_records) + _, write_params = self._client_vdb.get_write_params() + self._client_vdb.write_to_index(nvingest_records, **write_params) diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py index dec989ed5..e25ecc349 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py @@ -160,47 +160,3 @@ def write_text_embeddings_dir_to_lancedb( "lancedb": {"uri": cfg.lancedb_uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, } - -def handle_lancedb( - rows: Any, - uri: str, - table_name: str, - hybrid: bool = False, - mode: str = "overwrite", -) -> None: - """Write pipeline results to LanceDB. - - Accepts *rows* as a ``pd.DataFrame`` or ``list[dict]`` (e.g. from - ``take_all()``). Converts to canonical VDB records, writes to - LanceDB, and creates search indices. - """ - import lancedb - - from nemo_retriever.vector_store.lancedb_utils import infer_vector_dim, lancedb_schema - - if isinstance(rows, pd.DataFrame): - records = build_vdb_records(rows) - else: - records = build_vdb_records_from_dicts(rows) - - if not records: - return - - cfg = LanceDbParams( - lancedb_uri=uri, - table_name=table_name, - hybrid=hybrid, - overwrite=(mode == "overwrite"), - ) - dim = infer_vector_dim(records) - schema = lancedb_schema(vector_dim=dim) - db = lancedb.connect(uri=uri) - table = db.create_table(table_name, data=records, schema=schema, mode=mode) - try: - create_lancedb_index(table, cfg=cfg) - except RuntimeError: - logger.warning( - "Index creation failed (likely too few rows for %d partitions); skipping.", - cfg.num_partitions, - exc_info=True, - ) diff --git a/nemo_retriever/tests/integration_test_milvus_recall.py b/nemo_retriever/tests/integration_test_milvus_recall.py new file mode 100644 index 000000000..16064f195 --- /dev/null +++ b/nemo_retriever/tests/integration_test_milvus_recall.py @@ -0,0 +1,258 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""End-to-end Milvus recall test: ingest jp20 → embed queries → search → compute recall. + +Requires: + - Running Milvus instance + - GPU for extraction + embedding + - jp20 dataset at /datasets/nv-ingest/jp20 + - Query CSV at data/jp20_query_gt.csv + +Usage: + python tests/integration_test_milvus_recall.py [--milvus-uri URI] +""" + +from __future__ import annotations + +import sys +import time +from pathlib import Path + +MILVUS_URI = "http://172.20.0.4:19530" +COLLECTION_NAME = "jp20_recall_test" +DATASET_DIR = "/datasets/nv-ingest/jp20" +QUERY_CSV = "/raid/jioffe/NeMo-Retriever/data/jp20_query_gt.csv" +EMBED_MODEL = "nvidia/llama-nemotron-embed-1b-v2" +EMBED_DIM = 2048 +TOP_K = 10 + + +def ingest_jp20_to_milvus(): + """Run the full graph pipeline with Milvus as the VDB backend.""" + from nemo_retriever import create_ingestor + from nemo_retriever.params import EmbedParams, ExtractParams, VdbUploadParams + from nemo_retriever.params.models import LanceDbParams + + print("--- Step 1: Ingesting jp20 into Milvus ---") + t0 = time.perf_counter() + + ingestor = ( + create_ingestor(run_mode="batch") + .files(DATASET_DIR + "/*.pdf") + .extract(extract_text=True, extract_tables=True, extract_charts=True) + .embed( + EmbedParams( + model_name=EMBED_MODEL, + inference_batch_size=32, + ) + ) + .vdb_upload( + VdbUploadParams( + backend="milvus", + client_vdb_kwargs={ + "milvus_uri": MILVUS_URI, + "collection_name": COLLECTION_NAME, + "dense_dim": EMBED_DIM, + "recreate": True, + "gpu_index": False, + "stream": True, + "sparse": False, + }, + ) + ) + ) + ingestor.ingest() + + elapsed = time.perf_counter() - t0 + print(f"[OK] Ingestion complete in {elapsed:.1f}s") + return elapsed + + +def embed_queries(): + """Embed jp20 queries using the local model.""" + import pandas as pd + + from nemo_retriever.model import create_local_embedder + + print("--- Step 2: Embedding queries ---") + + df = pd.read_csv(QUERY_CSV) + queries = df["query"].astype(str).tolist() + gold_keys = [] + for _, row in df.iterrows(): + if "pdf_page" in df.columns: + gold_keys.append(str(row["pdf_page"])) + else: + gold_keys.append(f"{row['pdf']}_{row['page']}") + + print(f" {len(queries)} queries, embedding with {EMBED_MODEL}...") + embedder = create_local_embedder(EMBED_MODEL, device="cuda") + vecs = embedder.embed(["query: " + q for q in queries], batch_size=32) + query_embeddings = vecs.detach().to("cpu").tolist() + print(f"[OK] Embedded {len(query_embeddings)} queries") + + return queries, gold_keys, query_embeddings + + +def search_milvus(query_embeddings): + """Search Milvus collection with pre-computed query embeddings.""" + from pymilvus import MilvusClient + + print("--- Step 3: Searching Milvus ---") + + client = MilvusClient(uri=MILVUS_URI) + t0 = time.perf_counter() + + all_hits = [] + # Batch queries to avoid overwhelming Milvus + batch_size = 20 + for i in range(0, len(query_embeddings), batch_size): + batch = query_embeddings[i : i + batch_size] + results = client.search( + collection_name=COLLECTION_NAME, + data=batch, + limit=TOP_K, + output_fields=["text", "source", "content_metadata"], + ) + all_hits.extend(results) + + elapsed = time.perf_counter() - t0 + print(f"[OK] Searched {len(query_embeddings)} queries in {elapsed:.1f}s ({len(query_embeddings)/elapsed:.1f} QPS)") + + return all_hits + + +def hits_to_keys(all_hits): + """Extract pdf_page keys from Milvus search results.""" + import json + + retrieved_keys = [] + for hits in all_hits: + keys = [] + for h in hits: + entity = h.get("entity", {}) + # Milvus stores content_metadata and source as dicts (not JSON strings) + content_meta = entity.get("content_metadata", {}) + source = entity.get("source", {}) + + if isinstance(content_meta, str): + try: + content_meta = json.loads(content_meta) + except (json.JSONDecodeError, TypeError): + content_meta = {} + if isinstance(source, str): + try: + source = json.loads(source) + except (json.JSONDecodeError, TypeError): + source = {} + + source_id = source.get("source_id", "") + page_number = content_meta.get("page_number") + if page_number is None: + hierarchy = content_meta.get("hierarchy", {}) + if isinstance(hierarchy, dict): + page_number = hierarchy.get("page") + + if source_id and page_number is not None: + filename = Path(str(source_id)).stem + keys.append(f"{filename}_{page_number}") + retrieved_keys.append(keys) + + return retrieved_keys + + +def compute_recall(gold_keys, retrieved_keys, ks=(1, 5, 10)): + """Compute recall@k metrics.""" + print("--- Step 4: Computing recall ---") + + metrics = {} + for k in ks: + hits = 0 + for gold, retrieved in zip(gold_keys, retrieved_keys): + parts = str(gold).rsplit("_", 1) + if len(parts) == 2: + specific = f"{parts[0]}_{parts[1]}" + whole_doc = f"{parts[0]}_-1" + top = retrieved[:k] + if specific in top or whole_doc in top: + hits += 1 + elif gold in retrieved[:k]: + hits += 1 + recall = hits / max(1, len(gold_keys)) + metrics[f"recall@{k}"] = recall + print(f" recall@{k}: {recall:.4f}") + + return metrics + + +def cleanup(): + from pymilvus import MilvusClient + + client = MilvusClient(uri=MILVUS_URI) + if client.has_collection(COLLECTION_NAME): + client.drop_collection(COLLECTION_NAME) + print(f"[OK] Cleaned up collection {COLLECTION_NAME}") + + +def main(): + print("=" * 60) + print("Milvus End-to-End Recall Test (jp20)") + print(f" Milvus: {MILVUS_URI}") + print(f" Collection: {COLLECTION_NAME}") + print(f" Dataset: {DATASET_DIR}") + print(f" Queries: {QUERY_CSV}") + print("=" * 60) + print() + + try: + ingest_secs = ingest_jp20_to_milvus() + queries, gold_keys, query_embeddings = embed_queries() + all_hits = search_milvus(query_embeddings) + retrieved_keys = hits_to_keys(all_hits) + metrics = compute_recall(gold_keys, retrieved_keys) + + print() + print("=" * 60) + print("RESULTS") + print(f" Ingestion: {ingest_secs:.1f}s") + for k, v in metrics.items(): + print(f" {k}: {v:.4f}") + print() + + # Compare against LanceDB baseline + baseline = {"recall@1": 0.6435, "recall@5": 0.8783, "recall@10": 0.9304} + print("vs LanceDB baseline:") + all_match = True + for k, bl in baseline.items(): + mv = metrics.get(k, 0) + delta = mv - bl + status = "MATCH" if abs(delta) < 0.01 else ("BETTER" if delta > 0 else "WORSE") + print(f" {k}: Milvus={mv:.4f} LanceDB={bl:.4f} ({status})") + if status == "WORSE": + all_match = False + + print() + if all_match: + print("=== PASS: Milvus recall matches LanceDB baseline ===") + else: + print("=== NOTE: Milvus recall differs from LanceDB baseline ===") + + return 0 + + except Exception as e: + print(f"\n[FAIL] {e}") + import traceback + traceback.print_exc() + return 1 + + finally: + cleanup() + + +if __name__ == "__main__": + if "--milvus-uri" in sys.argv: + idx = sys.argv.index("--milvus-uri") + MILVUS_URI = sys.argv[idx + 1] + sys.exit(main()) diff --git a/nemo_retriever/tests/integration_test_milvus_vdb.py b/nemo_retriever/tests/integration_test_milvus_vdb.py new file mode 100644 index 000000000..31e68e98d --- /dev/null +++ b/nemo_retriever/tests/integration_test_milvus_vdb.py @@ -0,0 +1,119 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Integration test: VDBUploadOperator writing to a real Milvus instance. + +Requires a running Milvus at the URI below. Run manually: + + python tests/integration_test_milvus_vdb.py + +This is NOT part of the pytest suite — it requires external infrastructure. +""" + +from __future__ import annotations + +import sys + +import pandas as pd + +MILVUS_URI = "http://172.20.0.4:19530" +COLLECTION_NAME = "nemo_retriever_integration_test" +EMBED_DIM = 128 + + +def _make_embedded_df(n: int = 10, dim: int = EMBED_DIM) -> pd.DataFrame: + """Build a minimal post-embed DataFrame.""" + rows = [] + for i in range(n): + embedding = [float(i * dim + j) / (n * dim) for j in range(dim)] + metadata = { + "embedding": embedding, + "source_path": f"/data/doc_{i}.pdf", + "content_metadata": {"hierarchy": {"page": i}}, + } + rows.append( + { + "metadata": metadata, + "text_embeddings_1b_v2": {"embedding": embedding, "info_msg": None}, + "text": f"This is the content of page {i} in the test document.", + "path": f"/data/doc_{i}.pdf", + "page_number": i, + "document_type": "text", + } + ) + return pd.DataFrame(rows) + + +def main(): + from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator + from nemo_retriever.params.models import VdbUploadParams + + print(f"=== Milvus Integration Test ===") + print(f"Milvus URI: {MILVUS_URI}") + print(f"Collection: {COLLECTION_NAME}") + print() + + # --- Step 1: Create operator with Milvus backend --- + params = VdbUploadParams( + backend="milvus", + client_vdb_kwargs={ + "milvus_uri": MILVUS_URI, + "collection_name": COLLECTION_NAME, + "dense_dim": EMBED_DIM, + "recreate": True, + "gpu_index": False, + "stream": True, + "sparse": False, + }, + ) + op = VDBUploadOperator(params=params) + print(f"[OK] VDBUploadOperator created with backend='milvus'") + + # --- Step 2: Run two batches through the operator --- + df_batch1 = _make_embedded_df(5) + df_batch2 = _make_embedded_df(5) + + result1 = op.run(df_batch1) + print(f"[OK] Batch 1: {len(result1)} rows passed through, records written") + + result2 = op.run(df_batch2) + print(f"[OK] Batch 2: {len(result2)} rows passed through, records written") + + # --- Step 3: Verify data landed in Milvus --- + from pymilvus import MilvusClient + + client = MilvusClient(uri=MILVUS_URI) + + if not client.has_collection(COLLECTION_NAME): + print(f"[FAIL] Collection {COLLECTION_NAME} not found!") + return 1 + + stats = client.get_collection_stats(COLLECTION_NAME) + row_count = stats.get("row_count", 0) + print(f"[OK] Collection exists with {row_count} rows") + + # Query to verify data is searchable + query_vector = _make_embedded_df(1).iloc[0]["metadata"]["embedding"] + results = client.search( + collection_name=COLLECTION_NAME, + data=[query_vector], + limit=3, + output_fields=["text"], + ) + print(f"[OK] Search returned {len(results[0])} hits") + for i, hit in enumerate(results[0]): + text = hit.get("entity", {}).get("text", "")[:60] + print(f" Hit {i}: distance={hit['distance']:.4f} text='{text}...'") + + # --- Step 4: Cleanup --- + client.drop_collection(COLLECTION_NAME) + print(f"[OK] Cleaned up collection {COLLECTION_NAME}") + + print() + print("=== PASS: Milvus integration test completed successfully ===") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/nemo_retriever/tests/test_vdb_record_contract.py b/nemo_retriever/tests/test_vdb_record_contract.py index 031757cb5..67ac115f9 100644 --- a/nemo_retriever/tests/test_vdb_record_contract.py +++ b/nemo_retriever/tests/test_vdb_record_contract.py @@ -7,7 +7,6 @@ These tests validate: - build_vdb_records produces the correct canonical record format - build_vdb_records_from_dicts (transitional list[dict] path) matches - - handle_lancedb round-trips data correctly (regression) - _ensure_dict handles Arrow serialization robustness """ @@ -190,82 +189,6 @@ def test_empty_list(self): assert build_vdb_records_from_dicts([]) == [] -# --------------------------------------------------------------------------- -# Regression: handle_lancedb uses canonical builder -# --------------------------------------------------------------------------- - - -class TestHandleLancedbRegression: - def test_handle_lancedb_writes_valid_json_metadata(self, tmp_path): - """After refactoring, handle_lancedb should produce valid JSON metadata.""" - from nemo_retriever.vector_store.lancedb_store import handle_lancedb - import lancedb - - df = _make_sample_dataframe() - rows = df.to_dict("records") - - uri = str(tmp_path / "test_db") - handle_lancedb(rows, uri, "test_table", mode="overwrite") - - db = lancedb.connect(uri) - table = db.open_table("test_table") - result = table.to_pandas() - - assert len(result) == 1 - meta_str = result.iloc[0]["metadata"] - meta = json.loads(meta_str) - assert isinstance(meta, dict) - assert "page_number" in meta - - def test_handle_lancedb_accepts_dataframe(self, tmp_path): - """handle_lancedb should now accept a DataFrame directly.""" - from nemo_retriever.vector_store.lancedb_store import handle_lancedb - import lancedb - - df = _make_sample_dataframe() - - uri = str(tmp_path / "test_db") - handle_lancedb(df, uri, "test_table", mode="overwrite") - - db = lancedb.connect(uri) - table = db.open_table("test_table") - assert table.count_rows() == 1 - - def test_handle_lancedb_round_trip_preserves_text(self, tmp_path): - """Text content should survive the write→read round-trip.""" - from nemo_retriever.vector_store.lancedb_store import handle_lancedb - import lancedb - - df = _make_sample_dataframe() - rows = df.to_dict("records") - - uri = str(tmp_path / "test_db") - handle_lancedb(rows, uri, "test_table", mode="overwrite") - - db = lancedb.connect(uri) - table = db.open_table("test_table") - result = table.to_pandas() - - assert result.iloc[0]["text"] == "Hello world" - - def test_handle_lancedb_round_trip_preserves_path(self, tmp_path): - from nemo_retriever.vector_store.lancedb_store import handle_lancedb - import lancedb - - df = _make_sample_dataframe() - rows = df.to_dict("records") - - uri = str(tmp_path / "test_db") - handle_lancedb(rows, uri, "test_table", mode="overwrite") - - db = lancedb.connect(uri) - table = db.open_table("test_table") - result = table.to_pandas() - - assert result.iloc[0]["path"] == "/data/test.pdf" - assert result.iloc[0]["page_number"] == 0 - - # --------------------------------------------------------------------------- # _ensure_dict — Arrow serialization robustness # --------------------------------------------------------------------------- diff --git a/nemo_retriever/tests/test_vdb_upload_operator.py b/nemo_retriever/tests/test_vdb_upload_operator.py index 1adbd577f..682666204 100644 --- a/nemo_retriever/tests/test_vdb_upload_operator.py +++ b/nemo_retriever/tests/test_vdb_upload_operator.py @@ -211,16 +211,22 @@ class TestVDBUploadMilvus: def _make_milvus_params(self): return VdbUploadParams(backend="milvus", client_vdb_kwargs={"collection_name": "test"}) + def _make_mock_client(self): + mock_client = MagicMock() + mock_client.collection_name = "test" + mock_client.get_connection_params.return_value = ("test", {"milvus_uri": "http://localhost:19530"}) + mock_client.get_write_params.return_value = ("test", {"collection_name": "test"}) + return mock_client + @patch("nemo_retriever.graph.vdb_upload_operator.get_vdb_op_cls", create=True) def test_delegates_to_client_write_to_index(self, mock_get_cls): - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_get_cls.return_value = lambda **kwargs: mock_client params = self._make_milvus_params() op = VDBUploadOperator(params=params) df = _make_embedded_df(3) - # Patch get_vdb_op_cls at the right location with patch("nv_ingest_client.util.vdb.get_vdb_op_cls", return_value=lambda **kw: mock_client): op.run(df) @@ -234,7 +240,7 @@ def test_delegates_to_client_write_to_index(self, mock_get_cls): @patch("nemo_retriever.graph.vdb_upload_operator.get_vdb_op_cls", create=True) def test_multiple_batches_call_write_per_batch(self, mock_get_cls): - mock_client = MagicMock() + mock_client = self._make_mock_client() mock_get_cls.return_value = lambda **kwargs: mock_client params = self._make_milvus_params() From f3ba6162125d2c8f95bf5f12b834fdaf5d81937e Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Tue, 14 Apr 2026 20:11:10 +0000 Subject: [PATCH 3/6] vector_store: fix test_batch_pipeline for removed handle_lancedb - Remove monkeypatches for _ensure_lancedb_table and handle_lancedb - Add vdb_upload/store/caption/dedup stubs to _FakeIngestor - Add count() to _FakeDataset for no-collect code path - Delete dead handle_lancedb function and its tests Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Jacob Ioffe --- nemo_retriever/tests/test_batch_pipeline.py | 23 ++++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/nemo_retriever/tests/test_batch_pipeline.py b/nemo_retriever/tests/test_batch_pipeline.py index b5d19b97d..f66e5f0e7 100644 --- a/nemo_retriever/tests/test_batch_pipeline.py +++ b/nemo_retriever/tests/test_batch_pipeline.py @@ -20,6 +20,9 @@ def materialize(self): def take_all(self): return [] + def count(self): + return 1 + def groupby(self, _key): class _FakeGrouped: @staticmethod @@ -80,6 +83,18 @@ def embed(self, params): self.embed_params = params return self + def vdb_upload(self, params=None): + return self + + def store(self, params=None): + return self + + def caption(self, params=None): + return self + + def dedup(self, params=None): + return self + def ingest(self, params=None): return _FakeDataset() @@ -108,8 +123,6 @@ def test_batch_pipeline_accepts_multimodal_embed_and_page_image_flags(tmp_path, fake_ingestor = _FakeIngestor() monkeypatch.setattr(batch_pipeline, "GraphIngestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) monkeypatch.setitem(sys.modules, "ray", SimpleNamespace(shutdown=lambda: None)) class _FakeTable: @@ -154,8 +167,6 @@ def test_batch_pipeline_routes_audio_input_to_audio_ingestor(tmp_path, monkeypat fake_ingestor = _FakeIngestor() monkeypatch.setattr(batch_pipeline, "GraphIngestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) monkeypatch.setitem(sys.modules, "ray", SimpleNamespace(shutdown=lambda: None)) monkeypatch.setattr( batch_pipeline, "asr_params_from_env", lambda: SimpleNamespace(model_copy=lambda update: update) @@ -206,8 +217,6 @@ def test_batch_pipeline_routes_beir_mode_to_evaluator(tmp_path, monkeypatch) -> fake_ingestor = _FakeIngestor() monkeypatch.setattr(batch_pipeline, "GraphIngestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) monkeypatch.setattr(detection_summary_module, "print_run_summary", lambda *args, **kwargs: None) class _FakeTable: @@ -267,8 +276,6 @@ def test_batch_pipeline_accepts_harness_runtime_metric_flags(tmp_path, monkeypat fake_ingestor = _FakeIngestor() monkeypatch.setattr(batch_pipeline, "GraphIngestor", lambda *args, **kwargs: fake_ingestor) - monkeypatch.setattr(batch_pipeline, "_ensure_lancedb_table", lambda *args, **kwargs: None) - monkeypatch.setattr(batch_pipeline, "handle_lancedb", lambda *args, **kwargs: None) monkeypatch.setitem(sys.modules, "ray", SimpleNamespace(shutdown=lambda: None)) class _FakeTable: From 2328a0dd2edbba5c490629b9922e2b27ffcb0505 Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Thu, 16 Apr 2026 15:40:12 +0000 Subject: [PATCH 4/6] vector_store: fix pre-commit formatting and unused imports - black reformatting on test files - remove unused imports flagged by flake8 - end-of-file-fixer on lancedb_store.py Co-Authored-By: Claude Opus 4 Signed-off-by: Jacob Ioffe --- .../nemo_retriever/vector_store/lancedb_store.py | 3 +-- .../tests/integration_test_milvus_recall.py | 4 ++-- nemo_retriever/tests/test_vdb_record_contract.py | 14 ++++++++------ nemo_retriever/tests/test_vdb_upload_operator.py | 12 +++++++----- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py index e25ecc349..b822740bc 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_store.py @@ -12,7 +12,7 @@ from nv_ingest_client.util.vdb.lancedb import LanceDB from nemo_retriever.params.models import LanceDbParams -from nemo_retriever.vector_store.vdb_records import build_vdb_records, build_vdb_records_from_dicts +from nemo_retriever.vector_store.vdb_records import build_vdb_records import pandas as pd logger = logging.getLogger(__name__) @@ -159,4 +159,3 @@ def write_text_embeddings_dir_to_lancedb( "failed": failed, "lancedb": {"uri": cfg.lancedb_uri, "table_name": cfg.table_name, "overwrite": cfg.overwrite}, } - diff --git a/nemo_retriever/tests/integration_test_milvus_recall.py b/nemo_retriever/tests/integration_test_milvus_recall.py index 16064f195..bbeddfe4e 100644 --- a/nemo_retriever/tests/integration_test_milvus_recall.py +++ b/nemo_retriever/tests/integration_test_milvus_recall.py @@ -32,8 +32,7 @@ def ingest_jp20_to_milvus(): """Run the full graph pipeline with Milvus as the VDB backend.""" from nemo_retriever import create_ingestor - from nemo_retriever.params import EmbedParams, ExtractParams, VdbUploadParams - from nemo_retriever.params.models import LanceDbParams + from nemo_retriever.params import EmbedParams, VdbUploadParams print("--- Step 1: Ingesting jp20 into Milvus ---") t0 = time.perf_counter() @@ -244,6 +243,7 @@ def main(): except Exception as e: print(f"\n[FAIL] {e}") import traceback + traceback.print_exc() return 1 diff --git a/nemo_retriever/tests/test_vdb_record_contract.py b/nemo_retriever/tests/test_vdb_record_contract.py index 67ac115f9..0a346e837 100644 --- a/nemo_retriever/tests/test_vdb_record_contract.py +++ b/nemo_retriever/tests/test_vdb_record_contract.py @@ -16,7 +16,6 @@ import json import pandas as pd -import pytest # --------------------------------------------------------------------------- @@ -193,6 +192,7 @@ def test_empty_list(self): # _ensure_dict — Arrow serialization robustness # --------------------------------------------------------------------------- + class TestEnsureDict: """Tests for the _ensure_dict helper that handles Arrow-serialized dict columns.""" @@ -235,11 +235,13 @@ def test_string_metadata_with_embedding(self): from nemo_retriever.vector_store.vdb_records import build_vdb_records embedding = [0.1, 0.2, 0.3, 0.4] - metadata = json.dumps({ - "embedding": embedding, - "source_path": "/data/test.pdf", - "content_metadata": {"hierarchy": {"page": 0}}, - }) + metadata = json.dumps( + { + "embedding": embedding, + "source_path": "/data/test.pdf", + "content_metadata": {"hierarchy": {"page": 0}}, + } + ) df = pd.DataFrame( [ { diff --git a/nemo_retriever/tests/test_vdb_upload_operator.py b/nemo_retriever/tests/test_vdb_upload_operator.py index 682666204..322a7e80b 100644 --- a/nemo_retriever/tests/test_vdb_upload_operator.py +++ b/nemo_retriever/tests/test_vdb_upload_operator.py @@ -144,11 +144,13 @@ def _make_string_encoded_df(self, n: int = 3, dim: int = 4) -> pd.DataFrame: rows = [] for i in range(n): embedding = [float(i + j) for j in range(dim)] - metadata = json.dumps({ - "embedding": embedding, - "source_path": f"/data/doc_{i}.pdf", - "content_metadata": {"hierarchy": {"page": i}}, - }) + metadata = json.dumps( + { + "embedding": embedding, + "source_path": f"/data/doc_{i}.pdf", + "content_metadata": {"hierarchy": {"page": i}}, + } + ) embed_payload = json.dumps({"embedding": embedding, "info_msg": None}) rows.append( { From f3d7d19a6f1266e18910d4927105ef65224be72a Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Thu, 16 Apr 2026 22:16:48 +0000 Subject: [PATCH 5/6] vector_store: accept pre-constructed client VDB in graph upload Lets callers pass a ready-made LanceDB/Milvus instance from nv_ingest_client.util.vdb directly into GraphIngestor.vdb_upload(vdb_op=...) so the graph wraps it instead of rebuilding one from VdbUploadParams. The operator captures vdb_op in get_constructor_kwargs() so it round-trips to Ray workers unchanged. - VDBUploadOperator: new vdb_op kwarg, backend derived from its class name - GraphIngestor.vdb_upload + build_graph + _append_ordered_transform_stages thread vdb_op through to the operator - _finalize_vdb reuses the passed LanceDB instance for post-run indexing - Extracted build_client_lancedb helper; removed dead build_lancedb_rows / build_vdb_records_from_dicts; fixed latent missing logger in graph_ingestor - Added 3 unit tests + end-to-end integration test covering both backends (bo20 ingested 831 rows to both LanceDB and Milvus via passthrough) Signed-off-by: Jacob Ioffe --- .../nemo_retriever/graph/ingestor_runtime.py | 5 +- .../graph/vdb_upload_operator.py | 43 ++--- .../src/nemo_retriever/graph_ingestor.py | 63 +++++--- .../nemo_retriever/vector_store/__init__.py | 3 +- .../vector_store/lancedb_utils.py | 45 +++--- .../vector_store/vdb_records.py | 21 --- .../integration_test_vdb_op_passthrough.py | 150 ++++++++++++++++++ .../tests/test_vdb_record_contract.py | 29 ---- .../tests/test_vdb_upload_operator.py | 59 +++++++ 9 files changed, 305 insertions(+), 113 deletions(-) create mode 100644 nemo_retriever/tests/integration_test_vdb_op_passthrough.py diff --git a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py index 96e703f55..728c47ad1 100644 --- a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py +++ b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py @@ -321,6 +321,7 @@ def _append_ordered_transform_stages( store_params: Any | None, embed_params: Any | None, vdb_upload_params: Any | None, + vdb_op: Any | None, stage_order: tuple[str, ...], supports_dedup: bool, reshape_for_modal_content: bool, @@ -384,7 +385,7 @@ def _append_ordered_transform_stages( ) graph = graph >> _BatchEmbedActor(params=embed_params) elif stage_name == "vdb_upload" and vdb_upload_params is not None: - graph = graph >> VDBUploadOperator(params=vdb_upload_params) + graph = graph >> VDBUploadOperator(params=vdb_upload_params, vdb_op=vdb_op) return graph @@ -404,6 +405,7 @@ def build_graph( caption_params: Any | None = None, store_params: Any | None = None, vdb_upload_params: Any | None = None, + vdb_op: Any | None = None, stage_order: tuple[str, ...] = (), ) -> Graph: """Build a batch graph from explicit params or a shared execution plan.""" @@ -567,6 +569,7 @@ def build_graph( store_params=store_params, embed_params=embed_params, vdb_upload_params=vdb_upload_params, + vdb_op=vdb_op, stage_order=stage_order, supports_dedup=True, reshape_for_modal_content=True, diff --git a/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py index c10572c00..60988132f 100644 --- a/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py @@ -66,6 +66,15 @@ class VDBUploadOperator(AbstractOperator, CPUOperator): The DataFrame is passed through unchanged — this is a side-effect operator. + **Configuring the backend.** Callers can either: + + * pass ``params`` and let the operator construct the client VDB from + ``VdbUploadParams.backend`` + ``client_vdb_kwargs``, or + * pass a pre-constructed client VDB via ``vdb_op`` + (e.g. ``Milvus(...)`` / ``LanceDB(...)`` from + ``nv_ingest_client.util.vdb``). ``params`` still supplies the + record-shaping config (embedding column, text column, etc.). + **Concurrency**: This operator must run with ``concurrency=1`` in batch mode. The single actor creates the table on its first write (respecting ``overwrite``) and appends on subsequent writes. Index creation happens @@ -77,13 +86,15 @@ def __init__( self, *, params: Any = None, + vdb_op: Any = None, ) -> None: super().__init__() from nemo_retriever.params.models import LanceDbParams, VdbUploadParams - # Store as self.params so get_constructor_kwargs() can capture it - # for deferred reconstruction on Ray workers. + # Store as self. so get_constructor_kwargs() captures both for + # deferred reconstruction on Ray workers. self.params = params + self.vdb_op = vdb_op if isinstance(params, VdbUploadParams): self._vdb_params = params @@ -95,9 +106,14 @@ def __init__( self._vdb_params = VdbUploadParams() self._lance_params = self._vdb_params.lancedb - self._backend_name: str = getattr(self._vdb_params, "backend", "lancedb") if self._vdb_params else "lancedb" - self._client_vdb: Any = None + if vdb_op is not None: + self._backend_name = type(vdb_op).__name__.lower() + else: + self._backend_name = getattr(self._vdb_params, "backend", "lancedb") if self._vdb_params else "lancedb" + + self._client_vdb: Any = vdb_op self._table: Any = None + self._index_created: bool = False self._pending_records: List[Dict[str, Any]] = [] # ------------------------------------------------------------------ @@ -109,18 +125,9 @@ def _create_client_vdb(self) -> Any: from nv_ingest_client.util.vdb import get_vdb_op_cls if self._backend_name == "lancedb": - LanceDB = get_vdb_op_cls("lancedb") - return LanceDB( - uri=self._lance_params.lancedb_uri, - table_name=self._lance_params.table_name, - overwrite=self._lance_params.overwrite, - index_type=self._lance_params.index_type, - metric=self._lance_params.metric, - num_partitions=self._lance_params.num_partitions, - num_sub_vectors=self._lance_params.num_sub_vectors, - hybrid=self._lance_params.hybrid, - fts_language=self._lance_params.fts_language, - ) + from nemo_retriever.vector_store.lancedb_utils import build_client_lancedb + + return build_client_lancedb(self._lance_params) kwargs = getattr(self._vdb_params, "client_vdb_kwargs", {}) or {} return get_vdb_op_cls(self._backend_name)(**kwargs) @@ -214,10 +221,10 @@ def _write_via_client(self, records: List[Dict[str, Any]]) -> None: if not nvingest_records: return - if self._table is None: + if not self._index_created: collection_name, create_params = self._client_vdb.get_connection_params() self._client_vdb.create_index(collection_name=collection_name, **create_params) - self._table = True # sentinel: schema created + self._index_created = True _, write_params = self._client_vdb.get_write_params() self._client_vdb.write_to_index(nvingest_records, **write_params) diff --git a/nemo_retriever/src/nemo_retriever/graph_ingestor.py b/nemo_retriever/src/nemo_retriever/graph_ingestor.py index 237c05597..c09986b56 100644 --- a/nemo_retriever/src/nemo_retriever/graph_ingestor.py +++ b/nemo_retriever/src/nemo_retriever/graph_ingestor.py @@ -27,6 +27,7 @@ from __future__ import annotations import json +import logging from typing import Any, Callable, Dict, List, Optional, Union from nemo_retriever.graph import InprocessExecutor, RayDataExecutor @@ -47,6 +48,8 @@ ) from nemo_retriever.utils.remote_auth import resolve_remote_api_key +logger = logging.getLogger(__name__) + def _resolve_api_key(params: Any) -> Any: """Auto-resolve api_key from NVIDIA_API_KEY / NGC_API_KEY if not explicitly set.""" @@ -148,6 +151,7 @@ def __init__( self._dedup_params: Any = None self._store_params: Any = None self._vdb_upload_params: Any = None + self._vdb_op: Any = None # Ordered list of stage names; "extract" is tracked but excluded from # the post-extraction stage_order passed to graph builders. self._stage_order: List[str] = [] @@ -241,14 +245,32 @@ def embed(self, params: Optional[EmbedParams] = None, **kwargs: Any) -> "GraphIn self._record_stage("embed") return self - def vdb_upload(self, params: Optional[VdbUploadParams] = None, **kwargs: Any) -> "GraphIngestor": + def vdb_upload( + self, + params: Optional[VdbUploadParams] = None, + *, + vdb_op: Any = None, + **kwargs: Any, + ) -> "GraphIngestor": """Record a VDB upload sink. Writes embeddings to a vector store as data flows through the graph, eliminating the need to collect all results on the driver. Index creation happens automatically after the pipeline completes. + + Parameters + ---------- + params: + Upload/record-shaping config. When ``vdb_op`` is ``None`` this + also determines which backend to construct. + vdb_op: + Optional pre-constructed client VDB instance (e.g. ``Milvus(...)`` + or ``LanceDB(...)`` from ``nv_ingest_client.util.vdb``). When + provided, the operator uses it directly instead of building one + from ``params``. """ self._vdb_upload_params = _coerce(params, kwargs, default_factory=VdbUploadParams) + self._vdb_op = vdb_op self._record_stage("vdb_upload") return self @@ -301,6 +323,7 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: dedup_params=self._dedup_params, store_params=self._store_params, vdb_upload_params=self._vdb_upload_params, + vdb_op=self._vdb_op, stage_order=post_extract_order, ) # Derive per-node Ray scheduling config from BatchTuningParams plus @@ -351,6 +374,7 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: dedup_params=self._dedup_params, store_params=self._store_params, vdb_upload_params=self._vdb_upload_params, + vdb_op=self._vdb_op, stage_order=post_extract_order, ) executor = InprocessExecutor(graph, show_progress=self._show_progress) @@ -445,43 +469,44 @@ def _finalize_vdb(self) -> None: from nemo_retriever.params.models import LanceDbParams, VdbUploadParams - if isinstance(self._vdb_upload_params, VdbUploadParams): + if self._vdb_op is not None: + backend_type = type(self._vdb_op).__name__.lower() + elif isinstance(self._vdb_upload_params, VdbUploadParams): backend_type = self._vdb_upload_params.backend - lance_params = self._vdb_upload_params.lancedb elif isinstance(self._vdb_upload_params, LanceDbParams): backend_type = "lancedb" - lance_params = self._vdb_upload_params else: return if backend_type != "lancedb": return + if isinstance(self._vdb_upload_params, VdbUploadParams): + lance_params = self._vdb_upload_params.lancedb + elif isinstance(self._vdb_upload_params, LanceDbParams): + lance_params = self._vdb_upload_params + else: + lance_params = LanceDbParams() + if not lance_params.create_index: return import lancedb - from nv_ingest_client.util.vdb import get_vdb_op_cls + from nemo_retriever.vector_store.lancedb_utils import build_client_lancedb + + uri = getattr(self._vdb_op, "uri", lance_params.lancedb_uri) if self._vdb_op else lance_params.lancedb_uri + table_name = ( + getattr(self._vdb_op, "table_name", lance_params.table_name) if self._vdb_op else lance_params.table_name + ) try: - db = lancedb.connect(uri=lance_params.lancedb_uri) - table = db.open_table(lance_params.table_name) + db = lancedb.connect(uri=uri) + table = db.open_table(table_name) except Exception: return - ClientLanceDB = get_vdb_op_cls("lancedb") - client = ClientLanceDB( - uri=lance_params.lancedb_uri, - table_name=lance_params.table_name, - overwrite=False, - index_type=lance_params.index_type, - metric=lance_params.metric, - num_partitions=lance_params.num_partitions, - num_sub_vectors=lance_params.num_sub_vectors, - hybrid=lance_params.hybrid, - fts_language=lance_params.fts_language, - ) + client = self._vdb_op if self._vdb_op is not None else build_client_lancedb(lance_params, overwrite=False) try: client.write_to_index(None, table=table) except RuntimeError: diff --git a/nemo_retriever/src/nemo_retriever/vector_store/__init__.py b/nemo_retriever/src/nemo_retriever/vector_store/__init__.py index de519b5e8..3320092d4 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/__init__.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/__init__.py @@ -8,12 +8,11 @@ write_embeddings_to_lancedb, write_text_embeddings_dir_to_lancedb, ) -from .vdb_records import build_vdb_records, build_vdb_records_from_dicts +from .vdb_records import build_vdb_records __all__ = [ "app", "build_vdb_records", - "build_vdb_records_from_dicts", "create_lancedb_index", "write_embeddings_to_lancedb", "write_text_embeddings_dir_to_lancedb", diff --git a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py index 72db70ab4..167b80fbb 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/lancedb_utils.py @@ -161,29 +161,6 @@ def build_lancedb_row( return row_out -def build_lancedb_rows( - df: Any, - *, - embedding_column: str = "text_embeddings_1b_v2", - embedding_key: str = "embedding", - text_column: str = "text", - include_text: bool = True, -) -> List[Dict[str, Any]]: - """Build LanceDB rows from a pandas DataFrame.""" - rows: List[Dict[str, Any]] = [] - for row in df.itertuples(index=False): - row_out = build_lancedb_row( - row, - embedding_column=embedding_column, - embedding_key=embedding_key, - text_column=text_column, - include_text=include_text, - ) - if row_out is not None: - rows.append(row_out) - return rows - - def lancedb_schema(vector_dim: int = 2048) -> Any: """Return a PyArrow schema for the standard LanceDB table layout.""" import pyarrow as pa # type: ignore @@ -213,6 +190,28 @@ def infer_vector_dim(rows: List[Dict[str, Any]]) -> int: return 0 +def build_client_lancedb(lance_params: Any, *, overwrite: Optional[bool] = None) -> Any: + """Construct an ``nv_ingest_client.util.vdb.LanceDB`` from ``LanceDbParams``. + + Pass ``overwrite=False`` to open an existing table (e.g. for post-pipeline + index creation); otherwise ``lance_params.overwrite`` is used. + """ + from nv_ingest_client.util.vdb import get_vdb_op_cls + + LanceDB = get_vdb_op_cls("lancedb") + return LanceDB( + uri=lance_params.lancedb_uri, + table_name=lance_params.table_name, + overwrite=lance_params.overwrite if overwrite is None else overwrite, + index_type=lance_params.index_type, + metric=lance_params.metric, + num_partitions=lance_params.num_partitions, + num_sub_vectors=lance_params.num_sub_vectors, + hybrid=lance_params.hybrid, + fts_language=lance_params.fts_language, + ) + + def create_or_append_lancedb_table( db: Any, table_name: str, diff --git a/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py b/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py index 668e2a89a..c25d2256e 100644 --- a/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py +++ b/nemo_retriever/src/nemo_retriever/vector_store/vdb_records.py @@ -56,24 +56,3 @@ def build_vdb_records( if row_out is not None: rows.append(row_out) return rows - - -def build_vdb_records_from_dicts( - records: List[Dict[str, Any]], - *, - embedding_column: str = "text_embeddings_1b_v2", - embedding_key: str = "embedding", - text_column: str = "text", - include_text: bool = True, -) -> List[Dict[str, Any]]: - """Convert a list of dicts (e.g. from ``take_all()``) into canonical VDB records.""" - if not records: - return [] - df = pd.DataFrame(records) - return build_vdb_records( - df, - embedding_column=embedding_column, - embedding_key=embedding_key, - text_column=text_column, - include_text=include_text, - ) diff --git a/nemo_retriever/tests/integration_test_vdb_op_passthrough.py b/nemo_retriever/tests/integration_test_vdb_op_passthrough.py new file mode 100644 index 000000000..fbdfc8d0a --- /dev/null +++ b/nemo_retriever/tests/integration_test_vdb_op_passthrough.py @@ -0,0 +1,150 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Integration test: pass pre-constructed client VDB instances into the graph. + +Exercises the lead-review ask for PR 1847 — that `GraphIngestor.vdb_upload` +accept a pre-built `nv_ingest_client.util.vdb` instance (`LanceDB(...)` or +`Milvus(...)`) and wire it directly into `VDBUploadOperator` without the +operator reconstructing it from `VdbUploadParams`. + +Runs a small bo20 ingest against each backend and verifies the target table / +collection was populated. + +Requires: + - Running Milvus at ``MILVUS_URI`` + - GPU for extraction + embedding + - bo20 dataset at ``DATASET_DIR`` + +Usage: + python tests/integration_test_vdb_op_passthrough.py [--milvus-uri URI] +""" + +from __future__ import annotations + +import shutil +import sys +import tempfile +import time +import traceback +from pathlib import Path + +MILVUS_URI = "http://172.20.0.4:19530" +MILVUS_COLLECTION = "vdb_op_passthrough_bo20" +DATASET_DIR = "/datasets/nv-ingest/bo20" +EMBED_MODEL = "nvidia/llama-nemotron-embed-1b-v2" +EMBED_DIM = 2048 + + +def _run_ingest(vdb_op): + from nemo_retriever import create_ingestor + from nemo_retriever.params import EmbedParams + + t0 = time.perf_counter() + ( + create_ingestor(run_mode="batch") + .files(DATASET_DIR + "/*.pdf") + .extract(extract_text=True, extract_tables=True, extract_charts=True) + .embed(EmbedParams(model_name=EMBED_MODEL, inference_batch_size=32)) + .vdb_upload(vdb_op=vdb_op) + .ingest() + ) + return time.perf_counter() - t0 + + +def test_lancedb_passthrough(): + from nv_ingest_client.util.vdb import get_vdb_op_cls + + LanceDB = get_vdb_op_cls("lancedb") + tmp_dir = Path(tempfile.mkdtemp(prefix="vdb_op_passthrough_lance_")) + uri = str(tmp_dir / "lancedb") + table_name = "bo20_lance_passthrough" + + print("--- LanceDB pre-constructed instance ---") + print(f" uri: {uri}") + print(f" table: {table_name}") + + client = LanceDB(uri=uri, table_name=table_name, overwrite=True) + assert type(client).__name__ == "LanceDB", f"expected LanceDB, got {type(client).__name__}" + + elapsed = _run_ingest(client) + + import lancedb + + db = lancedb.connect(uri) + table = db.open_table(table_name) + n = table.count_rows() + print(f"[OK] LanceDB passthrough ingested {n} rows in {elapsed:.1f}s") + + shutil.rmtree(tmp_dir, ignore_errors=True) + return n > 0 + + +def test_milvus_passthrough(): + from nv_ingest_client.util.vdb import get_vdb_op_cls + from pymilvus import MilvusClient + + Milvus = get_vdb_op_cls("milvus") + print("--- Milvus pre-constructed instance ---") + print(f" uri: {MILVUS_URI}") + print(f" collection: {MILVUS_COLLECTION}") + + client = Milvus( + milvus_uri=MILVUS_URI, + collection_name=MILVUS_COLLECTION, + dense_dim=EMBED_DIM, + recreate=True, + gpu_index=False, + stream=True, + sparse=False, + ) + assert type(client).__name__ == "Milvus", f"expected Milvus, got {type(client).__name__}" + + elapsed = _run_ingest(client) + + mc = MilvusClient(uri=MILVUS_URI) + mc.load_collection(MILVUS_COLLECTION) + stats = mc.get_collection_stats(collection_name=MILVUS_COLLECTION) + n = int(stats.get("row_count", 0)) + print(f"[OK] Milvus passthrough ingested {n} rows in {elapsed:.1f}s") + + mc.drop_collection(MILVUS_COLLECTION) + return n > 0 + + +def main(): + print("=" * 60) + print("VDB-op passthrough integration test (bo20)") + print(f" Dataset: {DATASET_DIR}") + print("=" * 60) + + results = {} + try: + results["lancedb"] = test_lancedb_passthrough() + except Exception as exc: # noqa: BLE001 + print(f"[FAIL] LanceDB passthrough: {exc}") + traceback.print_exc() + results["lancedb"] = False + + try: + results["milvus"] = test_milvus_passthrough() + except Exception as exc: # noqa: BLE001 + print(f"[FAIL] Milvus passthrough: {exc}") + traceback.print_exc() + results["milvus"] = False + + print() + print("=" * 60) + print("RESULTS") + for backend, ok in results.items(): + print(f" {backend}: {'PASS' if ok else 'FAIL'}") + print("=" * 60) + return 0 if all(results.values()) else 1 + + +if __name__ == "__main__": + if "--milvus-uri" in sys.argv: + idx = sys.argv.index("--milvus-uri") + MILVUS_URI = sys.argv[idx + 1] + sys.exit(main()) diff --git a/nemo_retriever/tests/test_vdb_record_contract.py b/nemo_retriever/tests/test_vdb_record_contract.py index 0a346e837..29fecdf85 100644 --- a/nemo_retriever/tests/test_vdb_record_contract.py +++ b/nemo_retriever/tests/test_vdb_record_contract.py @@ -6,7 +6,6 @@ These tests validate: - build_vdb_records produces the correct canonical record format - - build_vdb_records_from_dicts (transitional list[dict] path) matches - _ensure_dict handles Arrow serialization robustness """ @@ -160,34 +159,6 @@ def test_include_text_false(self): assert rows[0]["text"] == "" -# --------------------------------------------------------------------------- -# Transitional list[dict] builder tests -# --------------------------------------------------------------------------- - - -class TestBuildVdbRecordsFromDicts: - def test_matches_dataframe_path(self): - """list[dict] path should produce identical output to DataFrame path.""" - from nemo_retriever.vector_store.vdb_records import build_vdb_records, build_vdb_records_from_dicts - - df = _make_sample_dataframe() - records = df.to_dict("records") - - from_df = build_vdb_records(df) - from_dicts = build_vdb_records_from_dicts(records) - - assert len(from_df) == len(from_dicts) - assert from_df[0]["vector"] == from_dicts[0]["vector"] - assert from_df[0]["text"] == from_dicts[0]["text"] - assert from_df[0]["path"] == from_dicts[0]["path"] - assert from_df[0]["page_number"] == from_dicts[0]["page_number"] - - def test_empty_list(self): - from nemo_retriever.vector_store.vdb_records import build_vdb_records_from_dicts - - assert build_vdb_records_from_dicts([]) == [] - - # --------------------------------------------------------------------------- # _ensure_dict — Arrow serialization robustness # --------------------------------------------------------------------------- diff --git a/nemo_retriever/tests/test_vdb_upload_operator.py b/nemo_retriever/tests/test_vdb_upload_operator.py index 322a7e80b..68820b7f7 100644 --- a/nemo_retriever/tests/test_vdb_upload_operator.py +++ b/nemo_retriever/tests/test_vdb_upload_operator.py @@ -256,6 +256,65 @@ def test_multiple_batches_call_write_per_batch(self, mock_get_cls): assert mock_client.write_to_index.call_count == 2 +# --------------------------------------------------------------------------- +# Pre-constructed VDB injection +# --------------------------------------------------------------------------- + + +class TestPreConstructedVDB: + """Operator accepts a pre-built client VDB instance (lead's review ask).""" + + def test_accepts_client_lancedb_instance(self, lance_params): + """Passing a client LanceDB object skips internal construction.""" + from nv_ingest_client.util.vdb import get_vdb_op_cls + + LanceDB = get_vdb_op_cls("lancedb") + client_vdb = LanceDB( + uri=lance_params.lancedb_uri, + table_name=lance_params.table_name, + overwrite=True, + ) + + op = VDBUploadOperator(vdb_op=client_vdb) + assert op._client_vdb is client_vdb + assert op._backend_name == "lancedb" + + op.run(_make_embedded_df(2)) + + import lancedb + + db = lancedb.connect(lance_params.lancedb_uri) + table = db.open_table(lance_params.table_name) + assert table.count_rows() == 2 + + def test_accepts_arbitrary_vdb_instance(self): + """A non-LanceDB instance routes through the client write path.""" + mock_vdb = MagicMock() + type(mock_vdb).__name__ = "Milvus" + mock_vdb.get_connection_params.return_value = ("c", {"milvus_uri": "x"}) + mock_vdb.get_write_params.return_value = ("c", {"collection_name": "c"}) + + op = VDBUploadOperator(vdb_op=mock_vdb) + assert op._backend_name == "milvus" + assert op._client_vdb is mock_vdb + + op.run(_make_embedded_df(2)) + + mock_vdb.create_index.assert_called_once() + mock_vdb.write_to_index.assert_called_once() + + def test_constructor_kwargs_round_trip(self): + """get_constructor_kwargs captures both params and vdb_op for Ray reconstruction.""" + mock_vdb = MagicMock() + type(mock_vdb).__name__ = "Milvus" + + op = VDBUploadOperator(params=VdbUploadParams(), vdb_op=mock_vdb) + kwargs = op.get_constructor_kwargs() + assert "vdb_op" in kwargs + assert kwargs["vdb_op"] is mock_vdb + assert "params" in kwargs + + # --------------------------------------------------------------------------- # Finalization # --------------------------------------------------------------------------- From 2d590e3e19d5b39aae2b71c79295b20c4c53d14d Mon Sep 17 00:00:00 2001 From: Jacob Ioffe Date: Tue, 21 Apr 2026 20:32:38 +0000 Subject: [PATCH 6/6] vector_store: add driver-side finalize hook for VDB sinks Preparatory checkpoint for the Milvus streaming-write fix. Adds a no-op `VDBUploadOperator.finalize()` that the driver calls once after `executor.ingest()` returns (via a new `vdb_upload_ops_out` handle threaded through `build_graph`), so one-shot flush/wait-for-index work can live off the per-batch lifecycle that `AbstractOperator.run()` fires. LanceDB and the bulk-fallback write path remain no-ops. Also reverts the Milvus CLI additions from `graph_pipeline.py` (backend selection moves back to the caller passing `vdb_op=`) and sanitizes the bo20 dataset path in `test_configs.yaml`. Signed-off-by: Jacob Ioffe --- nemo_retriever/harness/test_configs.yaml | 2 +- .../nemo_retriever/examples/graph_pipeline.py | 73 +++++++++---------- .../nemo_retriever/graph/ingestor_runtime.py | 8 +- .../graph/vdb_upload_operator.py | 10 +++ .../src/nemo_retriever/graph_ingestor.py | 8 ++ 5 files changed, 61 insertions(+), 40 deletions(-) diff --git a/nemo_retriever/harness/test_configs.yaml b/nemo_retriever/harness/test_configs.yaml index 4d90e05fe..f98b1af3e 100644 --- a/nemo_retriever/harness/test_configs.yaml +++ b/nemo_retriever/harness/test_configs.yaml @@ -200,7 +200,7 @@ presets: datasets: bo20: - path: /home/jdyer/datasets/bo20 + path: /datasets/nv-ingest/bo20 query_csv: null input_type: pdf recall_required: false diff --git a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py index 081d8a322..dc3760272 100644 --- a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py @@ -130,15 +130,19 @@ def _write_runtime_summary( runtime_metrics_dir: Optional[Path], runtime_metrics_prefix: Optional[str], payload: dict[str, object], + metrics_output_file: Optional[Path] = None, ) -> None: - if runtime_metrics_dir is None and not runtime_metrics_prefix: - return + if runtime_metrics_dir is not None or runtime_metrics_prefix: + target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve() + target_dir.mkdir(parents=True, exist_ok=True) + prefix = (runtime_metrics_prefix or "run").strip() or "run" + target = target_dir / f"{prefix}.runtime.summary.json" + target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") - target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve() - target_dir.mkdir(parents=True, exist_ok=True) - prefix = (runtime_metrics_prefix or "run").strip() or "run" - target = target_dir / f"{prefix}.runtime.summary.json" - target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") + if metrics_output_file is not None: + out_path = Path(metrics_output_file).expanduser() + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") def _count_input_units(result_df) -> int: @@ -292,6 +296,13 @@ def main( runtime_metrics_dir: Optional[Path] = typer.Option(None, "--runtime-metrics-dir", path_type=Path), runtime_metrics_prefix: Optional[str] = typer.Option(None, "--runtime-metrics-prefix"), detection_summary_file: Optional[Path] = typer.Option(None, "--detection-summary-file", path_type=Path), + metrics_output_file: Optional[Path] = typer.Option( + None, + "--metrics-output-file", + path_type=Path, + dir_okay=False, + help="JSON file path to write structured run metrics (used by the harness).", + ), log_file: Optional[Path] = typer.Option(None, "--log-file", path_type=Path, dir_okay=False), ) -> None: _ = ctx @@ -510,9 +521,9 @@ def main( ingestor = ingestor.embed(embed_params) - # VDB upload runs inside the graph — rows stream to LanceDB as they - # are produced, so we never need to collect the entire result set on - # the driver just for the LanceDB write. Index creation happens + # VDB upload runs inside the graph — rows stream to the configured + # backend as they are produced, so we never need to collect the entire + # result set on the driver just for the write. Index creation happens # automatically in GraphIngestor._finalize_vdb() after the pipeline. ingestor = ingestor.vdb_upload( VdbUploadParams( @@ -586,13 +597,7 @@ def main( # ------------------------------------------------------------------ # Recall / BEIR evaluation # ------------------------------------------------------------------ - import lancedb as _lancedb_mod - - db = _lancedb_mod.connect(lancedb_uri) - table = db.open_table(LANCEDB_TABLE) - - if int(table.count_rows()) == 0: - logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode) + def _empty_summary(reason_label: str) -> None: _write_runtime_summary( runtime_metrics_dir, runtime_metrics_prefix, @@ -612,8 +617,19 @@ def main( "recall_details": bool(recall_details), "lancedb_uri": str(lancedb_uri), "lancedb_table": str(LANCEDB_TABLE), + "skip_reason": reason_label, }, + metrics_output_file=metrics_output_file, ) + + import lancedb as _lancedb_mod + + db = _lancedb_mod.connect(lancedb_uri) + table = db.open_table(LANCEDB_TABLE) + + if int(table.count_rows()) == 0: + logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode) + _empty_summary("lancedb_table_empty") if run_mode == "batch": ray.shutdown() return @@ -660,27 +676,7 @@ def main( query_csv_path = Path(query_csv) if not query_csv_path.exists(): logger.warning("Query CSV not found at %s; skipping recall evaluation.", query_csv_path) - _write_runtime_summary( - runtime_metrics_dir, - runtime_metrics_prefix, - { - "run_mode": run_mode, - "input_path": str(Path(input_path).resolve()), - "input_pages": int(num_rows), - "num_pages": int(num_rows), - "num_rows": int(num_rows), - "ingestion_only_secs": float(ingestion_only_total_time), - "ray_download_secs": float(ray_download_time), - "lancedb_write_secs": 0.0, - "evaluation_secs": 0.0, - "total_secs": float(time.perf_counter() - ingest_start), - "evaluation_mode": evaluation_mode, - "evaluation_metrics": {}, - "recall_details": bool(recall_details), - "lancedb_uri": str(lancedb_uri), - "lancedb_table": str(LANCEDB_TABLE), - }, - ) + _empty_summary("query_csv_missing") if run_mode == "batch": ray.shutdown() return @@ -731,6 +727,7 @@ def main( "lancedb_uri": str(lancedb_uri), "lancedb_table": str(LANCEDB_TABLE), }, + metrics_output_file=metrics_output_file, ) if run_mode == "batch": diff --git a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py index 401c4ba95..0059defec 100644 --- a/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py +++ b/nemo_retriever/src/nemo_retriever/graph/ingestor_runtime.py @@ -368,6 +368,7 @@ def _append_ordered_transform_stages( stage_order: tuple[str, ...], supports_dedup: bool, reshape_for_modal_content: bool, + vdb_upload_ops_out: list | None = None, ) -> Graph: """Append post-extraction transform stages in the exact recorded plan order.""" @@ -428,7 +429,10 @@ def _append_ordered_transform_stages( ) graph = graph >> _BatchEmbedActor(params=embed_params) elif stage_name == "vdb_upload" and vdb_upload_params is not None: - graph = graph >> VDBUploadOperator(params=vdb_upload_params, vdb_op=vdb_op) + vdb_upload_op = VDBUploadOperator(params=vdb_upload_params, vdb_op=vdb_op) + if vdb_upload_ops_out is not None: + vdb_upload_ops_out.append(vdb_upload_op) + graph = graph >> vdb_upload_op return graph @@ -450,6 +454,7 @@ def build_graph( vdb_upload_params: Any | None = None, vdb_op: Any | None = None, stage_order: tuple[str, ...] = (), + vdb_upload_ops_out: list | None = None, ) -> Graph: """Build a batch graph from explicit params or a shared execution plan.""" @@ -615,6 +620,7 @@ def build_graph( stage_order=stage_order, supports_dedup=True, reshape_for_modal_content=True, + vdb_upload_ops_out=vdb_upload_ops_out, ) diff --git a/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py index 60988132f..2a95b21c9 100644 --- a/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/vdb_upload_operator.py @@ -175,6 +175,16 @@ def process(self, data: Any, **kwargs: Any) -> Any: def postprocess(self, data: Any, **kwargs: Any) -> Any: return data + def finalize(self) -> None: + """Driver-side post-pipeline hook. + + ``AbstractOperator.run()`` fires the preprocess/process/postprocess + lifecycle once per batch, so anything that must run exactly once at + the end of ingestion (flush, wait-for-index, etc.) is driven from + here by the driver after ``executor.ingest()`` returns. + """ + return + # ------------------------------------------------------------------ # LanceDB streaming write path # ------------------------------------------------------------------ diff --git a/nemo_retriever/src/nemo_retriever/graph_ingestor.py b/nemo_retriever/src/nemo_retriever/graph_ingestor.py index 290bb7704..d2c63ad5a 100644 --- a/nemo_retriever/src/nemo_retriever/graph_ingestor.py +++ b/nemo_retriever/src/nemo_retriever/graph_ingestor.py @@ -321,6 +321,7 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: ) cluster_resources = gather_cluster_resources(ray) + vdb_upload_ops: list[Any] = [] graph = build_graph( extraction_mode=self._extraction_mode, extract_params=self._extract_params, @@ -336,6 +337,7 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: vdb_upload_params=self._vdb_upload_params, vdb_op=self._vdb_op, stage_order=post_extract_order, + vdb_upload_ops_out=vdb_upload_ops, ) # Derive per-node Ray scheduling config from BatchTuningParams plus # cluster-scaled heuristic defaults, then let any explicit @@ -374,9 +376,12 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: ) result = executor.ingest(self._documents) self._rd_dataset = result + for op in vdb_upload_ops: + op.finalize() self._finalize_vdb() return result else: + vdb_upload_ops = [] graph = build_graph( extraction_mode=self._extraction_mode, extract_params=self._extract_params, @@ -392,10 +397,13 @@ def ingest(self, params: Any = None, **kwargs: Any) -> Any: vdb_upload_params=self._vdb_upload_params, vdb_op=self._vdb_op, stage_order=post_extract_order, + vdb_upload_ops_out=vdb_upload_ops, ) executor = InprocessExecutor(graph, show_progress=self._show_progress) self._rd_dataset = None result = executor.ingest(self._documents) + for op in vdb_upload_ops: + op.finalize() self._finalize_vdb() return result