Skip to content
2 changes: 1 addition & 1 deletion nemo_retriever/harness/test_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ presets:

datasets:
bo20:
path: /home/jdyer/datasets/bo20
path: /datasets/nv-ingest/bo20
query_csv: null
input_type: pdf
recall_required: false
Expand Down
143 changes: 67 additions & 76 deletions nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@
from nemo_retriever.params import ExtractParams
from nemo_retriever.params import StoreParams
from nemo_retriever.params import TextChunkParams
from nemo_retriever.params import VdbUploadParams
from nemo_retriever.model import VL_EMBED_MODEL, VL_RERANK_MODEL
from nemo_retriever.params.models import BatchTuningParams
from nemo_retriever.params.models import BatchTuningParams, LanceDbParams
from nemo_retriever.utils.input_files import resolve_input_patterns
from nemo_retriever.utils.remote_auth import resolve_remote_api_key
from nemo_retriever.vector_store.lancedb_store import handle_lancedb

logger = logging.getLogger(__name__)
app = typer.Typer()
Expand Down Expand Up @@ -126,36 +126,23 @@ def _configure_logging(log_file: Optional[Path], *, debug: bool = False) -> tupl
return fh, original_stdout, original_stderr


def _ensure_lancedb_table(uri: str, table_name: str) -> None:
from nemo_retriever.vector_store.lancedb_utils import lancedb_schema
import lancedb
import pyarrow as pa

Path(uri).mkdir(parents=True, exist_ok=True)
db = lancedb.connect(uri)
try:
db.open_table(table_name)
return
except Exception:
pass
schema = lancedb_schema()
empty = pa.table({f.name: [] for f in schema}, schema=schema)
db.create_table(table_name, data=empty, schema=schema, mode="create")


def _write_runtime_summary(
runtime_metrics_dir: Optional[Path],
runtime_metrics_prefix: Optional[str],
payload: dict[str, object],
metrics_output_file: Optional[Path] = None,
) -> None:
if runtime_metrics_dir is None and not runtime_metrics_prefix:
return
if runtime_metrics_dir is not None or runtime_metrics_prefix:
target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve()
target_dir.mkdir(parents=True, exist_ok=True)
prefix = (runtime_metrics_prefix or "run").strip() or "run"
target = target_dir / f"{prefix}.runtime.summary.json"
target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")

target_dir = Path(runtime_metrics_dir or Path.cwd()).expanduser().resolve()
target_dir.mkdir(parents=True, exist_ok=True)
prefix = (runtime_metrics_prefix or "run").strip() or "run"
target = target_dir / f"{prefix}.runtime.summary.json"
target.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")
if metrics_output_file is not None:
out_path = Path(metrics_output_file).expanduser()
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8")


def _count_input_units(result_df) -> int:
Expand Down Expand Up @@ -309,6 +296,13 @@ def main(
runtime_metrics_dir: Optional[Path] = typer.Option(None, "--runtime-metrics-dir", path_type=Path),
runtime_metrics_prefix: Optional[str] = typer.Option(None, "--runtime-metrics-prefix"),
detection_summary_file: Optional[Path] = typer.Option(None, "--detection-summary-file", path_type=Path),
metrics_output_file: Optional[Path] = typer.Option(
None,
"--metrics-output-file",
path_type=Path,
dir_okay=False,
help="JSON file path to write structured run metrics (used by the harness).",
),
log_file: Optional[Path] = typer.Option(None, "--log-file", path_type=Path, dir_okay=False),
) -> None:
_ = ctx
Expand All @@ -327,7 +321,6 @@ def main(
os.environ["RAY_LOG_TO_DRIVER"] = "1" if ray_log_to_driver else "0"

lancedb_uri = str(Path(lancedb_uri).expanduser().resolve())
_ensure_lancedb_table(lancedb_uri, LANCEDB_TABLE)

remote_api_key = resolve_remote_api_key(api_key)
extract_remote_api_key = remote_api_key
Expand Down Expand Up @@ -528,39 +521,58 @@ def main(

ingestor = ingestor.embed(embed_params)

# VDB upload runs inside the graph — rows stream to the configured
# backend as they are produced, so we never need to collect the entire
# result set on the driver just for the write. Index creation happens
# automatically in GraphIngestor._finalize_vdb() after the pipeline.
ingestor = ingestor.vdb_upload(
VdbUploadParams(
lancedb=LanceDbParams(
lancedb_uri=lancedb_uri,
table_name=LANCEDB_TABLE,
hybrid=hybrid,
overwrite=True,
),
)
)

# ------------------------------------------------------------------
# Execute the graph via the executor
# ------------------------------------------------------------------
logger.info("Starting ingestion of %s ...", input_path)
ingest_start = time.perf_counter()

# GraphIngestor.ingest() builds the Graph, creates the executor,
# and calls executor.ingest(file_patterns) returning:
# calls executor.ingest(file_patterns), and finalizes the VDB index.
# batch mode -> materialized ray.data.Dataset
# inprocess mode -> pandas.DataFrame
result = ingestor.ingest()

ingestion_only_total_time = time.perf_counter() - ingest_start

# ------------------------------------------------------------------
# Collect results
# Collect results (only when needed for detection summary / counting)
# ------------------------------------------------------------------
if run_mode == "batch":
import ray

ray_download_start = time.perf_counter()
ingest_local_results = result.take_all()
ray_download_time = time.perf_counter() - ray_download_start
if detection_summary_file is not None:
ray_download_start = time.perf_counter()
ingest_local_results = result.take_all()
ray_download_time = time.perf_counter() - ray_download_start

import pandas as pd
import pandas as pd

result_df = pd.DataFrame(ingest_local_results)
num_rows = _count_input_units(result_df)
result_df = pd.DataFrame(ingest_local_results)
num_rows = _count_input_units(result_df)
else:
ray_download_time = 0.0
result_df = None
num_rows = result.count()
else:
import pandas as pd

result_df = result
ingest_local_results = result_df.to_dict("records")
ray_download_time = 0.0
num_rows = _count_input_units(result_df)

Expand All @@ -582,23 +594,10 @@ def main(
collect_detection_summary_from_df(result_df),
)

# ------------------------------------------------------------------
# Write to LanceDB
# ------------------------------------------------------------------
lancedb_write_start = time.perf_counter()
handle_lancedb(ingest_local_results, lancedb_uri, LANCEDB_TABLE, hybrid=hybrid, mode="overwrite")
lancedb_write_time = time.perf_counter() - lancedb_write_start

# ------------------------------------------------------------------
# Recall / BEIR evaluation
# ------------------------------------------------------------------
import lancedb as _lancedb_mod

db = _lancedb_mod.connect(lancedb_uri)
table = db.open_table(LANCEDB_TABLE)

if int(table.count_rows()) == 0:
logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode)
def _empty_summary(reason_label: str) -> None:
_write_runtime_summary(
runtime_metrics_dir,
runtime_metrics_prefix,
Expand All @@ -607,19 +606,30 @@ def main(
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"num_rows": int(num_rows),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"lancedb_write_secs": float(lancedb_write_time),
"lancedb_write_secs": 0.0,
"evaluation_secs": 0.0,
"total_secs": float(time.perf_counter() - ingest_start),
"evaluation_mode": evaluation_mode,
"evaluation_metrics": {},
"recall_details": bool(recall_details),
"lancedb_uri": str(lancedb_uri),
"lancedb_table": str(LANCEDB_TABLE),
"skip_reason": reason_label,
},
metrics_output_file=metrics_output_file,
)

import lancedb as _lancedb_mod

db = _lancedb_mod.connect(lancedb_uri)
table = db.open_table(LANCEDB_TABLE)

if int(table.count_rows()) == 0:
logger.warning("LanceDB table is empty; skipping %s evaluation.", evaluation_mode)
_empty_summary("lancedb_table_empty")
if run_mode == "batch":
ray.shutdown()
return
Expand Down Expand Up @@ -666,27 +676,7 @@ def main(
query_csv_path = Path(query_csv)
if not query_csv_path.exists():
logger.warning("Query CSV not found at %s; skipping recall evaluation.", query_csv_path)
_write_runtime_summary(
runtime_metrics_dir,
runtime_metrics_prefix,
{
"run_mode": run_mode,
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"lancedb_write_secs": float(lancedb_write_time),
"evaluation_secs": 0.0,
"total_secs": float(time.perf_counter() - ingest_start),
"evaluation_mode": evaluation_mode,
"evaluation_metrics": {},
"recall_details": bool(recall_details),
"lancedb_uri": str(lancedb_uri),
"lancedb_table": str(LANCEDB_TABLE),
},
)
_empty_summary("query_csv_missing")
if run_mode == "batch":
ray.shutdown()
return
Expand Down Expand Up @@ -724,10 +714,10 @@ def main(
"input_path": str(Path(input_path).resolve()),
"input_pages": int(num_rows),
"num_pages": int(num_rows),
"num_rows": int(len(result_df.index)),
"num_rows": int(num_rows),
"ingestion_only_secs": float(ingestion_only_total_time),
"ray_download_secs": float(ray_download_time),
"lancedb_write_secs": float(lancedb_write_time),
"lancedb_write_secs": 0.0,
"evaluation_secs": float(evaluation_total_time),
"total_secs": float(total_time),
"evaluation_mode": evaluation_mode,
Expand All @@ -737,6 +727,7 @@ def main(
"lancedb_uri": str(lancedb_uri),
"lancedb_table": str(LANCEDB_TABLE),
},
metrics_output_file=metrics_output_file,
)

if run_mode == "batch":
Expand All @@ -751,7 +742,7 @@ def main(
total_time,
ingestion_only_total_time,
ray_download_time,
lancedb_write_time,
0.0,
evaluation_total_time,
evaluation_metrics,
evaluation_label=evaluation_label,
Expand Down
2 changes: 2 additions & 0 deletions nemo_retriever/src/nemo_retriever/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from nemo_retriever.graph.graph_pipeline_registry import GraphPipelineRegistry, default_registry
from nemo_retriever.graph.pipeline_graph import Graph, Node
from nemo_retriever.graph.store_operator import StoreOperator
from nemo_retriever.graph.vdb_upload_operator import VDBUploadOperator

__all__ = [
"AbstractExecutor",
Expand All @@ -32,6 +33,7 @@
"RayDataExecutor",
"StoreOperator",
"UDFOperator",
"VDBUploadOperator",
"default_registry",
]

Expand Down
Loading
Loading