From d9af7e4cf2cf51db53e68758ba5f745cb1b76480 Mon Sep 17 00:00:00 2001 From: Kyle Zheng Date: Wed, 15 Apr 2026 23:43:39 +0000 Subject: [PATCH 01/10] add support for full e2e qa run from one cli command --- .../src/nemo_retriever/evaluation/README.md | 150 +++++++++++++++++- .../src/nemo_retriever/evaluation/cli.py | 57 +++++-- .../src/nemo_retriever/evaluation/config.py | 8 +- .../nemo_retriever/evaluation/retrievers.py | 98 +++++++++++- .../nemo_retriever/examples/graph_pipeline.py | 96 ++++++++++- 5 files changed, 389 insertions(+), 20 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/README.md b/nemo_retriever/src/nemo_retriever/evaluation/README.md index 2926864a9..650f7733f 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/README.md +++ b/nemo_retriever/src/nemo_retriever/evaluation/README.md @@ -4,7 +4,7 @@ The evaluation framework lives in **`nemo_retriever.evaluation`** (install `nemo Measures LLM answer quality over a RAG pipeline: retrieve context from a VDB, generate answers with one or more LLMs, and score each answer against ground-truth references using multi-tier scoring and an LLM-as-judge. -**Pluggable retrieval:** The evaluation framework does not care how you retrieved chunks -- only that you produce a JSON file that matches the **[retrieval JSON specification](#retrieval-json-format-interface-contract)** expected by `retriever eval run` / `FileRetriever`. Vector search, hybrid, agentic pipelines, or any custom system can plug in as long as the file format and query strings align with your chosen ground-truth dataset. +**Pluggable retrieval:** The evaluation framework does not care how you retrieved chunks -- only that you produce a JSON file that matches the **[retrieval JSON specification](#retrieval-json-format-interface-contract)** expected by `retriever eval run` / `FileRetriever`. Vector search, hybrid, agentic pipelines, or any custom system can plug in as long as the file format and query strings align with your chosen ground-truth dataset. Alternatively, set `retrieval.type: "lancedb"` to **[query LanceDB live in-memory](#in-memory-lancedb-retrieval)** -- skipping the export step entirely while optionally saving the JSON for future re-runs. **Default ground truth:** Standalone runs default to **`data/bo767_annotations.csv`** at the repo root -- the **bo767 annotations subset** maintained for this benchmark (multi-modality Q&A over the bo767 PDFs). Override with `QA_DATASET` or another registered loader when comparing different corpora. @@ -14,6 +14,7 @@ Designed to be **plug-and-play** -- swap retrievers, generators, or judges indep - [Pipeline File Map and Data Flow](#pipeline-file-map-and-data-flow) - [Reproducing the bo767 Run](#reproducing-the-bo767-run) + - [Which Path Should I Use?](#which-path-should-i-use) - [Retrieval JSON Format (Interface Contract)](#retrieval-json-format-interface-contract) - [Custom Datasets (CSV Loader)](#custom-datasets-csv-loader) - [Architecture](#architecture) @@ -26,6 +27,7 @@ Designed to be **plug-and-play** -- swap retrievers, generators, or judges indep - [Entry Points](#entry-points) - [Configuration](#configuration) - [Eval Config File (YAML / JSON)](#eval-config-file-yaml--json) + - [In-Memory LanceDB Retrieval](#in-memory-lancedb-retrieval) - [Scoring System (Three-Tier Hierarchy)](#scoring-system-three-tier-hierarchy) - [Adding a New Component](#adding-a-new-component) - [Output Format](#output-format) @@ -43,7 +45,7 @@ End-to-end bo767 + LanceDB + full-page markdown touches these **artifacts** and | **4. Ground truth** | `data/bo767_annotations.csv` (repo root) | Questions/answers for export and eval; must align with **query string normalization** in `FileRetriever` (see retrieval JSON rules). | | **5. Evaluation** | `qa_results_*.json` | `retriever eval run` or operator graph chain -> `nemo_retriever.evaluation`: `RetrievalLoaderOperator >> QAGenerationOperator >> JudgingOperator >> ScoringOperator`, or `QAEvalPipeline` for multi-model sweeps. | -**Data flow (conceptual):** PDFs -> (A) **chunked embeddings in LanceDB** for similarity search; (B) **Parquet** for full-page reconstruction. **Export** runs search on (A), then **replaces** hit chunks with pages from (B) via the index. **Eval** never talks to LanceDB -- it only reads the retrieval JSON + ground-truth CSV. +**Data flow (conceptual):** PDFs -> (A) **chunked embeddings in LanceDB** for similarity search; (B) **Parquet** for full-page reconstruction. **Export** runs search on (A), then **replaces** hit chunks with pages from (B) via the index. In **file mode**, eval reads the retrieval JSON + ground-truth CSV. In **[lancedb mode](#in-memory-lancedb-retrieval)**, eval queries LanceDB directly in-memory (optionally saving the JSON for later re-runs). ``` NeMo Retriever (steps 1-3) Universal (steps 4-5) @@ -81,11 +83,27 @@ Exact commands to reproduce the full-page markdown QA evaluation from scratch. **Debug:** Lance index build can hit `No space left on device` when `/tmp` is a tiny tmpfs; set `export TMPDIR=/path/to/large/filesystem/tmp` and `mkdir -p "$TMPDIR"` before step 1. If `extraction.parquet` was written but LanceDB failed, retry with `python -c "from nemo_retriever.utils.parquet_to_lancedb import reload_parquet_to_lancedb; reload_parquet_to_lancedb('', '')""`; otherwise re-run `graph_pipeline`. +**Before running any quick reference below**, complete the one-time setup: + +```bash +# 1. Create and activate a Python 3.12 environment +uv venv qa-retriever --python 3.12 +source qa-retriever/bin/activate + +# 2. Install nemo_retriever with eval extras (from repo root) +cd /path/to/nv-ingest +uv pip install -e "./nemo_retriever[eval]" + +# 3. Set your API key (used by generation + judging) +export NVIDIA_API_KEY="nvapi-..." +``` + +See [Python environment](#python-environment) and [Prerequisites](#prerequisites-data-and-keys) for details. +
Quick reference -- full-page markdown (all commands) ```bash -# All commands from repo root cd /path/to/nv-ingest # 1. Ingest + embed + save Parquet in one pass (~45-90 min) @@ -106,7 +124,6 @@ retriever eval export \ --page-index data/bo767_page_markdown.json # 4. Run QA evaluation (~1-2 hrs) -export NVIDIA_API_KEY="nvapi-..." retriever eval run --config nemo_retriever/examples/eval_sweep.yaml ``` @@ -134,12 +151,64 @@ retriever eval export \ --output data/eval/bo767_retrieval.json # 3. Run QA evaluation -export NVIDIA_API_KEY="nvapi-..." retriever eval run --config nemo_retriever/examples/eval_sweep.yaml ```
+
+Quick reference -- end-to-end in-memory (single command: ingest + eval) + +Alternative to the separable export+eval path above. One command ingests +PDFs, builds the full-page markdown index in-memory, queries LanceDB, and +runs generation + judging + scoring. Optionally saves the retrieval JSON +so you can re-run eval later without re-querying. + +```bash +cd /path/to/nv-ingest + +# Single command: ingest -> page index -> LanceDB query -> QA eval +python -m nemo_retriever.examples.graph_pipeline /path/to/bo767 \ + --lancedb-uri lancedb \ + --evaluation-mode qa \ + --eval-config nemo_retriever/examples/eval_sweep.yaml \ + --query-csv data/bo767_annotations.csv \ + --retrieval-save-path data/eval/bo767_retrieval.json +``` + +The page index is built automatically from the ingestion results (no +separate step needed). Pass `--page-index ` to use a pre-built +index instead, or `--save-intermediate ` if you also want the +extraction Parquet saved for other uses. + +Or, if you already have a LanceDB from a previous ingestion and want to +run eval standalone (no re-ingestion): + +```bash +export LANCEDB_URI="lancedb" +export QA_DATASET="csv:data/bo767_annotations.csv" +export RETRIEVAL_SAVE_PATH="data/eval/bo767_retrieval.json" # optional +retriever eval run --from-env +``` + +
+ +### Which path should I use? + +| | **Option A: Separable (recommended)** | **Option B: End-to-end in-memory** | +|---|---|---| +| **Steps** | Ingest -> Export JSON -> Run eval | Single command: Ingest -> Page index -> LanceDB query -> Eval | +| **Re-run eval with a different model?** | Instant -- just re-run step 4 with the same JSON | Must re-query LanceDB (or pass `save_path` to cache) | +| **Share results with teammates?** | Send the retrieval JSON file | Must share LanceDB or use `save_path` | +| **Best for** | Benchmarking, CI, reproducible comparisons | Quick end-to-end validation, development iteration | +| **Commands** | `retriever eval export` + `retriever eval run` | `--evaluation-mode qa` on `graph_pipeline` or `LANCEDB_URI` with `--from-env` | + +Option A is the primary design -- the retrieval JSON is the interface contract +that lets you re-run eval N times without re-querying, swap retrievers, and +share results. Option B is a convenience shortcut for when you want a single +end-to-end pass (e.g. validating a new ingestion pipeline). Both produce +identical evaluation results. + ### Bring your own retrieval (skip steps 1-3) Steps 1-3 below are the **NeMo Retriever + LanceDB** reference implementation @@ -245,6 +314,9 @@ Output: `data/bo767_page_markdown.json` (~180 MB, ~6k pages across 767 docs). ### Step 3: Export retrieval results (NeMo Retriever) +> **Skip this step** if using the [end-to-end in-memory path](#which-path-should-i-use) +> (`--evaluation-mode qa`), which queries LanceDB live instead of reading a JSON file. + Queries LanceDB for each ground-truth question via `nemo_retriever.export.export_retrieval_json()`, then looks up the full-page markdown for each hit's page. Multiple sub-page hits from the same page are deduplicated into a single full-page chunk. @@ -276,6 +348,11 @@ only step 1 without `--save-intermediate` (no Parquet or page index needed). ### Step 4: Run QA evaluation +> **Alternative:** skip steps 3 and 4 and use `--evaluation-mode qa` on +> `graph_pipeline` to query LanceDB in-memory and run eval in one pass. +> See [Which path should I use?](#which-path-should-i-use) and the +> end-to-end quick reference above. + **Estimated time: ~15 min - 45 min** (1005 queries, ~12s per query for generation + judge, 8 concurrent workers). ```bash @@ -709,6 +786,65 @@ LiteLLM routes by prefix: For local vLLM/Ollama, use `openai/` with `api_base: http://localhost:8000/v1`. +### In-Memory LanceDB Retrieval + +Instead of pre-exporting a retrieval JSON, the eval pipeline can query LanceDB +live and feed results directly to generation/judging. Set `retrieval.type` to +`"lancedb"` in your config: + +```yaml +retrieval: + type: "lancedb" + lancedb_uri: "lancedb" + lancedb_table: "nv-ingest" + embedder: "nvidia/llama-nemotron-embed-1b-v2" + save_path: "data/eval/bo767_retrieval.json" # optional -- persists for re-runs + page_index: "data/bo767_page_markdown.json" # optional -- full-page expansion +``` + +| Field | Default | Description | +|-------|---------|-------------| +| `type` | `"file"` | `"file"` reads a pre-exported JSON; `"lancedb"` queries live | +| `lancedb_uri` | `"lancedb"` | LanceDB directory path | +| `lancedb_table` | `"nv-ingest"` | Table name inside the LanceDB directory | +| `embedder` | `"nvidia/llama-nemotron-embed-1b-v2"` | Embedding model for query encoding | +| `save_path` | *none* | If set, writes the retrieval JSON for later `type: "file"` re-runs | +| `page_index` | *none* | JSON page-markdown index for full-page chunk expansion | + +Three usage modes from a single config: + +- **End-to-end + save**: `type: lancedb, save_path: retrieval.json` -- queries live, saves for future, runs eval +- **End-to-end no save**: `type: lancedb` -- queries live, pure in-memory, runs eval +- **File-only (existing)**: `type: file, file_path: retrieval.json` -- unchanged, reads saved file + +**CLI (--from-env mode):** set `LANCEDB_URI` instead of `RETRIEVAL_FILE`: + +```bash +export LANCEDB_URI="lancedb" +export LANCEDB_TABLE="nv-ingest" # optional, default: nv-ingest +export EMBEDDER="nvidia/llama-nemotron-embed-1b-v2" # optional +export RETRIEVAL_SAVE_PATH="data/eval/retrieval.json" # optional +export QA_DATASET="csv:data/bo767_annotations.csv" +retriever eval run --from-env +``` + +**Graph pipeline (single command: ingest + page index + QA eval):** + +The page index is built automatically from the ingestion results -- no +separate `build-page-index` step is needed. + +```bash +python -m nemo_retriever.examples.graph_pipeline /data/pdfs \ + --lancedb-uri lancedb \ + --evaluation-mode qa \ + --eval-config nemo_retriever/examples/eval_sweep.yaml \ + --query-csv data/bo767_annotations.csv \ + --retrieval-save-path data/eval/bo767_retrieval.json +``` + +Pass `--page-index ` to use a pre-built index instead of the +automatic in-memory build. + ### Environment Variables | Variable | Used By | Purpose | @@ -717,6 +853,10 @@ For local vLLM/Ollama, use `openai/` with `api_base: http://localhost:800 | `NVIDIA_NIM_API_KEY` | litellm's `nvidia_nim` provider | Alias -- set to same value as `NVIDIA_API_KEY` | | `GEN_API_KEY` | `retriever eval run --from-env`, config `${GEN_API_KEY}` | Generator API key (falls back to `NVIDIA_API_KEY`) | | `JUDGE_API_KEY` | `retriever eval run --from-env`, config `${JUDGE_API_KEY}` | Judge API key (falls back to `NVIDIA_API_KEY`) | +| `LANCEDB_URI` | `retriever eval run --from-env` (lancedb mode) | LanceDB directory path (activates lancedb mode when `RETRIEVAL_FILE` is unset) | +| `LANCEDB_TABLE` | `retriever eval run --from-env` (lancedb mode) | LanceDB table name (default: `nv-ingest`) | +| `EMBEDDER` | `retriever eval run --from-env` (lancedb mode) | Embedding model name | +| `RETRIEVAL_SAVE_PATH` | `retriever eval run --from-env` (lancedb mode) | Optional path to persist the retrieval JSON | ## Scoring System (Three-Tier Hierarchy) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/cli.py b/nemo_retriever/src/nemo_retriever/evaluation/cli.py index c434ca41e..543c87823 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/cli.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/cli.py @@ -126,8 +126,12 @@ def _build_env_config() -> tuple[dict, str, str, str, float]: Returns ``(config, qa_dataset, ground_truth_dir, results_dir, min_coverage)``. """ retrieval_file = os.environ.get("RETRIEVAL_FILE", "") - if not retrieval_file: - typer.echo("ERROR: RETRIEVAL_FILE environment variable is required with --from-env", err=True) + lancedb_uri = os.environ.get("LANCEDB_URI", "") + if not retrieval_file and not lancedb_uri: + typer.echo( + "ERROR: set RETRIEVAL_FILE (file mode) or LANCEDB_URI (lancedb mode) with --from-env", + err=True, + ) raise typer.Exit(code=1) qa_dataset = os.environ.get("QA_DATASET", "") @@ -187,6 +191,17 @@ def _build_env_config() -> tuple[dict, str, str, str, float]: os.path.dirname(os.environ.get("OUTPUT_FILE", "")) or "data/test_retrieval", ) + if retrieval_file: + retrieval_block: dict[str, str | None] = {"type": "file", "file_path": retrieval_file} + else: + retrieval_block = { + "type": "lancedb", + "lancedb_uri": lancedb_uri, + "lancedb_table": os.environ.get("LANCEDB_TABLE", "nv-ingest"), + "embedder": os.environ.get("EMBEDDER", "nvidia/llama-nemotron-embed-1b-v2"), + "save_path": os.environ.get("RETRIEVAL_SAVE_PATH"), + } + config = { "execution": { "runs": 1, @@ -196,7 +211,7 @@ def _build_env_config() -> tuple[dict, str, str, str, float]: "min_coverage": min_coverage, }, "dataset": {"source": qa_dataset, "ground_truth_dir": ground_truth_dir}, - "retrieval": {"file_path": retrieval_file}, + "retrieval": retrieval_block, "models": models, "evaluations": evaluations, "output": {"results_dir": results_dir}, @@ -258,11 +273,6 @@ def run_cmd( typer.echo("ERROR: dataset.source is required in config", err=True) raise typer.Exit(code=1) - retrieval_file = retrieval_cfg.get("file_path", "") - if not retrieval_file: - typer.echo("ERROR: retrieval.file_path is required in config", err=True) - raise typer.Exit(code=1) - results_dir = output_cfg.get("results_dir", "data/test_retrieval") qa_limit = execution.get("limit", 0) min_coverage = execution.get("min_coverage", 0.0) @@ -276,13 +286,38 @@ def run_cmd( qa_pairs = qa_pairs[:qa_limit] typer.echo(f"limit={qa_limit}: evaluating first {len(qa_pairs)} pairs") - retriever = FileRetriever(file_path=retrieval_file) + retrieval_type = retrieval_cfg.get("type", "file") + if retrieval_type == "lancedb": + page_index_path = retrieval_cfg.get("page_index") + page_idx = None + if page_index_path: + with open(page_index_path, encoding="utf-8") as f: + page_idx = json.load(f) + typer.echo(f"Loaded page index: {len(page_idx)} documents") + + save_path = retrieval_cfg.get("save_path") + retriever = FileRetriever.from_lancedb( + qa_pairs=qa_pairs, + lancedb_uri=retrieval_cfg.get("lancedb_uri", "lancedb"), + lancedb_table=retrieval_cfg.get("lancedb_table", "nv-ingest"), + embedder=retrieval_cfg.get("embedder", "nvidia/llama-nemotron-embed-1b-v2"), + top_k=execution.get("top_k", 5), + page_index=page_idx, + save_path=save_path, + ) + typer.echo("Built retriever from LanceDB (in-memory)") + else: + retrieval_file = retrieval_cfg.get("file_path", "") + if not retrieval_file: + typer.echo("ERROR: retrieval.file_path is required when type='file'", err=True) + raise typer.Exit(code=1) + retriever = FileRetriever(file_path=retrieval_file) + coverage = retriever.check_coverage(qa_pairs) typer.echo(f"Coverage: {coverage:.1%}") if coverage < min_coverage: typer.echo( - f"ERROR: retrieval file covers only {coverage:.1%} of queries " - f"(min_coverage={min_coverage:.0%}). Aborting.", + f"ERROR: retrieval covers only {coverage:.1%} of queries " f"(min_coverage={min_coverage:.0%}). Aborting.", err=True, ) raise typer.Exit(code=1) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/config.py b/nemo_retriever/src/nemo_retriever/evaluation/config.py index 1b5c7ce3a..ab8d8de98 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/config.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/config.py @@ -345,8 +345,14 @@ def build_eval_pipeline(config: dict) -> "QAEvalPipeline": retrieval_type = retrieval.get("type", "file") if retrieval_type == "file": retriever = FileRetriever(file_path=retrieval["file_path"]) + elif retrieval_type == "lancedb": + raise ValueError( + "retrieval.type='lancedb' requires the caller to build the retriever " + "via FileRetriever.from_lancedb() and pass it to run_eval_sweep(retriever=...). " + "build_eval_pipeline() cannot construct it because it needs qa_pairs." + ) else: - raise ValueError(f"Unsupported retrieval type: {retrieval_type!r}. " "Currently only 'file' is supported.") + raise ValueError(f"Unsupported retrieval type: {retrieval_type!r}. " "Supported: 'file', 'lancedb'.") llm_clients: dict[str, LiteLLMClient] = {} for gen_cfg in generators: diff --git a/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py b/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py index 0b4ad0d61..c2d09a387 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py @@ -5,11 +5,13 @@ """ Retriever strategy implementations for the QA evaluation pipeline. -FileRetriever: reads pre-computed retrieval results from a JSON file. +FileRetriever: reads pre-computed retrieval results from a JSON file, +or queries LanceDB in-memory via ``from_lancedb()``. FileRetriever is the primary integration point. Any retrieval method -- vector search, agentic retrieval, hybrid, reranked, BM25, or a completely custom -pipeline -- can plug into the QA eval harness by writing a single JSON file. +pipeline -- can plug into the QA eval harness by writing a single JSON file +or by using ``FileRetriever.from_lancedb()`` to query a live vector DB. """ from __future__ import annotations @@ -20,6 +22,7 @@ import re import threading import unicodedata + from nemo_retriever.evaluation.types import RetrievalResult logger = logging.getLogger(__name__) @@ -86,6 +89,97 @@ def __init__(self, file_path: str): self._miss_count = 0 self._miss_lock = threading.Lock() + @classmethod + def _from_dict(cls, queries: dict[str, dict]) -> "FileRetriever": + """Build a FileRetriever from an in-memory queries dict. + + Bypasses file I/O while reusing the same normalized index that + ``__init__`` builds from JSON. All instance methods (``retrieve``, + ``check_coverage``) work identically afterwards. + + Parameters + ---------- + queries : dict + ``{query_text: {"chunks": [...], "metadata": [...]}}`` -- + the same shape as the ``"queries"`` value in a retrieval JSON. + """ + if not queries: + raise ValueError("FileRetriever._from_dict: queries dict is empty") + sample = next(iter(queries.values()), {}) + if not isinstance(sample.get("chunks"), list): + raise ValueError( + "FileRetriever._from_dict: first entry is missing a 'chunks' list. " + 'Expected: {"query": {"chunks": ["..."]}}' + ) + + instance = object.__new__(cls) + instance.file_path = "" + instance._norm_index = {} + instance._raw_keys = {} + instance._miss_count = 0 + instance._miss_lock = threading.Lock() + for raw_key, value in queries.items(): + norm = _normalize_query(raw_key) + instance._norm_index[norm] = value + instance._raw_keys[norm] = raw_key + return instance + + @classmethod + def from_lancedb( + cls, + qa_pairs: list[dict], + lancedb_uri: str = "lancedb", + lancedb_table: str = "nv-ingest", + embedder: str = "nvidia/llama-nemotron-embed-1b-v2", + top_k: int = 5, + page_index: dict[str, dict[str, str]] | None = None, + save_path: str | None = None, + ) -> "FileRetriever": + """Query LanceDB in-memory, optionally save, return a FileRetriever. + + Reuses :func:`~nemo_retriever.export.query_lancedb` for batched + vector search and :func:`~nemo_retriever.export.write_retrieval_json` + for optional disk persistence. + + Parameters + ---------- + qa_pairs : list[dict] + Ground-truth pairs; each must have a ``"query"`` key. + lancedb_uri : str + Path to the LanceDB directory. + lancedb_table : str + LanceDB table name. + embedder : str + Embedding model name for query encoding. + top_k : int + Number of chunks to retrieve per query. + page_index : dict, optional + ``{source_id: {page_str: markdown}}``. Enables full-page + markdown expansion when provided. + save_path : str, optional + If set, also writes the retrieval JSON to this path so it + can be reloaded later via ``FileRetriever(file_path=...)``. + """ + from nemo_retriever.export import query_lancedb, write_retrieval_json + + all_results, meta = query_lancedb( + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + queries=qa_pairs, + top_k=top_k, + embedder=embedder, + page_index=page_index, + ) + + if save_path: + write_retrieval_json(all_results, save_path, meta) + logger.info("Saved retrieval JSON to %s", save_path) + + instance = cls._from_dict(all_results) + if save_path: + instance.file_path = save_path + return instance + def check_coverage(self, qa_pairs: list[dict]) -> float: """Validate retrieval file covers the ground-truth queries.""" total = len(qa_pairs) diff --git a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py index 3e1091778..e3e890c49 100644 --- a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py @@ -297,6 +297,27 @@ def main( audio_split_type: str = typer.Option("size", "--audio-split-type"), audio_split_interval: int = typer.Option(500000, "--audio-split-interval", min=1), evaluation_mode: str = typer.Option("recall", "--evaluation-mode"), + eval_config: Optional[Path] = typer.Option( + None, + "--eval-config", + help="Path to QA eval sweep YAML/JSON config (required for --evaluation-mode=qa).", + path_type=Path, + dir_okay=False, + ), + retrieval_save_path: Optional[Path] = typer.Option( + None, + "--retrieval-save-path", + help="Save the LanceDB retrieval JSON here for later re-runs (--evaluation-mode=qa).", + path_type=Path, + dir_okay=False, + ), + page_index: Optional[Path] = typer.Option( + None, + "--page-index", + help="Page markdown index JSON for full-page chunk expansion (--evaluation-mode=qa).", + path_type=Path, + dir_okay=False, + ), reranker: Optional[bool] = typer.Option(False, "--reranker/--no-reranker"), reranker_model_name: str = typer.Option(VL_RERANK_MODEL, "--reranker-model-name"), beir_loader: Optional[str] = typer.Option(None, "--beir-loader"), @@ -320,8 +341,10 @@ def main( raise ValueError(f"Unsupported --recall-match-mode: {recall_match_mode!r}") if audio_split_type not in {"size", "time", "frame"}: raise ValueError(f"Unsupported --audio-split-type: {audio_split_type!r}") - if evaluation_mode not in {"recall", "beir"}: + if evaluation_mode not in {"recall", "beir", "qa"}: raise ValueError(f"Unsupported --evaluation-mode: {evaluation_mode!r}") + if evaluation_mode == "qa" and not eval_config: + raise ValueError("--eval-config is required when --evaluation-mode=qa") if run_mode == "batch": os.environ["RAY_LOG_TO_DRIVER"] = "1" if ray_log_to_driver else "0" @@ -662,6 +685,77 @@ def main( evaluation_total_time = time.perf_counter() - evaluation_start evaluation_label = "BEIR" evaluation_query_count = len(beir_dataset.query_ids) + elif evaluation_mode == "qa": + import csv as csv_mod + + from nemo_retriever.evaluation.config import load_eval_config + from nemo_retriever.evaluation.retrievers import FileRetriever + from nemo_retriever.evaluation.runner import run_eval_sweep + + qa_csv_path = Path(query_csv) + if not qa_csv_path.exists(): + raise FileNotFoundError(f"Query CSV not found: {qa_csv_path}") + + qa_pairs: list[dict] = [] + with open(qa_csv_path, newline="", encoding="utf-8") as f: + for row in csv_mod.DictReader(f): + q = row.get("query", "").strip() + if q: + qa_pairs.append(row) + logger.info("Loaded %d Q&A pairs from %s", len(qa_pairs), qa_csv_path) + + page_idx = None + if page_index is not None: + with open(page_index, encoding="utf-8") as f: + page_idx = json.load(f) + logger.info("Loaded page index from file: %d documents", len(page_idx)) + else: + from nemo_retriever.io.markdown import build_page_index + + logger.info("Building page index from ingestion results (%d rows)...", len(result_df)) + page_idx, page_failures = build_page_index(dataframe=result_df) + if page_failures: + logger.warning("Page index: %d documents failed rendering", len(page_failures)) + if not page_idx: + logger.warning( + "Page index is empty -- all documents failed rendering. " + "Retrieval will fall back to sub-page chunks." + ) + else: + logger.info("Built page index: %d documents", len(page_idx)) + + qa_cfg = load_eval_config(str(eval_config)) + qa_top_k = qa_cfg.get("execution", {}).get("top_k", 5) + results_dir = qa_cfg.get("output", {}).get("results_dir", "data/qa_results") + + evaluation_start = time.perf_counter() + retriever = FileRetriever.from_lancedb( + qa_pairs=qa_pairs, + lancedb_uri=str(lancedb_uri), + lancedb_table=str(LANCEDB_TABLE), + embedder=_recall_model, + top_k=qa_top_k, + page_index=page_idx, + save_path=str(retrieval_save_path) if retrieval_save_path else None, + ) + coverage = retriever.check_coverage(qa_pairs) + logger.info("Retrieval coverage: %.1f%%", coverage * 100) + sweep_results = run_eval_sweep(qa_cfg, qa_pairs, results_dir, retriever=retriever) + + evaluation_total_time = time.perf_counter() - evaluation_start + evaluation_label = "QA" + evaluation_query_count = len(qa_pairs) + + passed = sum(1 for r in sweep_results if r["status"] == "PASS") + logger.info("QA sweep complete: %d/%d passed", passed, len(sweep_results)) + for r in sweep_results: + if r["status"] == "PASS": + out = Path(r["output_path"]).resolve() + logger.info("Results: %s", out) + er = r.get("eval_results", {}) + judge_scores = er.get("tier3_llm_judge", {}) + for gen_name, stats in judge_scores.items(): + evaluation_metrics[f"{gen_name} ({r['label']})"] = float(stats.get("mean_score", 0.0)) else: query_csv_path = Path(query_csv) if not query_csv_path.exists(): From 31d4e88f43e5bc795a6ee9e5aa78fccd45e537bf Mon Sep 17 00:00:00 2001 From: Kyle Zheng Date: Thu, 16 Apr 2026 20:56:14 +0000 Subject: [PATCH 02/10] Enforce min_coverage threshold in graph_pipeline QA path The eval CLI (cli.py) already aborts when retrieval coverage falls below the configured min_coverage. The graph_pipeline QA path logged coverage but never enforced the threshold, creating inconsistent behavior between the two entry points. Made-with: Cursor --- .../src/nemo_retriever/examples/graph_pipeline.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py index e3e890c49..7940acc4b 100644 --- a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py @@ -725,7 +725,9 @@ def main( logger.info("Built page index: %d documents", len(page_idx)) qa_cfg = load_eval_config(str(eval_config)) - qa_top_k = qa_cfg.get("execution", {}).get("top_k", 5) + execution_cfg = qa_cfg.get("execution", {}) + qa_top_k = execution_cfg.get("top_k", 5) + min_coverage = float(execution_cfg.get("min_coverage", 0.0)) results_dir = qa_cfg.get("output", {}).get("results_dir", "data/qa_results") evaluation_start = time.perf_counter() @@ -740,6 +742,10 @@ def main( ) coverage = retriever.check_coverage(qa_pairs) logger.info("Retrieval coverage: %.1f%%", coverage * 100) + if coverage < min_coverage: + raise ValueError( + f"Retrieval covers only {coverage:.1%} of queries " f"(min_coverage={min_coverage:.0%}). Aborting." + ) sweep_results = run_eval_sweep(qa_cfg, qa_pairs, results_dir, retriever=retriever) evaluation_total_time = time.perf_counter() - evaluation_start From 632cd365345bdf0333e2746c1fc50c0f6b380337 Mon Sep 17 00:00:00 2001 From: Kyle Zheng Date: Thu, 16 Apr 2026 21:12:31 +0000 Subject: [PATCH 03/10] Add pre-flight validation and lancedb guard for QA eval paths - build_eval_chain() now raises a clear ValueError when retrieval.type is not 'file', matching build_eval_pipeline(). - graph_pipeline validates eval_config and query_csv file existence before starting ingestion so a typo does not waste a full run. Made-with: Cursor --- nemo_retriever/src/nemo_retriever/evaluation/config.py | 6 ++++++ .../src/nemo_retriever/examples/graph_pipeline.py | 5 +++++ 2 files changed, 11 insertions(+) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/config.py b/nemo_retriever/src/nemo_retriever/evaluation/config.py index ab8d8de98..a6e9289ce 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/config.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/config.py @@ -272,6 +272,12 @@ def build_eval_chain( dataset = config.get("dataset", {}) judge_cfg = config["judge"] + retrieval_type = retrieval.get("type", "file") + if retrieval_type != "file": + raise ValueError( + f"build_eval_chain() only supports retrieval.type='file', got {retrieval_type!r}. " + "For LanceDB retrieval, use FileRetriever.from_lancedb() with run_eval_sweep()." + ) retrieval_json = retrieval.get("file_path", "") ground_truth_source = dataset.get("source", "") diff --git a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py index 7940acc4b..c576c7ffd 100644 --- a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py @@ -345,6 +345,11 @@ def main( raise ValueError(f"Unsupported --evaluation-mode: {evaluation_mode!r}") if evaluation_mode == "qa" and not eval_config: raise ValueError("--eval-config is required when --evaluation-mode=qa") + if evaluation_mode == "qa": + if not Path(str(eval_config)).exists(): + raise FileNotFoundError(f"--eval-config file not found: {eval_config}") + if not Path(str(query_csv)).exists(): + raise FileNotFoundError(f"--query-csv file not found: {query_csv}") if run_mode == "batch": os.environ["RAY_LOG_TO_DRIVER"] = "1" if ray_log_to_driver else "0" From b04e4ab1da9e553cdb1dc81b2d9dd6151af6198e Mon Sep 17 00:00:00 2001 From: Kyle Zheng <126034466+KyleZheng1284@users.noreply.github.com> Date: Tue, 21 Apr 2026 03:28:58 +0000 Subject: [PATCH 04/10] add live rag sdk support --- nemo_retriever/README.md | 148 +++++ nemo_retriever/pyproject.toml | 8 +- .../src/nemo_retriever/evaluation/README.md | 68 ++- .../src/nemo_retriever/evaluation/__init__.py | 10 +- .../src/nemo_retriever/evaluation/cli.py | 2 +- .../src/nemo_retriever/evaluation/config.py | 14 +- .../nemo_retriever/evaluation/generation.py | 6 +- .../src/nemo_retriever/evaluation/judging.py | 6 +- .../evaluation/live_retrieval.py | 92 +++ .../nemo_retriever/evaluation/orchestrator.py | 2 +- .../nemo_retriever/evaluation/retrievers.py | 2 +- .../src/nemo_retriever/evaluation/runner.py | 11 +- .../src/nemo_retriever/evaluation/scoring.py | 2 +- .../nemo_retriever/evaluation/text_utils.py | 22 - .../src/nemo_retriever/evaluation/types.py | 67 --- .../src/nemo_retriever/llm/__init__.py | 80 +++ .../nemo_retriever/llm/clients/__init__.py | 64 ++ .../judges.py => llm/clients/judge.py} | 60 +- .../generators.py => llm/clients/litellm.py} | 103 +++- .../src/nemo_retriever/llm/text_utils.py | 28 + .../src/nemo_retriever/llm/types.py | 117 ++++ .../src/nemo_retriever/params/__init__.py | 2 + .../src/nemo_retriever/params/models.py | 67 ++- .../src/nemo_retriever/retriever.py | 532 ++++++++++++++++- nemo_retriever/tests/test_live_rag.py | 558 ++++++++++++++++++ nemo_retriever/tests/test_llm_params.py | 395 +++++++++++++ nemo_retriever/uv.lock | 10 +- 27 files changed, 2293 insertions(+), 183 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/evaluation/live_retrieval.py delete mode 100644 nemo_retriever/src/nemo_retriever/evaluation/text_utils.py delete mode 100644 nemo_retriever/src/nemo_retriever/evaluation/types.py create mode 100644 nemo_retriever/src/nemo_retriever/llm/__init__.py create mode 100644 nemo_retriever/src/nemo_retriever/llm/clients/__init__.py rename nemo_retriever/src/nemo_retriever/{evaluation/judges.py => llm/clients/judge.py} (71%) rename nemo_retriever/src/nemo_retriever/{evaluation/generators.py => llm/clients/litellm.py} (59%) create mode 100644 nemo_retriever/src/nemo_retriever/llm/text_utils.py create mode 100644 nemo_retriever/src/nemo_retriever/llm/types.py create mode 100644 nemo_retriever/tests/test_live_rag.py create mode 100644 nemo_retriever/tests/test_llm_params.py diff --git a/nemo_retriever/README.md b/nemo_retriever/README.md index 18e9a1927..d5ae84359 100644 --- a/nemo_retriever/README.md +++ b/nemo_retriever/README.md @@ -98,6 +98,51 @@ ray_dataset = ingestor.ingest() chunks = ray_dataset.get_dataset().take_all() ``` +### Ingest a test corpus (CLI) + +`graph_pipeline` is the canonical ingestion script used throughout the +[QA evaluation guide](./src/nemo_retriever/evaluation/README.md#step-1-ingest-and-embed-pdfs-nemo-retriever). +Point it at a **directory** of PDFs to produce a ready-to-query LanceDB table. + +> **Corpus size matters.** LanceDB's default IVF index needs at least 16 +> chunks to train its 16 k-means partitions. Single-PDF ingestion will fail +> at the indexing step; point `graph_pipeline` at a directory with enough +> documents to clear that threshold. Replace `/your-example-dir` below with +> the path to your own corpus. + +```bash +python -m nemo_retriever.examples.graph_pipeline \ + /your-example-dir \ + --lancedb-uri lancedb +``` + +Chunks land at `./lancedb/nv-ingest`, which matches the default `Retriever()` +constructor used in [Run a recall query](#run-a-recall-query) below. With the +`[local]` extra installed (see setup), defaults point at local-GPU extraction +and embedding. For a realistic retrieval corpus, see +[QA evaluation -- Step 1](./src/nemo_retriever/evaluation/README.md#step-1-ingest-and-embed-pdfs-nemo-retriever). + +**No local GPU?** Set `NVIDIA_API_KEY` and route extraction and embedding +through [build.nvidia.com](https://build.nvidia.com/) NIMs instead: + +```bash +export NVIDIA_API_KEY=nvapi-... + +python -m nemo_retriever.examples.graph_pipeline \ + /your-example-dir \ + --lancedb-uri lancedb \ + --page-elements-invoke-url https://ai.api.nvidia.com/v1/cv/nvidia/nemotron-page-elements-v3 \ + --graphic-elements-invoke-url https://ai.api.nvidia.com/v1/cv/nvidia/nemotron-graphic-elements-v1 \ + --ocr-invoke-url https://ai.api.nvidia.com/v1/cv/nvidia/nemoretriever-ocr-v1 \ + --table-structure-invoke-url https://ai.api.nvidia.com/v1/cv/nvidia/nemotron-table-structure-v1 \ + --embed-invoke-url https://integrate.api.nvidia.com/v1/embeddings \ + --embed-model-name nvidia/llama-nemotron-embed-1b-v2 +``` + +When you use the remote embedder, pair the `Retriever` with the matching +`embedder=` + `embedding_endpoint=` overrides shown in +[Run a recall query](#run-a-recall-query). + ### Inspect extracts You can inspect how recall accuracy optimized text chunks for various content types were extracted into text representations: ```python @@ -151,6 +196,22 @@ query = "Given their activities, which animal is responsible for the typos in my hits = retriever.query(query) ``` +If you ingested with the remote-NIM recipe above (no local GPU), point the +`Retriever` at the same embedding endpoint so query vectors are produced by the +same model that produced the stored chunk vectors: + +```python +retriever = Retriever( + lancedb_uri="lancedb", + lancedb_table="nv-ingest", + embedder="nvidia/llama-nemotron-embed-1b-v2", + embedding_endpoint="https://integrate.api.nvidia.com/v1/embeddings", + top_k=5, + reranker=False, +) +hits = retriever.query(query) +``` + ```python # retrieved text from the first page >>> hits[0] @@ -202,6 +263,93 @@ Answer: Cat is the animal whose activity (jumping onto a laptop) matches the location of the typos, so the cat is responsible for the typos in the documents. ``` +### Live RAG SDK (retrieve + answer in one call) + +The pattern above -- retrieve hits, build a prompt, call an LLM -- is baked into the SDK as `Retriever.answer()` so live applications can skip the boilerplate. The same `Retriever` instance powers three entry points: + +| Method | Input | Output | Use case | +| --- | --- | --- | --- | +| `Retriever.retrieve(query, top_k=...)` | one query | `RetrievalResult` (`chunks`, `metadata`) | Structured retrieval without an LLM. | +| `Retriever.answer(query, llm=..., judge=None, reference=None, ...)` | one query | `AnswerResult` (answer + chunks + optional scores) | One-shot RAG -- production/live. | +| `Retriever.pipeline().generate(...).score().judge(...).run(queries)` | many queries | `pandas.DataFrame` | Batch RAG over the operator graph, each step optional. | + +Install the LLM client extra: +```bash +uv pip install "nemo-retriever[llm]" +export NVIDIA_API_KEY=nvapi-... +``` + +Single-query live RAG. Point `lancedb_uri` at any table built above; the +`embedder` must match the one used during ingestion so query vectors land in +the same embedding space as the stored chunks. + +```python +from nemo_retriever.retriever import Retriever +from nemo_retriever.llm import LiteLLMClient + +retriever = Retriever( + lancedb_uri="lancedb", + lancedb_table="nv-ingest", + embedder="nvidia/llama-nemotron-embed-1b-v2", + embedding_endpoint="https://integrate.api.nvidia.com/v1/embeddings", + top_k=5, +) +llm = LiteLLMClient.from_kwargs( + model="nvidia_nim/nvidia/llama-3.3-nemotron-super-49b-v1.5", + temperature=0.0, + max_tokens=512, +) + +result = retriever.answer("What is RAG?", llm=llm) +print(result.answer) +# 'Retrieval-augmented generation combines external context with an LLM...' +print(len(result.chunks), "chunks from", {m.get("source") for m in result.metadata}) +print(f"{result.latency_s:.2f}s on {result.model}") +``` + +Local-GPU shortcut: if you ingested with default `graph_pipeline` flags +(`--embed` omitted, `[local]` extra installed), drop `embedder=` and +`embedding_endpoint=` to reuse the bundled `VL_EMBED_MODEL`. + +Live RAG with scoring and an LLM judge (requires a ground-truth `reference`): +```python +from nemo_retriever.llm import LLMJudge + +judge = LLMJudge.from_kwargs(model="nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1") +result = retriever.answer( + "What is RAG?", + llm=llm, + judge=judge, + reference="RAG combines retrieved context with LLM generation.", +) +print(result.token_f1, result.judge_score, result.failure_mode) +# 0.62 4 'correct' +``` + +Batch RAG over the operator graph -- each builder step is optional: +```python +df = ( + retriever.pipeline() + .generate(llm) + .score() + .judge(judge) + .run( + queries=["What is RAG?", "What is reranking?"], + reference=["RAG combines retrieval with generation.", "Reranking re-scores retrieved passages."], + ) +) +print(df[["query", "answer", "token_f1", "judge_score", "failure_mode"]]) +``` + +Scoring tiers on `AnswerResult`: + +- **Tier 1** (`answer_in_context`) -- whether retrieval surfaced the evidence; requires `reference`. +- **Tier 2** (`token_f1`, `exact_match`) -- token-level overlap; requires `reference`. +- **Tier 3** (`judge_score`, `judge_reasoning`) -- LLM-as-judge 1-5 score; requires `reference` and `judge`. +- `failure_mode` -- derived classification (`correct`, `partial`, `retrieval_miss`, `generation_miss`, `refused_*`, `thinking_truncated`). + +If only `reference` is supplied, Tier 1 + 2 run. If only `judge` is supplied (without `reference`), a `ValueError` is raised. On generation error, scoring and judge are skipped and `AnswerResult.error` is populated. + ### Ingest other types of content: For PowerPoint and Docx files, ensure libeoffice is installed by your system's package manager. This is required to make their pages renderable as images for our [page-elements content classifier](https://huggingface.co/nvidia/nemotron-page-elements-v3). diff --git a/nemo_retriever/pyproject.toml b/nemo_retriever/pyproject.toml index 8e2758554..9408502cf 100644 --- a/nemo_retriever/pyproject.toml +++ b/nemo_retriever/pyproject.toml @@ -112,10 +112,12 @@ benchmarks = [ "open-clip-torch==3.2.0", ] -eval = [ +# ── LLM client (generation + judge) ────────────────────────────────────────── +# Install this if you want to call ``Retriever.answer()``, ``Retriever.pipeline()``, +# or construct an ``LLMJudge`` / ``LiteLLMClient`` directly. Powers both the +# live-RAG SDK and the batch evaluation framework. +llm = [ "litellm>=1.40.0", - "pyyaml>=6.0", - "tenacity>=8.0.0", ] dev = [ "build>=1.2.2", diff --git a/nemo_retriever/src/nemo_retriever/evaluation/README.md b/nemo_retriever/src/nemo_retriever/evaluation/README.md index 650f7733f..d28333dd0 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/README.md +++ b/nemo_retriever/src/nemo_retriever/evaluation/README.md @@ -1,6 +1,6 @@ # QA Evaluation Pipeline -The evaluation framework lives in **`nemo_retriever.evaluation`** (install `nemo-retriever[eval]` from PyPI, or **`uv pip install -e "./nemo_retriever[eval]"` from this repo root** so `graph_pipeline` and local changes resolve). +The evaluation framework lives in **`nemo_retriever.evaluation`** (install `nemo-retriever[llm]` from PyPI, or **`uv pip install -e "./nemo_retriever[llm]"` from this repo root** so `graph_pipeline` and local changes resolve). Measures LLM answer quality over a RAG pipeline: retrieve context from a VDB, generate answers with one or more LLMs, and score each answer against ground-truth references using multi-tier scoring and an LLM-as-judge. @@ -90,9 +90,9 @@ Exact commands to reproduce the full-page markdown QA evaluation from scratch. uv venv qa-retriever --python 3.12 source qa-retriever/bin/activate -# 2. Install nemo_retriever with eval extras (from repo root) +# 2. Install nemo_retriever with LLM extras (from repo root) cd /path/to/nv-ingest -uv pip install -e "./nemo_retriever[eval]" +uv pip install -e "./nemo_retriever[llm]" # 3. Set your API key (used by generation + judging) export NVIDIA_API_KEY="nvapi-..." @@ -236,12 +236,12 @@ Steps 1-3 (ingest, build index, export) require the **`nemo_retriever`** library uv venv qa-retriever --python 3.12 source qa-retriever/bin/activate cd /path/to/nv-ingest # repo root -uv pip install -e "./nemo_retriever[eval]" +uv pip install -e "./nemo_retriever[llm]" ``` -The `[eval]` extra installs `litellm` for LLM generation and judging. If you are not using this tree, you can instead `uv pip install "nemo-retriever[eval]"` from PyPI (package name uses a hyphen). +The `[llm]` extra installs `litellm` for LLM generation and judging (and powers both the batch-eval framework and the live-RAG SDK). If you are not using this tree, you can instead `uv pip install "nemo-retriever[llm]"` from PyPI (package name uses a hyphen). -**Eval-only path:** if you already have a retrieval JSON and only need to run `retriever eval run`, an environment with `nemo_retriever[eval]` installed is sufficient. +**Eval-only path:** if you already have a retrieval JSON and only need to run `retriever eval run`, an environment with `nemo_retriever[llm]` installed is sufficient. ### Prerequisites (data and keys) @@ -613,7 +613,7 @@ or manually via `QAEvalPipeline(retriever=..., llm_clients=..., judge=...)`. ### Protocol interfaces All three pluggable interfaces are Python `Protocol` classes defined in -`nemo_retriever.evaluation.types`. Any object that implements the right method +`nemo_retriever.llm.types`. Any object that implements the right method signature works -- no inheritance or registration required. | Protocol | Method | Default implementation | @@ -858,6 +858,54 @@ automatic in-memory build. | `EMBEDDER` | `retriever eval run --from-env` (lancedb mode) | Embedding model name | | `RETRIEVAL_SAVE_PATH` | `retriever eval run --from-env` (lancedb mode) | Optional path to persist the retrieval JSON | +### Mixing providers per component + +Every component (`Retriever`, `LiteLLMClient`, `LLMJudge`) takes its own `(model, api_base, api_key)` triple, so you can point different stages at different endpoints without any extra wiring. Unset `api_key` fields auto-resolve from `NVIDIA_API_KEY` / `NGC_API_KEY`, so the common single-provider path stays a one-liner; only reach for explicit `api_key=...` when a stage needs a distinct credential. + +```python +import os +from nemo_retriever.retriever import Retriever +from nemo_retriever.llm import LiteLLMClient, LLMJudge + +retriever = Retriever( + lancedb_uri="lancedb", + lancedb_table="nv-ingest", + embedder="nvidia/llama-nemotron-embed-1b-v2", + embedding_endpoint="https://integrate.api.nvidia.com/v1/embeddings", + top_k=5, +) + +llm = LiteLLMClient.from_kwargs( + model="openai/gpt-4o-mini", + api_base="https://api.openai.com/v1", + api_key=os.environ["OPENAI_API_KEY"], + temperature=0.2, +) + +judge = LLMJudge.from_kwargs( + model="nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1", +) +``` + +Transport params redact the bearer token in their `__repr__` / `__str__` (`api_key=***`), so logging a client or letting a Pydantic validation error echo a params object back will not leak credentials. + +### Batch generation failure rate + +When the fluent builder ran `.generate(...)`, the returned `DataFrame` exposes the aggregate generation failure rate on `df.attrs` so batch-eval reports can surface it alongside `token_f1` / `judge_score`: + +```python +df = retriever.pipeline().generate(llm).run(queries=questions) +print(f"generation_failure_rate = {df.attrs['generation_failure_rate']:.2%}") +``` + +`generation_failure_rate` is the fraction of rows whose `gen_error` column is non-null; it is only attached when the `.generate()` step ran. + +### Stable public surface + +`nemo_retriever.llm.__all__` is the supported integration point for the client + types layer. Import the Protocols, result dataclasses, `LiteLLMClient`, `LLMJudge`, and the `LLMInferenceParams` / `LLMRemoteClientParams` models from `nemo_retriever.llm`; deeper submodule paths (`llm.clients.litellm`, `llm.text_utils`, ...) are implementation details and may be reorganised without notice. + +The previously-provided `nemo-retriever[eval]` install extra was removed in favor of `nemo-retriever[llm]`; pin the new extra in any requirements files that still reference `[eval]`. + ## Scoring System (Three-Tier Hierarchy) Each (query, model) pair is scored by three independent tiers. Each tier tests a different layer of the RAG pipeline: @@ -1017,7 +1065,7 @@ The same pattern works for custom failure classifiers, alternative judge prompts ### Custom Retriever ```python -from nemo_retriever.evaluation.types import RetrieverStrategy, RetrievalResult +from nemo_retriever.llm.types import RetrieverStrategy, RetrievalResult class MyRetriever: def retrieve(self, query: str, top_k: int) -> RetrievalResult: @@ -1028,7 +1076,7 @@ class MyRetriever: ### Custom LLM Client ```python -from nemo_retriever.evaluation.types import LLMClient, GenerationResult +from nemo_retriever.llm.types import LLMClient, GenerationResult class MyClient: def generate(self, query: str, chunks: list[str]) -> GenerationResult: @@ -1039,7 +1087,7 @@ class MyClient: ### Custom Judge ```python -from nemo_retriever.evaluation.types import AnswerJudge, JudgeResult +from nemo_retriever.llm.types import AnswerJudge, JudgeResult class MyJudge: def judge(self, query: str, reference: str, candidate: str) -> JudgeResult: diff --git a/nemo_retriever/src/nemo_retriever/evaluation/__init__.py b/nemo_retriever/src/nemo_retriever/evaluation/__init__.py index 142041e92..de1dd3b75 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/__init__.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/__init__.py @@ -14,14 +14,14 @@ Types, scoring, and ``EvalOperator`` are always available. Modules that depend on ``litellm`` (generators, judges, generation, judging, orchestrator, config) are lazy-loaded so that lightweight -consumers can use scoring without installing the ``[eval]`` extra:: +consumers can use scoring without installing the ``[llm]`` extra:: - pip install nemo-retriever[eval] + pip install nemo-retriever[llm] # SDK + batch eval """ from nemo_retriever.evaluation.eval_operator import EvalOperator from nemo_retriever.evaluation.scoring import score_dataframe -from nemo_retriever.evaluation.types import ( +from nemo_retriever.llm.types import ( AnswerJudge, GenerationResult, JudgeResult, @@ -35,8 +35,8 @@ "JudgingOperator": "nemo_retriever.evaluation.judging", "ScoringOperator": "nemo_retriever.evaluation.scoring_operator", "RetrievalLoaderOperator": "nemo_retriever.evaluation.retrieval_loader", - "LiteLLMClient": "nemo_retriever.evaluation.generators", - "LLMJudge": "nemo_retriever.evaluation.judges", + "LiteLLMClient": "nemo_retriever.llm.clients", + "LLMJudge": "nemo_retriever.llm.clients", "QAEvalPipeline": "nemo_retriever.evaluation.orchestrator", "load_eval_config": "nemo_retriever.evaluation.config", "build_eval_chain": "nemo_retriever.evaluation.config", diff --git a/nemo_retriever/src/nemo_retriever/evaluation/cli.py b/nemo_retriever/src/nemo_retriever/evaluation/cli.py index 543c87823..4550f3312 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/cli.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/cli.py @@ -5,7 +5,7 @@ """``retriever eval`` Typer subcommands. All heavy imports (litellm, evaluation modules) are deferred to inside -command bodies so that ``pip install nemo-retriever`` (without ``[eval]``) +command bodies so that ``pip install nemo-retriever`` (without ``[llm]``) does not break the CLI at import time. """ diff --git a/nemo_retriever/src/nemo_retriever/evaluation/config.py b/nemo_retriever/src/nemo_retriever/evaluation/config.py index a6e9289ce..8fd33ea2b 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/config.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/config.py @@ -170,7 +170,7 @@ def load_eval_config(path: str) -> dict: """Load eval config from YAML (``.yaml``/``.yml``) or JSON (``.json``) file. Supports ``${VAR}`` env var expansion in string values (recursive). - YAML requires ``pyyaml`` (in ``[eval]`` extras). JSON uses stdlib. + YAML requires ``pyyaml`` (core dependency of nemo-retriever). JSON uses stdlib. Parameters ---------- @@ -199,7 +199,9 @@ def load_eval_config(path: str) -> dict: import yaml except ImportError as exc: raise ImportError( - "pyyaml is required for YAML config files. " "Install it: pip install nemo-retriever[eval]" + "pyyaml is required for YAML config files. " + "It is a core dependency of nemo-retriever; reinstall with: " + "pip install nemo-retriever" ) from exc with open(config_path, encoding="utf-8") as f: raw = yaml.safe_load(f) @@ -334,8 +336,7 @@ def build_eval_pipeline(config: dict) -> "QAEvalPipeline": A fully configured pipeline ready for ``.evaluate(qa_pairs)`` or ``.process(df)``. """ - from nemo_retriever.evaluation.generators import LiteLLMClient - from nemo_retriever.evaluation.judges import LLMJudge + from nemo_retriever.llm.clients import LLMJudge, LiteLLMClient from nemo_retriever.evaluation.orchestrator import QAEvalPipeline from nemo_retriever.evaluation.retrievers import FileRetriever @@ -363,18 +364,19 @@ def build_eval_pipeline(config: dict) -> "QAEvalPipeline": llm_clients: dict[str, LiteLLMClient] = {} for gen_cfg in generators: name = gen_cfg.get("name", gen_cfg["model"]) - llm_clients[name] = LiteLLMClient( + llm_clients[name] = LiteLLMClient.from_kwargs( model=gen_cfg["model"], api_base=gen_cfg.get("api_base"), api_key=gen_cfg.get("api_key"), temperature=gen_cfg.get("temperature", 0.0), + top_p=gen_cfg.get("top_p"), max_tokens=gen_cfg.get("max_tokens", 4096), extra_params=gen_cfg.get("extra_params"), num_retries=gen_cfg.get("num_retries", 3), timeout=gen_cfg.get("timeout", default_timeout), ) - judge = LLMJudge( + judge = LLMJudge.from_kwargs( model=judge_cfg["model"], api_base=judge_cfg.get("api_base"), api_key=judge_cfg.get("api_key"), diff --git a/nemo_retriever/src/nemo_retriever/evaluation/generation.py b/nemo_retriever/src/nemo_retriever/evaluation/generation.py index ab78f75c5..b879868de 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/generation.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/generation.py @@ -11,8 +11,8 @@ from typing import Any, ClassVar, Optional from nemo_retriever.evaluation.eval_operator import EvalOperator -from nemo_retriever.evaluation.generators import LiteLLMClient -from nemo_retriever.evaluation.types import GenerationResult +from nemo_retriever.llm.clients import LiteLLMClient +from nemo_retriever.llm.types import GenerationResult logger = logging.getLogger(__name__) @@ -53,7 +53,7 @@ def __init__( timeout=timeout, max_workers=max_workers, ) - self._client = LiteLLMClient( + self._client = LiteLLMClient.from_kwargs( model=model, api_base=api_base, api_key=api_key, diff --git a/nemo_retriever/src/nemo_retriever/evaluation/judging.py b/nemo_retriever/src/nemo_retriever/evaluation/judging.py index 0ca6546e8..d7ef29efa 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/judging.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/judging.py @@ -11,8 +11,8 @@ from typing import Any, ClassVar, Optional from nemo_retriever.evaluation.eval_operator import EvalOperator -from nemo_retriever.evaluation.judges import LLMJudge -from nemo_retriever.evaluation.types import JudgeResult +from nemo_retriever.llm.clients import LLMJudge +from nemo_retriever.llm.types import JudgeResult logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def __init__( timeout=timeout, max_workers=max_workers, ) - self._judge = LLMJudge( + self._judge = LLMJudge.from_kwargs( model=model, api_base=api_base, api_key=api_key, diff --git a/nemo_retriever/src/nemo_retriever/evaluation/live_retrieval.py b/nemo_retriever/src/nemo_retriever/evaluation/live_retrieval.py new file mode 100644 index 000000000..92ecd41f9 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/evaluation/live_retrieval.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""LiveRetrievalOperator -- live LanceDB retrieval source for evaluation chains.""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any, ClassVar + +from nemo_retriever.evaluation.eval_operator import EvalOperator + +if TYPE_CHECKING: + from nemo_retriever.retriever import Retriever + +logger = logging.getLogger(__name__) + + +class LiveRetrievalOperator(EvalOperator): + """Live retrieval source for evaluation chains. + + Parallels :class:`~nemo_retriever.evaluation.retrieval_loader.RetrievalLoaderOperator` + but pulls chunks from LanceDB on the fly via + :meth:`Retriever.retrieve ` + rather than loading them from a pre-computed retrieval JSON. Used by + :meth:`Retriever.pipeline ` + to prepend retrieval to a DataFrame-in/out generation / scoring / + judging graph. + + Input DataFrame must have a ``query`` column. Adds ``context`` + (``list[str]`` of chunk texts per row) and ``context_metadata`` + (``list[dict]`` aligned with ``context``). + + Notes: + This operator is **inprocess-only**. The wrapped + :class:`~nemo_retriever.retriever.Retriever` instance is held as an + in-memory attribute rather than registered as a constructor kwarg, + because a live embedder / reranker cache does not serialise for + Ray fan-out. Use + :class:`~nemo_retriever.evaluation.retrieval_loader.RetrievalLoaderOperator` + instead for distributed batch evaluation. + + Example: + >>> from nemo_retriever.retriever import Retriever # doctest: +SKIP + >>> retriever = Retriever(lancedb_uri="./kb") # doctest: +SKIP + >>> op = LiveRetrievalOperator(retriever, top_k=5) # doctest: +SKIP + >>> import pandas as pd # doctest: +SKIP + >>> df = pd.DataFrame({"query": ["What is RAG?"]}) # doctest: +SKIP + >>> enriched = op.process(df) # doctest: +SKIP + >>> list(enriched.columns) # doctest: +SKIP + ['query', 'context', 'context_metadata'] + """ + + required_columns: ClassVar[tuple[str, ...]] = ("query",) + output_columns: ClassVar[tuple[str, ...]] = ("context", "context_metadata") + + def __init__(self, retriever: "Retriever", *, top_k: int = 5) -> None: + # ``retriever`` is not a serialisable constructor kwarg (embedders, + # LanceDB handles, reranker model caches), so only ``top_k`` is + # registered for get_constructor_kwargs(). This operator is + # inprocess-only as documented above. + super().__init__(top_k=int(top_k)) + self._retriever = retriever + self._top_k = int(top_k) + + def process(self, data: Any, **kwargs: Any) -> Any: + import pandas as pd + + if not isinstance(data, pd.DataFrame): + raise TypeError(f"{type(self).__name__} requires a pandas.DataFrame input, " f"got {type(data).__name__}") + + out = data.copy() + query_texts = [str(q) for q in out["query"]] + + # One batched call instead of per-row iteration. The Retriever + # embeds all queries in a single NIM round trip and issues a + # single LanceDB sweep, so an N-row DataFrame pays O(1) network + # cost end-to-end rather than O(N). Order is preserved by + # ``retrieve_batch`` so ``results[i]`` aligns with row ``i``. + results = self._retriever.retrieve_batch(query_texts, top_k=self._top_k) + + if len(results) != len(query_texts): + raise RuntimeError( + "retrieve_batch returned " + f"{len(results)} results for {len(query_texts)} queries; " + "this violates the contract and points at a Retriever bug." + ) + + out["context"] = [list(r.chunks) for r in results] + out["context_metadata"] = [list(r.metadata) for r in results] + return out diff --git a/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py b/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py index cabfc7e76..7df9d9eba 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py @@ -36,7 +36,7 @@ classify_failure, token_f1, ) -from nemo_retriever.evaluation.types import ( +from nemo_retriever.llm.types import ( AnswerJudge, GenerationResult, JudgeResult, diff --git a/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py b/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py index c2d09a387..74884e666 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py @@ -23,7 +23,7 @@ import threading import unicodedata -from nemo_retriever.evaluation.types import RetrievalResult +from nemo_retriever.llm.types import RetrievalResult logger = logging.getLogger(__name__) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/runner.py b/nemo_retriever/src/nemo_retriever/evaluation/runner.py index a2173fa21..44f4b9a54 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/runner.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/runner.py @@ -22,7 +22,7 @@ from typing import Any, Callable, TYPE_CHECKING if TYPE_CHECKING: - from nemo_retriever.evaluation.types import RetrieverStrategy + from nemo_retriever.llm.types import RetrieverStrategy logger = logging.getLogger(__name__) @@ -64,8 +64,7 @@ def run_eval_sweep( ``"FAIL"``), ``output_path`` (or ``error``), and ``eval_results`` (the full evaluation dict when status is PASS). """ - from nemo_retriever.evaluation.generators import LiteLLMClient - from nemo_retriever.evaluation.judges import LLMJudge + from nemo_retriever.llm.clients import LLMJudge, LiteLLMClient from nemo_retriever.evaluation.orchestrator import QAEvalPipeline from nemo_retriever.evaluation.retrievers import FileRetriever @@ -110,18 +109,18 @@ def run_eval_sweep( check_unresolved_env(gen_model_cfg.get("api_key"), "api_key", f"generator '{gen_name}'") check_unresolved_env(judge_model_cfg.get("api_key"), "api_key", f"judge '{judge_name}'") - client = LiteLLMClient( + client = LiteLLMClient.from_kwargs( model=gen_model_cfg["model"], api_base=gen_model_cfg.get("api_base"), api_key=gen_model_cfg.get("api_key"), temperature=eval_cfg.get("temperature", gen_model_cfg.get("temperature", 0.0)), - top_p=eval_cfg.get("top_p", gen_model_cfg.get("top_p", 1.0)), + top_p=eval_cfg.get("top_p", gen_model_cfg.get("top_p")), max_tokens=eval_cfg.get("max_tokens", gen_model_cfg.get("max_tokens", 4096)), extra_params=gen_model_cfg.get("extra_params"), num_retries=gen_model_cfg.get("num_retries", 3), timeout=gen_model_cfg.get("timeout", default_timeout), ) - judge = LLMJudge( + judge = LLMJudge.from_kwargs( model=judge_model_cfg["model"], api_base=judge_model_cfg.get("api_base"), api_key=judge_model_cfg.get("api_key"), diff --git a/nemo_retriever/src/nemo_retriever/evaluation/scoring.py b/nemo_retriever/src/nemo_retriever/evaluation/scoring.py index 121c49a04..bac508c37 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/scoring.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/scoring.py @@ -24,7 +24,7 @@ import pandas as pd -from nemo_retriever.evaluation.text_utils import strip_think_tags +from nemo_retriever.llm.text_utils import strip_think_tags _STOP_WORDS = frozenset( { diff --git a/nemo_retriever/src/nemo_retriever/evaluation/text_utils.py b/nemo_retriever/src/nemo_retriever/evaluation/text_utils.py deleted file mode 100644 index 479ff5844..000000000 --- a/nemo_retriever/src/nemo_retriever/evaluation/text_utils.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -"""Shared text processing utilities for the evaluation package.""" - -from __future__ import annotations - -import re - - -def strip_think_tags(text: str) -> str: - """Remove ... reasoning blocks from model output. - - Handles both closed tags (...) and unclosed tags where the - model hit the token limit mid-reasoning and never emitted . - Returns empty string if nothing remains after stripping so callers can - detect thinking_truncated. - """ - stripped = re.sub(r".*?", "", text, flags=re.DOTALL) - stripped = re.sub(r".*", "", stripped, flags=re.DOTALL) - return stripped.strip() diff --git a/nemo_retriever/src/nemo_retriever/evaluation/types.py b/nemo_retriever/src/nemo_retriever/evaluation/types.py deleted file mode 100644 index 31867779c..000000000 --- a/nemo_retriever/src/nemo_retriever/evaluation/types.py +++ /dev/null @@ -1,67 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -""" -Protocol definitions and dataclasses for the QA evaluation pipeline. - -These abstractions allow retrieval strategies, LLM clients, and judges -to be swapped independently without modifying the orchestrator. -""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any, Optional, Protocol, runtime_checkable - - -@runtime_checkable -class RetrieverStrategy(Protocol): - """Pluggable retrieval strategy interface.""" - - def retrieve(self, query: str, top_k: int) -> "RetrievalResult": ... - - -@runtime_checkable -class LLMClient(Protocol): - """Pluggable LLM answer generation interface.""" - - def generate(self, query: str, chunks: list[str]) -> "GenerationResult": ... - - -@runtime_checkable -class AnswerJudge(Protocol): - """Pluggable answer scoring interface.""" - - def judge(self, query: str, reference: str, candidate: str) -> "JudgeResult": ... - - -@dataclass -class RetrievalResult: - """Result from a retrieval operation.""" - - chunks: list[str] - metadata: list[dict[str, Any]] = field(default_factory=list) - - -@dataclass -class GenerationResult: - """Result from a single LLM generation call.""" - - answer: str - latency_s: float - model: str - error: Optional[str] = None - - -@dataclass -class JudgeResult: - """Result from a single judge evaluation. - - ``score`` is ``None`` when the judge could not produce a score - (API error, parse failure, empty candidate). Valid scores are 1-5. - """ - - score: Optional[int] = None - reasoning: str = "" - error: Optional[str] = None diff --git a/nemo_retriever/src/nemo_retriever/llm/__init__.py b/nemo_retriever/src/nemo_retriever/llm/__init__.py new file mode 100644 index 000000000..7a23181c3 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/llm/__init__.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +LLM primitives: Protocols, result dataclasses, and concrete clients. + +Types, Protocols, and result dataclasses are always available (zero +external deps). ``LiteLLMClient`` and ``LLMJudge`` are lazy-loaded so +that lightweight consumers can use the type contracts without +installing ``litellm``:: + + from nemo_retriever.llm import RetrieverStrategy, RetrievalResult # cheap + from nemo_retriever.llm import LiteLLMClient # imports litellm on first use + +Credentials +----------- +Per-component API keys (``api_key``) and base URLs (``api_base``) are +passed directly on ``LiteLLMClient.from_kwargs`` / ``LLMJudge.from_kwargs`` +or on ``Retriever(embedding_api_key=..., embedding_endpoint=...)``. When +``api_key`` is left ``None`` the shared ``_ParamsModel`` validator +resolves ``NVIDIA_API_KEY`` / ``NGC_API_KEY`` from the environment. This +keeps the common single-provider path a one-liner while still allowing +multiple independent endpoints to coexist -- each component takes its +own ``(api_base, api_key, model)`` triple. + +Public surface contract +----------------------- +The names in ``__all__`` below are the frozen public API of this +module. External callers should import from ``nemo_retriever.llm`` +rather than reaching into submodules (``llm.clients.litellm``, +``llm.text_utils``) directly -- those submodule paths are implementation +details and may be reorganised in future releases without notice. The +Protocols + result dataclasses + concrete clients + re-exported params +models listed here are the supported integration points. +""" + +from nemo_retriever.llm.types import ( + AnswerJudge, + AnswerResult, + GenerationResult, + JudgeResult, + LLMClient, + RetrievalResult, + RetrieverStrategy, +) +from nemo_retriever.params.models import LLMInferenceParams, LLMRemoteClientParams + +_LAZY_IMPORTS = { + "LiteLLMClient": "nemo_retriever.llm.clients.litellm", + "LLMJudge": "nemo_retriever.llm.clients.judge", +} + + +def __getattr__(name: str): + if name in _LAZY_IMPORTS: + import importlib + + module = importlib.import_module(_LAZY_IMPORTS[name]) + return getattr(module, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + # Protocols + "AnswerJudge", + "LLMClient", + "RetrieverStrategy", + # Result dataclasses + "AnswerResult", + "GenerationResult", + "JudgeResult", + "RetrievalResult", + # Concrete clients (lazy-loaded) + "LLMJudge", + "LiteLLMClient", + # Transport / sampling params (re-exported for ergonomics) + "LLMInferenceParams", + "LLMRemoteClientParams", +] diff --git a/nemo_retriever/src/nemo_retriever/llm/clients/__init__.py b/nemo_retriever/src/nemo_retriever/llm/clients/__init__.py new file mode 100644 index 000000000..9c19065e8 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/llm/clients/__init__.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Concrete LLM client implementations and a lightweight client registry. + +The ``llm.clients`` package hosts every concrete :class:`LLMClient` +implementation in its own submodule (``litellm.py``, ``judge.py``, ...) +so that adding a new client means adding a new file rather than +extending a monolithic module. To keep the public import path stable, +the registered client classes plus the internal prompt-helper +(``_build_rag_prompt``) and response-parser (``_parse_judge_response``) +are re-exported from this package's namespace. Any caller that imports +``from nemo_retriever.llm.clients import LiteLLMClient`` will therefore +continue to work unchanged after the module-to-package refactor. +""" + +from __future__ import annotations + +from nemo_retriever.llm.clients.judge import LLMJudge, _parse_judge_response +from nemo_retriever.llm.clients.litellm import LiteLLMClient, _build_rag_prompt + +_REGISTRY: dict[str, type] = {} + + +def register_client(name: str, cls: type) -> None: + """Register a client class under a human-readable name. + + The registry is optional: every concrete client remains addressable + by its import path. Registration is offered as a convenience for + configuration-driven instantiation (e.g. reading a ``type`` key from + a YAML file and looking up the matching class). + + Args: + name: Stable lookup name (e.g. ``"litellm"``). + cls: The client class. Must expose a ``from_kwargs`` + classmethod to be useful to configuration-driven callers. + """ + + _REGISTRY[name] = cls + + +def get_client(name: str) -> type: + """Return a registered client class by name. + + Raises: + KeyError: When ``name`` has not been registered. + """ + + return _REGISTRY[name] + + +register_client("litellm", LiteLLMClient) + + +__all__ = [ + "LLMJudge", + "LiteLLMClient", + "_build_rag_prompt", + "_parse_judge_response", + "get_client", + "register_client", +] diff --git a/nemo_retriever/src/nemo_retriever/evaluation/judges.py b/nemo_retriever/src/nemo_retriever/llm/clients/judge.py similarity index 71% rename from nemo_retriever/src/nemo_retriever/evaluation/judges.py rename to nemo_retriever/src/nemo_retriever/llm/clients/judge.py index 8cb76e6ba..813328203 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/judges.py +++ b/nemo_retriever/src/nemo_retriever/llm/clients/judge.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -LLM-as-judge scoring for the QA evaluation pipeline. +LLM-as-judge scoring. LLMJudge uses a strong LLM to score generated answers on a 1-5 scale against a ground-truth reference answer. @@ -15,8 +15,9 @@ import re from typing import Any, Optional -from nemo_retriever.evaluation.generators import LiteLLMClient -from nemo_retriever.evaluation.types import JudgeResult +from nemo_retriever.llm.clients.litellm import LiteLLMClient +from nemo_retriever.llm.types import JudgeResult +from nemo_retriever.params.models import LLMInferenceParams, LLMRemoteClientParams _JUDGE_SYSTEM_PROMPT = """\ You are an expert evaluator for factual question answering. @@ -59,27 +60,62 @@ class LLMJudge: - """LLM-as-judge that scores candidate answers on a 1-5 scale.""" + """LLM-as-judge that scores candidate answers on a 1-5 scale. + + Configuration is split into two Pydantic objects: + + * ``transport``: :class:`~nemo_retriever.params.LLMRemoteClientParams` + owns the endpoint, api_key, retries, and timeout for the judge model. + * ``sampling``: :class:`~nemo_retriever.params.LLMInferenceParams` + owns ``temperature`` / ``top_p`` / ``max_tokens``. Defaults to + ``temperature=0.0, max_tokens=256`` for deterministic scoring. + + Use :meth:`from_kwargs` for a flat, backwards-compatible constructor. + """ + + _DEFAULT_MODEL: str = "nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1" + _DEFAULT_SAMPLING: LLMInferenceParams = LLMInferenceParams(temperature=0.0, max_tokens=256) def __init__( self, - model: str = "nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1", + transport: LLMRemoteClientParams, + sampling: Optional[LLMInferenceParams] = None, + ): + self._client = LiteLLMClient( + transport=transport, + sampling=sampling if sampling is not None else self._DEFAULT_SAMPLING, + ) + + @property + def model(self) -> str: + """Return the judge model identifier from the transport params.""" + return self._client.transport.model + + @classmethod + def from_kwargs( + cls, + *, + model: str = _DEFAULT_MODEL, api_base: Optional[str] = None, api_key: Optional[str] = None, extra_params: Optional[dict[str, Any]] = None, + num_retries: int = 3, timeout: float = 120.0, - ): - self._client = LiteLLMClient( + ) -> "LLMJudge": + """Flat-kwarg constructor for zero-churn migration from the old signature. + + Sampling is left at the class default (deterministic 0.0 temperature, + 256 max tokens). Use the two-arg constructor to override sampling. + """ + transport = LLMRemoteClientParams( model=model, api_base=api_base, api_key=api_key, - temperature=0.0, - max_tokens=256, - extra_params=extra_params or {}, - num_retries=3, + num_retries=num_retries, timeout=timeout, + extra_params=extra_params or {}, ) - self.model = model + return cls(transport=transport) def judge(self, query: str, reference: str, candidate: str) -> JudgeResult: """Score a candidate answer against the reference answer.""" diff --git a/nemo_retriever/src/nemo_retriever/evaluation/generators.py b/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py similarity index 59% rename from nemo_retriever/src/nemo_retriever/evaluation/generators.py rename to nemo_retriever/src/nemo_retriever/llm/clients/litellm.py index 3e6ea2e9c..360c0f2bc 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/generators.py +++ b/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 """ -LLM answer generation client for the QA evaluation pipeline. +Unified LLM answer generation client. LiteLLMClient wraps the litellm library which provides a single interface for routing to NVIDIA NIM, OpenAI, HuggingFace Inference Endpoints, and @@ -16,8 +16,9 @@ import time from typing import Any, Optional -from nemo_retriever.evaluation.text_utils import strip_think_tags -from nemo_retriever.evaluation.types import GenerationResult +from nemo_retriever.llm.text_utils import strip_think_tags +from nemo_retriever.llm.types import GenerationResult +from nemo_retriever.params.models import LLMInferenceParams, LLMRemoteClientParams logger = logging.getLogger(__name__) @@ -58,49 +59,87 @@ class LiteLLMClient: Provider API keys are read from environment variables automatically (NVIDIA_API_KEY, OPENAI_API_KEY, HUGGINGFACE_API_KEY, etc.). + + Configuration is split into two orthogonal Pydantic objects: + + * ``transport``: :class:`~nemo_retriever.params.LLMRemoteClientParams` + owns provider endpoint, authentication, retry, and timeout. + * ``sampling``: :class:`~nemo_retriever.params.LLMInferenceParams` + owns ``temperature``, ``top_p``, and ``max_tokens``. + + Use :meth:`from_kwargs` for a flat, backwards-compatible constructor. """ + _DEFAULT_MODEL: str = "nvidia_nim/nvidia/llama-3.3-nemotron-super-49b-v1.5" + def __init__( self, - model: str, + transport: LLMRemoteClientParams, + sampling: Optional[LLMInferenceParams] = None, + ): + self.transport = transport + self.sampling = sampling if sampling is not None else LLMInferenceParams() + + @property + def model(self) -> str: + """Return the model identifier from the transport params.""" + return self.transport.model + + @classmethod + def from_kwargs( + cls, + *, + model: str = _DEFAULT_MODEL, api_base: Optional[str] = None, api_key: Optional[str] = None, temperature: float = 0.0, - top_p: float = 1.0, + top_p: Optional[float] = None, max_tokens: int = 4096, extra_params: Optional[dict[str, Any]] = None, num_retries: int = 3, timeout: float = 120.0, - ): - self.model = model - self.api_base = api_base - self.api_key = api_key - self.temperature = temperature - self.top_p = top_p - self.max_tokens = max_tokens - self.extra_params = extra_params or {} - self.num_retries = num_retries - self.timeout = timeout + ) -> "LiteLLMClient": + """Flat-kwarg constructor for zero-churn migration from the old signature. + + Splits the flat kwargs into the two structured params objects. All + validation (temperature range, ``num_retries >= 0``, ``timeout > 0``) + is delegated to the Pydantic models. + """ + transport = LLMRemoteClientParams( + model=model, + api_base=api_base, + api_key=api_key, + num_retries=num_retries, + timeout=timeout, + extra_params=extra_params or {}, + ) + sampling = LLMInferenceParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + ) + return cls(transport=transport, sampling=sampling) def complete(self, messages: list[dict], max_tokens: Optional[int] = None) -> tuple[str, float]: """Raw litellm completion call. Returns (content_text, latency_s).""" import litellm + sampling_kwargs = self.sampling.to_sampling_kwargs() + if max_tokens is not None: + sampling_kwargs["max_tokens"] = max_tokens + call_kwargs: dict[str, Any] = { - "model": self.model, + "model": self.transport.model, "messages": messages, - "temperature": self.temperature, - "max_tokens": max_tokens if max_tokens is not None else self.max_tokens, - "num_retries": self.num_retries, - "timeout": self.timeout, + "num_retries": self.transport.num_retries, + "timeout": self.transport.timeout, + **sampling_kwargs, } - if self.top_p is not None and self.top_p != 1.0: - call_kwargs["top_p"] = self.top_p - if self.api_base: - call_kwargs["api_base"] = self.api_base - if self.api_key: - call_kwargs["api_key"] = self.api_key - call_kwargs.update(self.extra_params) + if self.transport.api_base: + call_kwargs["api_base"] = self.transport.api_base + if self.transport.api_key: + call_kwargs["api_key"] = self.transport.api_key + call_kwargs.update(self.transport.extra_params) t0 = time.monotonic() try: @@ -114,7 +153,7 @@ def complete(self, messages: list[dict], max_tokens: Optional[int] = None) -> tu "only accept one. Either remove `top_p` from the model " "config or set `temperature` to null. Sent: " "temperature=%s, top_p=%s", - self.model, + self.transport.model, call_kwargs.get("temperature"), call_kwargs.get("top_p"), ) @@ -133,15 +172,15 @@ def generate(self, query: str, chunks: list[str]) -> GenerationResult: return GenerationResult( answer="", latency_s=latency, - model=self.model, + model=self.transport.model, error="thinking_truncated", ) - return GenerationResult(answer=answer, latency_s=latency, model=self.model) + return GenerationResult(answer=answer, latency_s=latency, model=self.transport.model) except Exception as exc: - logger.debug("Generation failed for model=%s: %s", self.model, exc) + logger.debug("Generation failed for model=%s: %s", self.transport.model, exc) return GenerationResult( answer="", latency_s=0.0, - model=self.model, + model=self.transport.model, error=str(exc), ) diff --git a/nemo_retriever/src/nemo_retriever/llm/text_utils.py b/nemo_retriever/src/nemo_retriever/llm/text_utils.py new file mode 100644 index 000000000..34b8634a2 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/llm/text_utils.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Shared text-processing utilities for LLM output hygiene. + +Pure-stdlib module. Lives under ``nemo_retriever.llm`` so that the +lightweight SDK surface (``LiteLLMClient``, ``Retriever.answer``) does +not pull in ``pandas`` or any evaluation dependencies just to clean +```` tags out of a model response. +""" + +from __future__ import annotations + +import re + + +def strip_think_tags(text: str) -> str: + """Remove ``...`` reasoning blocks from model output. + + Handles both closed tags (``...``) and unclosed tags + where the model hit the token limit mid-reasoning and never emitted + ````. Returns an empty string if nothing remains after + stripping so callers can detect ``thinking_truncated``. + """ + stripped = re.sub(r".*?", "", text, flags=re.DOTALL) + stripped = re.sub(r".*", "", stripped, flags=re.DOTALL) + return stripped.strip() diff --git a/nemo_retriever/src/nemo_retriever/llm/types.py b/nemo_retriever/src/nemo_retriever/llm/types.py new file mode 100644 index 000000000..a0030aba9 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/llm/types.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Protocol definitions and result dataclasses for LLM-based pipelines. + +These abstractions allow retrieval strategies, LLM clients, and judges +to be swapped independently. They are consumed by both the evaluation +framework (``nemo_retriever.evaluation``) and the live RAG surface on +``nemo_retriever.retriever.Retriever``. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional, Protocol, runtime_checkable + + +@runtime_checkable +class RetrieverStrategy(Protocol): + """Pluggable retrieval strategy interface.""" + + def retrieve(self, query: str, top_k: int) -> "RetrievalResult": ... + + +@runtime_checkable +class LLMClient(Protocol): + """Pluggable LLM answer generation interface.""" + + def generate(self, query: str, chunks: list[str]) -> "GenerationResult": ... + + +@runtime_checkable +class AnswerJudge(Protocol): + """Pluggable answer scoring interface.""" + + def judge(self, query: str, reference: str, candidate: str) -> "JudgeResult": ... + + +@dataclass +class RetrievalResult: + """Result from a retrieval operation.""" + + chunks: list[str] + metadata: list[dict[str, Any]] = field(default_factory=list) + + +@dataclass +class GenerationResult: + """Result from a single LLM generation call.""" + + answer: str + latency_s: float + model: str + error: Optional[str] = None + + +@dataclass +class JudgeResult: + """Result from a single judge evaluation. + + ``score`` is ``None`` when the judge could not produce a score + (API error, parse failure, empty candidate). Valid scores are 1-5. + """ + + score: Optional[int] = None + reasoning: str = "" + error: Optional[str] = None + + +@dataclass +class AnswerResult: + """Result from a single live-RAG call to ``Retriever.answer``. + + Holds the generated answer alongside the retrieved context that was used + to produce it and -- when a ``reference`` answer and/or ``judge`` are + supplied -- the Tier-1 / Tier-2 / Tier-3 scoring artefacts produced by + :mod:`nemo_retriever.evaluation.scoring` and + :class:`~nemo_retriever.llm.clients.judge.LLMJudge`. + + Attributes: + query: The question that was answered. + answer: The generated answer text. + chunks: Retrieved chunk texts used as context, in rank order. + metadata: Per-chunk metadata (source, page_number, etc.), aligned + with ``chunks``. + model: Model identifier that produced ``answer``. + latency_s: Wall-clock latency of the generation call in seconds. + error: Non-None when generation failed. Scoring and judge are + skipped when ``error`` is set. + judge_score: LLM-judge Tier-3 score (1-5) when a judge was run. + judge_reasoning: One-sentence rationale emitted by the judge. + judge_error: Non-None when the judge call failed. + token_f1: Tier-2 token-level F1 between ``answer`` and the + reference answer (0.0-1.0). + exact_match: Tier-2 normalised exact-match flag. + answer_in_context: Tier-1 flag -- True if at least half of the + reference answer's content words appear in the retrieved chunks. + failure_mode: Classification produced by + :func:`~nemo_retriever.evaluation.scoring.classify_failure`. + """ + + query: str + answer: str + chunks: list[str] + metadata: list[dict[str, Any]] + model: str + latency_s: float + error: Optional[str] = None + judge_score: Optional[int] = None + judge_reasoning: Optional[str] = None + judge_error: Optional[str] = None + token_f1: Optional[float] = None + exact_match: Optional[bool] = None + answer_in_context: Optional[bool] = None + failure_mode: Optional[str] = None diff --git a/nemo_retriever/src/nemo_retriever/params/__init__.py b/nemo_retriever/src/nemo_retriever/params/__init__.py index a70698d9b..81c5636ad 100644 --- a/nemo_retriever/src/nemo_retriever/params/__init__.py +++ b/nemo_retriever/src/nemo_retriever/params/__init__.py @@ -18,6 +18,7 @@ from .models import IngestorCreateParams from .models import LanceDbParams from .models import LLMInferenceParams +from .models import LLMRemoteClientParams from .models import ModelRuntimeParams from .models import OcrParams from .models import PageElementsParams @@ -48,6 +49,7 @@ "IngestorCreateParams", "LanceDbParams", "LLMInferenceParams", + "LLMRemoteClientParams", "ModelRuntimeParams", "OcrParams", "PageElementsParams", diff --git a/nemo_retriever/src/nemo_retriever/params/models.py b/nemo_retriever/src/nemo_retriever/params/models.py index c899cbc48..4a49887e3 100644 --- a/nemo_retriever/src/nemo_retriever/params/models.py +++ b/nemo_retriever/src/nemo_retriever/params/models.py @@ -24,13 +24,36 @@ NO_API_KEY = "" +_REDACTED = "***" + + +def _is_api_key_field(field_name: str) -> bool: + """Return True when ``field_name`` should be masked in ``repr`` / logs.""" + return field_name == "api_key" or field_name.endswith("_api_key") + + class _ParamsModel(BaseModel): + """Shared base for all remote-transport Pydantic params models. + + Two cross-cutting behaviours live here: + + * :meth:`_resolve_api_keys` auto-fills unset ``*api_key`` fields from + ``NVIDIA_API_KEY`` / ``NGC_API_KEY`` (see + :func:`nemo_retriever.utils.remote_auth.resolve_remote_api_key`). + * :meth:`__repr__` redacts every field whose name matches + :func:`_is_api_key_field` so that logging a transport object (or + letting Pydantic's default error formatter echo one back) never + prints a bearer token. The underlying field still serialises as + a plain ``str`` via ``.model_dump()`` / ``getattr(self, field)`` + so no downstream consumer needs changes. + """ + model_config = ConfigDict(extra="forbid") @model_validator(mode="after") def _resolve_api_keys(self) -> "_ParamsModel": for field_name in type(self).model_fields: - if field_name == "api_key" or field_name.endswith("_api_key"): + if _is_api_key_field(field_name): value = getattr(self, field_name, None) if value is None: setattr(self, field_name, resolve_remote_api_key()) @@ -38,6 +61,18 @@ def _resolve_api_keys(self) -> "_ParamsModel": setattr(self, field_name, None) return self + def __repr__(self) -> str: + parts: list[str] = [] + for field_name in type(self).model_fields: + value = getattr(self, field_name, None) + if _is_api_key_field(field_name) and value: + parts.append(f"{field_name}={_REDACTED}") + else: + parts.append(f"{field_name}={value!r}") + return f"{type(self).__name__}({', '.join(parts)})" + + __str__ = __repr__ + class RemoteRetryParams(_ParamsModel): remote_max_pool_workers: int = 8 @@ -392,6 +427,36 @@ def to_sampling_kwargs(self) -> dict[str, Any]: return kw +class LLMRemoteClientParams(_ParamsModel): + """Transport / connection parameters for any remote LLM client. + + Pairs with :class:`LLMInferenceParams` (sampling) to fully specify a + call. ``api_key`` is auto-resolved from the environment by + :class:`_ParamsModel` when left as ``None``. + """ + + model: str + api_base: Optional[str] = None + api_key: Optional[str] = None + num_retries: int = 3 + timeout: float = 120.0 + extra_params: dict[str, Any] = Field(default_factory=dict) + + @field_validator("num_retries") + @classmethod + def _check_retries(cls, v: int) -> int: + if v < 0: + raise ValueError("num_retries must be >= 0") + return v + + @field_validator("timeout") + @classmethod + def _check_timeout(cls, v: float) -> float: + if v <= 0: + raise ValueError("timeout must be > 0") + return v + + class CaptionParams(LLMInferenceParams): endpoint_url: Optional[str] = None model_name: str = "nvidia/NVIDIA-Nemotron-Nano-12B-v2-VL-BF16" diff --git a/nemo_retriever/src/nemo_retriever/retriever.py b/nemo_retriever/src/nemo_retriever/retriever.py index 59d3e585a..4390e7d37 100644 --- a/nemo_retriever/src/nemo_retriever/retriever.py +++ b/nemo_retriever/src/nemo_retriever/retriever.py @@ -6,11 +6,21 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Optional, Sequence +from typing import TYPE_CHECKING, Any, Optional, Sequence from tqdm import tqdm from nemo_retriever.model import VL_EMBED_MODEL, VL_RERANK_MODEL +if TYPE_CHECKING: + import pandas as pd + + from nemo_retriever.llm.types import ( + AnswerJudge, + AnswerResult, + LLMClient, + RetrievalResult, + ) + _KEEP_KEYS = frozenset( { "text", @@ -120,7 +130,7 @@ def _embed_queries_nim( query_texts, model_name=model, embedding_endpoint=endpoint, - nvidia_api_key=(self.embedding_api_key or "").strip(), + nvidia_api_key=self.embedding_api_key, input_type="query", ) out: list[list[float]] = [] @@ -349,6 +359,333 @@ def queries( return results + # ------------------------------------------------------------------ + # Live RAG API (structured retrieval + generation) + # ------------------------------------------------------------------ + + def retrieve( + self, + query: str, + top_k: Optional[int] = None, + *, + embedder: Optional[str] = None, + lancedb_uri: Optional[str] = None, + lancedb_table: Optional[str] = None, + ) -> "RetrievalResult": + """Run retrieval for a single query and return a structured result. + + Thin adapter over :meth:`query` that reshapes the raw LanceDB hits + into a :class:`~nemo_retriever.llm.RetrievalResult` with ``chunks`` + (the retrieved text, in rank order) and aligned ``metadata`` + (source, page_number, etc.). Satisfies the + :class:`~nemo_retriever.llm.RetrieverStrategy` Protocol. + + Args: + query: The natural-language query. + top_k: Override ``self.top_k`` for this call. When ``None`` the + instance attribute is used. + embedder: Override ``self.embedder`` for this call. + lancedb_uri: Override ``self.lancedb_uri`` for this call. + lancedb_table: Override ``self.lancedb_table`` for this call. + + Returns: + A :class:`~nemo_retriever.llm.RetrievalResult` whose ``chunks`` + and ``metadata`` lists have the same length. + + Example: + >>> retriever = Retriever(lancedb_uri="./kb") + >>> result = retriever.retrieve("What is RAG?", top_k=3) + >>> result.chunks[0][:40] # doctest: +SKIP + 'Retrieval augmented generation combines...' + """ + from nemo_retriever.llm.types import RetrievalResult + + previous_top_k = self.top_k + if top_k is not None: + self.top_k = int(top_k) + try: + hits = self.query( + query, + embedder=embedder, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + ) + finally: + self.top_k = previous_top_k + + chunks: list[str] = [] + metadata: list[dict[str, Any]] = [] + for hit in hits: + chunks.append(str(hit.get("text", ""))) + metadata.append({k: v for k, v in hit.items() if k != "text"}) + return RetrievalResult(chunks=chunks, metadata=metadata) + + def retrieve_batch( + self, + queries: Sequence[str], + *, + top_k: Optional[int] = None, + embedder: Optional[str] = None, + lancedb_uri: Optional[str] = None, + lancedb_table: Optional[str] = None, + ) -> list["RetrievalResult"]: + """Run retrieval for a batch of queries in a single embedder call. + + This is the batched analogue of :meth:`retrieve`. It funnels the + whole query list through :meth:`queries`, which already dispatches + exactly one call to ``_embed_queries_nim`` (or the local HF + embedder) regardless of ``len(queries)``. Callers that previously + looped over :meth:`retrieve` per row pay ``N`` sequential round + trips to the embed service; routing through ``retrieve_batch`` + collapses that to a single request and a single LanceDB search + sweep. + + Args: + queries: Iterable of natural-language query strings. Order + is preserved in the returned list. + top_k: Per-call override of ``self.top_k`` (scoped via + ``try/finally`` so the instance attribute is restored on + return, mirroring :meth:`retrieve`). + embedder: Per-call embedder override. + lancedb_uri: Per-call LanceDB URI override. + lancedb_table: Per-call LanceDB table override. + + Returns: + A list of :class:`~nemo_retriever.llm.RetrievalResult`, + aligned one-to-one with ``queries``. Empty input returns an + empty list. + """ + + from nemo_retriever.llm.types import RetrievalResult + + query_texts = [str(q) for q in queries] + if not query_texts: + return [] + + previous_top_k = self.top_k + if top_k is not None: + self.top_k = int(top_k) + try: + hits_per_query = self.queries( + query_texts, + embedder=embedder, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + ) + finally: + self.top_k = previous_top_k + + results: list[RetrievalResult] = [] + for hits in hits_per_query: + chunks = [str(hit.get("text", "")) for hit in hits] + metadata = [{k: v for k, v in hit.items() if k != "text"} for hit in hits] + results.append(RetrievalResult(chunks=chunks, metadata=metadata)) + return results + + def answer( + self, + query: str, + *, + llm: "LLMClient", + judge: Optional["AnswerJudge"] = None, + reference: Optional[str] = None, + top_k: Optional[int] = None, + embedder: Optional[str] = None, + lancedb_uri: Optional[str] = None, + lancedb_table: Optional[str] = None, + ) -> "AnswerResult": + """Run live RAG for a single query and optionally score the answer. + + Performs ``retrieve -> llm.generate`` and, when a ``reference`` answer + (for token-level scoring) and/or a ``judge`` (for LLM-as-judge + scoring) are supplied, fans those out concurrently on a small thread + pool so the judge network call and the local token-F1 computation do + not serialize. + + Scoring tiers that can be populated on the returned + :class:`~nemo_retriever.llm.AnswerResult`: + + * Tier 1 -- ``answer_in_context`` (requires ``reference``) + * Tier 2 -- ``token_f1``, ``exact_match`` (requires ``reference``) + * Tier 3 -- ``judge_score``, ``judge_reasoning`` (requires ``judge`` + and ``reference``); also populates ``failure_mode`` + + When generation fails the returned result has ``error`` populated + and all scoring/judge fields remain ``None`` -- scoring is skipped + to avoid misleading metrics on an empty answer. + + Args: + query: Natural-language question. + llm: Any object satisfying the + :class:`~nemo_retriever.llm.LLMClient` Protocol (typically + :class:`~nemo_retriever.llm.LiteLLMClient`). + judge: Optional LLM-as-judge. Requires ``reference``. + reference: Ground-truth answer for token-F1 and judge scoring. + top_k: Per-call override of ``self.top_k``. + embedder: Per-call override of ``self.embedder``. + lancedb_uri: Per-call override of ``self.lancedb_uri``. + lancedb_table: Per-call override of ``self.lancedb_table``. + + Returns: + An :class:`~nemo_retriever.llm.AnswerResult` carrying the + generated answer, the retrieved context, and any scoring + artefacts that were requested. + + Raises: + ValueError: If ``judge`` is supplied without ``reference``. + + Example: + >>> from nemo_retriever.llm import LiteLLMClient + >>> retriever = Retriever(lancedb_uri="./kb") + >>> llm = LiteLLMClient.from_kwargs( + ... model="nvidia_nim/meta/llama-3.3-70b-instruct", + ... ) + >>> result = retriever.answer( # doctest: +SKIP + ... "What did Q4 revenue look like?", + ... llm=llm, + ... ) + >>> result.answer # doctest: +SKIP + 'Revenue grew 12% YoY to $4.2B...' + """ + from nemo_retriever.llm.types import AnswerResult + + if judge is not None and reference is None: + raise ValueError("judge requires reference") + + retrieved = self.retrieve( + query, + top_k=top_k, + embedder=embedder, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + ) + + gen = llm.generate(query, retrieved.chunks) + + result = AnswerResult( + query=query, + answer=gen.answer, + chunks=retrieved.chunks, + metadata=retrieved.metadata, + model=gen.model, + latency_s=gen.latency_s, + error=gen.error, + ) + + if gen.error is not None: + return result + + if reference is None and judge is None: + return result + + self._populate_scores( + result, + query=query, + reference=reference, + judge=judge, + gen_error=gen.error, + ) + return result + + def _populate_scores( + self, + result: "AnswerResult", + *, + query: str, + reference: Optional[str], + judge: Optional["AnswerJudge"], + gen_error: Optional[str], + ) -> None: + """Populate scoring tiers on ``result`` in-place. + + Runs Tier-1 + Tier-2 (pure CPU, sub-millisecond) alongside the Tier-3 + judge API call (network-bound) on a two-worker thread pool so the + judge latency is not extended by scoring. After both complete, + ``failure_mode`` is derived from the combined signals via + :func:`~nemo_retriever.evaluation.scoring.classify_failure`. + """ + from concurrent.futures import ThreadPoolExecutor + + from nemo_retriever.evaluation.scoring import ( + answer_in_context, + classify_failure, + token_f1, + ) + + def _scoring() -> tuple[Optional[bool], Optional[float], Optional[bool]]: + if reference is None: + return None, None, None + aic = answer_in_context(reference, result.chunks) + f1 = token_f1(reference, result.answer) + return aic, float(f1.get("f1", 0.0)), bool(f1.get("exact_match", False)) + + def _judging() -> tuple[Optional[int], Optional[str], Optional[str]]: + if judge is None or reference is None: + return None, None, None + jr = judge.judge(query, reference, result.answer) + return jr.score, jr.reasoning, jr.error + + with ThreadPoolExecutor(max_workers=2) as pool: + scoring_future = pool.submit(_scoring) + judge_future = pool.submit(_judging) + aic, f1, em = scoring_future.result() + judge_score, judge_reasoning, judge_error = judge_future.result() + + result.answer_in_context = aic + result.token_f1 = f1 + result.exact_match = em + result.judge_score = judge_score + result.judge_reasoning = judge_reasoning + result.judge_error = judge_error + + if reference is not None and aic is not None: + result.failure_mode = classify_failure( + ref_in_chunks=aic, + judge_score=judge_score, + gen_error=gen_error, + candidate=result.answer, + ) + + def pipeline(self, *, top_k: Optional[int] = None) -> "RetrieverPipelineBuilder": + """Return a fluent builder for a batch live-RAG operator graph. + + The builder composes existing evaluation operators -- live retrieval + (via :class:`~nemo_retriever.evaluation.live_retrieval.LiveRetrievalOperator`), + :class:`~nemo_retriever.evaluation.generation.QAGenerationOperator`, + :class:`~nemo_retriever.evaluation.scoring_operator.ScoringOperator`, + and :class:`~nemo_retriever.evaluation.judging.JudgingOperator` -- + using the existing ``>>`` chaining from + :mod:`nemo_retriever.graph.pipeline_graph`. No new graph primitives + are introduced; this method is sugar for building and executing that + graph against a list of queries. + + Steps are optional and independent. Call only the ones you want, in + any order (retrieval always runs first since it is the source). + + Args: + top_k: Override ``self.top_k`` for retrieval within this + pipeline. Defaults to the instance attribute. + + Returns: + A :class:`RetrieverPipelineBuilder` whose ``.run(queries)`` method + executes the composed graph and returns a ``pandas.DataFrame``. + + Example: + >>> from nemo_retriever.llm import LiteLLMClient, LLMJudge + >>> retriever = Retriever(lancedb_uri="./kb") + >>> llm = LiteLLMClient.from_kwargs(model="nvidia_nim/meta/llama-3.3-70b-instruct") + >>> judge = LLMJudge.from_kwargs(model="nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1") + >>> df = ( # doctest: +SKIP + ... retriever.pipeline() + ... .generate(llm) + ... .score() + ... .judge(judge) + ... .run(queries=["What is RAG?"], reference=["Retrieval-augmented generation..."]) + ... ) + """ + effective_top_k = int(top_k) if top_k is not None else int(self.top_k) + return RetrieverPipelineBuilder(self, top_k=effective_top_k) + def generate_sql(self, query: str) -> str: """Generate a SQL query for a given natural language query.""" from nemo_retriever.tabular_data.retrieval import generate_sql @@ -356,5 +693,196 @@ def generate_sql(self, query: str) -> str: return generate_sql(query) +class RetrieverPipelineBuilder: + """Fluent builder for live-RAG batch operator graphs. + + Returned from :meth:`Retriever.pipeline`. Each builder method appends + an :class:`~nemo_retriever.evaluation.eval_operator.EvalOperator` to an + internal list; :meth:`run` composes them into a graph via the existing + ``>>`` chaining and executes it on a DataFrame built from the provided + queries. + + Example: + >>> builder = retriever.pipeline() # doctest: +SKIP + >>> df = builder.generate(llm).score().judge(judge).run( # doctest: +SKIP + ... queries=["q1", "q2"], + ... reference=["r1", "r2"], + ... ) + """ + + def __init__(self, retriever: "Retriever", *, top_k: int = 5) -> None: + self._retriever = retriever + self._top_k = int(top_k) + self._steps: list[Any] = [] + + def with_retrieval(self, *, top_k: int) -> "RetrieverPipelineBuilder": + """Override the ``top_k`` used for the live retrieval source.""" + self._top_k = int(top_k) + return self + + def generate( + self, + llm: Optional[Any] = None, + /, + *, + model: Optional[str] = None, + **kwargs: Any, + ) -> "RetrieverPipelineBuilder": + """Append a :class:`QAGenerationOperator` step. + + Accepts either a pre-built + :class:`~nemo_retriever.llm.clients.LiteLLMClient` (whose transport + and sampling params are unpacked onto the operator) or the flat + ``model=..., api_base=..., ...`` kwargs forwarded to the operator + constructor directly. + + Raises: + ValueError: If neither ``llm`` nor ``model`` is provided. + """ + from nemo_retriever.evaluation.generation import QAGenerationOperator + + if llm is None and model is None: + raise ValueError("generate() requires either llm= or model=") + + if llm is not None: + transport = llm.transport + sampling = llm.sampling + operator = QAGenerationOperator( + model=transport.model, + api_base=transport.api_base, + api_key=transport.api_key, + temperature=sampling.temperature, + max_tokens=sampling.max_tokens, + extra_params=dict(transport.extra_params) if transport.extra_params else None, + num_retries=transport.num_retries, + timeout=transport.timeout, + ) + else: + operator = QAGenerationOperator(model=model, **kwargs) + + self._steps.append(operator) + return self + + def score(self) -> "RetrieverPipelineBuilder": + """Append a :class:`ScoringOperator` step (Tier 1 + Tier 2).""" + from nemo_retriever.evaluation.scoring_operator import ScoringOperator + + self._steps.append(ScoringOperator()) + return self + + def judge( + self, + judge: Optional[Any] = None, + /, + *, + model: Optional[str] = None, + **kwargs: Any, + ) -> "RetrieverPipelineBuilder": + """Append a :class:`JudgingOperator` step (Tier 3). + + Accepts either a pre-built + :class:`~nemo_retriever.llm.clients.judge.LLMJudge` (whose transport params + are unpacked onto the operator) or the flat ``model=...`` kwargs + forwarded to the operator constructor. + + Raises: + ValueError: If neither ``judge`` nor ``model`` is provided. + """ + from nemo_retriever.evaluation.judging import JudgingOperator + + if judge is None and model is None: + raise ValueError("judge() requires either judge= or model=") + + if judge is not None: + transport = judge._client.transport + operator = JudgingOperator( + model=transport.model, + api_base=transport.api_base, + api_key=transport.api_key, + extra_params=dict(transport.extra_params) if transport.extra_params else None, + timeout=transport.timeout, + ) + else: + operator = JudgingOperator(model=model, **kwargs) + + self._steps.append(operator) + return self + + def run( + self, + queries: Any, + *, + reference: Any = None, + ) -> "pd.DataFrame": + """Execute the composed graph on ``queries``. + + Args: + queries: A single query string, a list of query strings, or a + pre-built ``pandas.DataFrame`` (which must contain a + ``query`` column and, when judging/scoring, a + ``reference_answer`` column). + reference: Optional ground-truth answer(s). Accepts a single + string (applied to all queries), a list aligned with + ``queries``, or ``None``. Ignored when ``queries`` is + already a DataFrame. + + Returns: + A ``pandas.DataFrame`` with the columns contributed by each + appended step (always ``query``, ``context``, and + ``context_metadata``; plus ``answer``/``latency_s``/... when + ``.generate()`` ran, and so on). + + Raises: + ValueError: If ``reference`` is a list whose length does not + match ``queries``. + """ + import pandas as pd + + from nemo_retriever.evaluation.live_retrieval import LiveRetrievalOperator + + if isinstance(queries, str): + query_list = [queries] + df = pd.DataFrame({"query": query_list}) + if reference is not None: + refs = reference if isinstance(reference, list) else [reference] + if len(refs) != len(query_list): + raise ValueError("reference length must match queries length") + df["reference_answer"] = refs + elif isinstance(queries, list): + df = pd.DataFrame({"query": list(queries)}) + if reference is not None: + refs = reference if isinstance(reference, list) else [reference] * len(queries) + if len(refs) != len(queries): + raise ValueError("reference length must match queries length") + df["reference_answer"] = refs + elif isinstance(queries, pd.DataFrame): + df = queries.copy() + else: + raise TypeError("queries must be a str, list[str], or pandas.DataFrame; " f"got {type(queries).__name__}") + + retrieval_op = LiveRetrievalOperator(self._retriever, top_k=self._top_k) + if not self._steps: + out = retrieval_op.run(df) + else: + graph = retrieval_op + for step in self._steps: + graph = graph >> step + # ``Graph.execute`` returns one entry per leaf node; a linear + # live-RAG pipeline has exactly one leaf. + leaves = graph.execute(df) + if len(leaves) != 1: + raise RuntimeError(f"Unexpected pipeline fan-out: got {len(leaves)} leaf outputs") + out = leaves[0] + + # Surface generation failure rate when the pipeline ran generation. + # ``QAGenerationOperator`` writes ``gen_error`` per row (non-null on + # failure); downstream aggregators read ``df.attrs`` without having + # to scan rows or guard on column presence themselves. + if "gen_error" in out.columns and len(out) > 0: + out.attrs["generation_failure_rate"] = float(out["gen_error"].notna().mean()) + + return out + + # Backward compatibility alias. retriever = Retriever diff --git a/nemo_retriever/tests/test_live_rag.py b/nemo_retriever/tests/test_live_rag.py new file mode 100644 index 000000000..d9086b32e --- /dev/null +++ b/nemo_retriever/tests/test_live_rag.py @@ -0,0 +1,558 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the live RAG SDK surface on ``Retriever``. + +Covers: + * Protocol compliance of ``Retriever`` against ``RetrieverStrategy``. + * ``Retriever.retrieve`` shape (``RetrievalResult`` with aligned + ``chunks`` / ``metadata`` from the raw ``.query()`` hits). + * ``Retriever.answer`` for all four tiers: + - no reference -> scoring and judge skipped. + - reference without judge -> Tier 1+2 populated, Tier 3 None. + - reference with judge -> all tiers populated, ``failure_mode`` set. + - judge without reference -> ``ValueError``. + * Generation error short-circuits scoring and judge. + * Scoring runs concurrently with the judge (wall-clock proof). + * ``RetrieverPipelineBuilder`` composition and skip-steps behaviour. + * ``LiveRetrievalOperator.process`` populates ``context`` and + ``context_metadata``. +""" + +from __future__ import annotations + +import time +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest + + +def _make_retriever(): + """Build a bare ``Retriever`` instance with all defaults.""" + from nemo_retriever.retriever import Retriever + + return Retriever() + + +def _fake_hits() -> list[dict]: + """Three fake LanceDB hits matching the shape Retriever.query() returns.""" + return [ + {"text": "Retrieval augmented generation combines context and LLMs.", "source": "doc-1.pdf", "page_number": 1}, + { + "text": "RAG pipelines retrieve passages then feed them to a generator.", + "source": "doc-1.pdf", + "page_number": 2, + }, + {"text": "Noisy unrelated chunk.", "source": "doc-9.pdf", "page_number": 7}, + ] + + +def _fake_generation(answer: str = "RAG retrieves context and uses an LLM.", error: str | None = None): + """Build a GenerationResult with the given answer / error.""" + from nemo_retriever.llm.types import GenerationResult + + return GenerationResult(answer=answer, latency_s=0.12, model="fake-llm/test", error=error) + + +def _fake_judge_result(score: int | None = 5, reasoning: str = "correct and complete"): + from nemo_retriever.llm.types import JudgeResult + + return JudgeResult(score=score, reasoning=reasoning, error=None) + + +class TestRetrieveProtocol: + """Retriever as a RetrieverStrategy adapter over .query().""" + + def test_retriever_satisfies_protocol(self): + """isinstance check must pass via @runtime_checkable.""" + from nemo_retriever.llm.types import RetrieverStrategy + + r = _make_retriever() + assert isinstance(r, RetrieverStrategy) + + def test_retrieve_returns_result_shape(self): + """``retrieve`` adapts ``.query()`` hits into a RetrievalResult.""" + from nemo_retriever.llm.types import RetrievalResult + + r = _make_retriever() + with patch.object(r, "query", return_value=_fake_hits()) as mock_query: + result = r.retrieve("What is RAG?", top_k=3) + + assert isinstance(result, RetrievalResult) + assert len(result.chunks) == 3 + assert result.chunks[0].startswith("Retrieval augmented") + assert len(result.metadata) == 3 + assert result.metadata[0] == {"source": "doc-1.pdf", "page_number": 1} + assert "text" not in result.metadata[0] + mock_query.assert_called_once() + + def test_retrieve_top_k_override_is_scoped(self): + """``top_k`` override applies only for the call, then restores.""" + r = _make_retriever() + original_top_k = r.top_k + with patch.object(r, "query", return_value=_fake_hits()): + r.retrieve("q", top_k=3) + assert r.top_k == original_top_k + + +class TestAnswer: + """Retriever.answer -- retrieve -> generate -> optional scoring + judge.""" + + def test_answer_without_reference(self): + """No reference, no judge -> scoring / judge fields all None.""" + r = _make_retriever() + llm = MagicMock() + llm.generate.return_value = _fake_generation() + + with patch.object(r, "query", return_value=_fake_hits()): + result = r.answer("q?", llm=llm) + + assert result.answer == "RAG retrieves context and uses an LLM." + assert result.chunks and result.metadata + assert result.model == "fake-llm/test" + assert result.error is None + assert result.token_f1 is None + assert result.exact_match is None + assert result.answer_in_context is None + assert result.judge_score is None + assert result.failure_mode is None + + def test_answer_with_reference_no_judge(self): + """Reference supplied -> Tier 1+2 populated, Tier 3 left None.""" + r = _make_retriever() + llm = MagicMock() + llm.generate.return_value = _fake_generation( + answer="RAG retrieves passages and feeds them to a generator.", + ) + + with patch.object(r, "query", return_value=_fake_hits()): + result = r.answer( + "What is RAG?", + llm=llm, + reference="RAG retrieves passages and feeds them to a generator.", + ) + + assert result.token_f1 == pytest.approx(1.0, abs=1e-6) + assert result.exact_match is True + assert result.answer_in_context is not None + assert result.judge_score is None + assert result.judge_reasoning is None + + def test_answer_with_reference_and_judge(self): + """All tiers populated, ``failure_mode`` derived from combined signals.""" + r = _make_retriever() + llm = MagicMock() + llm.generate.return_value = _fake_generation() + judge = MagicMock() + judge.judge.return_value = _fake_judge_result(score=5) + + with patch.object(r, "query", return_value=_fake_hits()): + result = r.answer( + "What is RAG?", + llm=llm, + judge=judge, + reference="RAG retrieves context and uses an LLM.", + ) + + assert result.judge_score == 5 + assert result.judge_reasoning == "correct and complete" + assert result.token_f1 is not None + assert result.exact_match is not None + assert result.answer_in_context is not None + assert result.failure_mode == "correct" + judge.judge.assert_called_once() + + def test_answer_judge_requires_reference(self): + """Passing a judge without a reference must raise ValueError.""" + r = _make_retriever() + llm = MagicMock() + judge = MagicMock() + + with pytest.raises(ValueError, match="judge requires reference"): + r.answer("q", llm=llm, judge=judge) + + def test_answer_generation_error_short_circuits(self): + """On generation error: result.error set, scoring and judge skipped.""" + r = _make_retriever() + llm = MagicMock() + llm.generate.return_value = _fake_generation(answer="", error="TimeoutError") + judge = MagicMock() + + with patch.object(r, "query", return_value=_fake_hits()): + result = r.answer( + "q", + llm=llm, + judge=judge, + reference="expected", + ) + + assert result.error == "TimeoutError" + assert result.token_f1 is None + assert result.judge_score is None + assert result.failure_mode is None + judge.judge.assert_not_called() + + def test_answer_concurrent_scoring_and_judge(self): + """Scoring + judge must run concurrently, not serially. + + We make the judge sleep 400ms. Scoring is sub-millisecond pure-CPU, + so if scoring + judge run in parallel the total wall time is + dominated by the judge. A serial implementation would add scoring + time on top; on modern CPUs scoring is <5ms so the margin is tight, + but we validate the upper bound at judge_time + 200ms to keep the + test robust under CI jitter. + """ + r = _make_retriever() + llm = MagicMock() + llm.generate.return_value = _fake_generation() + + judge_latency = 0.4 + judge = MagicMock() + + def _slow_judge(query, reference, candidate): + time.sleep(judge_latency) + return _fake_judge_result(score=4) + + judge.judge.side_effect = _slow_judge + + with patch.object(r, "query", return_value=_fake_hits()): + start = time.perf_counter() + result = r.answer( + "q", + llm=llm, + judge=judge, + reference="RAG retrieves context and uses an LLM.", + ) + elapsed = time.perf_counter() - start + + assert result.judge_score == 4 + assert result.token_f1 is not None + assert ( + elapsed < judge_latency + 0.2 + ), f"Expected concurrent scoring+judge wall time < {judge_latency + 0.2:.2f}s, got {elapsed:.3f}s" + + +class TestLiveRetrievalOperator: + """LiveRetrievalOperator adapts Retriever.retrieve_batch() into an EvalOperator.""" + + def test_process_populates_context_columns(self): + """Single ``retrieve_batch`` call covers the whole DataFrame. + + This is the batched contract -- one embed/LanceDB round trip for + all rows. The earlier per-row contract (N ``retrieve`` calls on + an N-row frame) was retired because it scaled linearly with RTT + to the embed NIM. Asserting the call count here guards against a + regression back to the quadratic path. + """ + + from nemo_retriever.evaluation.live_retrieval import LiveRetrievalOperator + from nemo_retriever.llm.types import RetrievalResult + + mock_retriever = MagicMock() + mock_retriever.retrieve_batch.return_value = [ + RetrievalResult( + chunks=["a", "b"], + metadata=[{"source": "s1"}, {"source": "s2"}], + ), + RetrievalResult(chunks=["c"], metadata=[{"source": "s3"}]), + ] + + op = LiveRetrievalOperator(mock_retriever, top_k=5) + df = pd.DataFrame({"query": ["q1", "q2"]}) + + out = op.process(df) + + assert list(out.columns) == ["query", "context", "context_metadata"] + assert out.loc[0, "context"] == ["a", "b"] + assert out.loc[1, "context"] == ["c"] + assert out.loc[0, "context_metadata"] == [{"source": "s1"}, {"source": "s2"}] + + # Exactly one batched call -- the whole point of the operator + # rewrite. ``retrieve`` must not be reached. + assert mock_retriever.retrieve_batch.call_count == 1 + mock_retriever.retrieve.assert_not_called() + + call_args = mock_retriever.retrieve_batch.call_args + queries_arg = call_args.args[0] if call_args.args else call_args.kwargs["queries"] + assert list(queries_arg) == ["q1", "q2"] + assert call_args.kwargs.get("top_k") == 5 + + def test_process_scales_to_ten_rows_with_single_call(self): + """A 10-row frame still triggers exactly one ``retrieve_batch`` call.""" + + from nemo_retriever.evaluation.live_retrieval import LiveRetrievalOperator + from nemo_retriever.llm.types import RetrievalResult + + mock_retriever = MagicMock() + mock_retriever.retrieve_batch.return_value = [ + RetrievalResult(chunks=[f"chunk-{i}"], metadata=[{"row": i}]) for i in range(10) + ] + + op = LiveRetrievalOperator(mock_retriever, top_k=3) + df = pd.DataFrame({"query": [f"q{i}" for i in range(10)]}) + + out = op.process(df) + + assert len(out) == 10 + assert mock_retriever.retrieve_batch.call_count == 1 + + def test_process_rejects_mismatched_batch_length(self): + """Guard against a retrieve_batch that drops or duplicates rows.""" + + from nemo_retriever.evaluation.live_retrieval import LiveRetrievalOperator + from nemo_retriever.llm.types import RetrievalResult + + mock_retriever = MagicMock() + mock_retriever.retrieve_batch.return_value = [ + RetrievalResult(chunks=["a"], metadata=[{"source": "s1"}]), + ] + + op = LiveRetrievalOperator(mock_retriever, top_k=3) + df = pd.DataFrame({"query": ["q1", "q2"]}) + + with pytest.raises(RuntimeError, match="retrieve_batch returned"): + op.process(df) + + def test_process_requires_dataframe(self): + from nemo_retriever.evaluation.live_retrieval import LiveRetrievalOperator + + op = LiveRetrievalOperator(MagicMock(), top_k=3) + with pytest.raises(TypeError, match="requires a pandas.DataFrame"): + op.process({"query": ["q"]}) + + +class TestRetrieveBatch: + """Batched analogue of ``retrieve`` -- one embed call for all rows.""" + + def test_retrieve_batch_returns_aligned_results(self): + """Length + order invariants hold across the batch.""" + + r = _make_retriever() + hits_per_query = [ + [ + {"text": "chunk-0-0", "source": "doc-A.pdf"}, + {"text": "chunk-0-1", "source": "doc-B.pdf"}, + ], + [{"text": "chunk-1-0", "source": "doc-C.pdf"}], + [], + ] + + with patch.object(r, "queries", return_value=hits_per_query) as mock_queries: + results = r.retrieve_batch(["q0", "q1", "q2"], top_k=4) + + assert mock_queries.call_count == 1 + assert len(results) == 3 + assert results[0].chunks == ["chunk-0-0", "chunk-0-1"] + assert results[0].metadata == [{"source": "doc-A.pdf"}, {"source": "doc-B.pdf"}] + assert results[1].chunks == ["chunk-1-0"] + assert results[2].chunks == [] and results[2].metadata == [] + + def test_retrieve_batch_top_k_is_scoped(self): + """``top_k`` override must not persist on the instance.""" + + r = _make_retriever() + original_top_k = r.top_k + with patch.object(r, "queries", return_value=[[]]): + r.retrieve_batch(["q"], top_k=42) + assert r.top_k == original_top_k + + def test_retrieve_batch_empty_input(self): + """Empty input returns an empty list and does not call ``queries``.""" + + r = _make_retriever() + with patch.object(r, "queries") as mock_queries: + assert r.retrieve_batch([]) == [] + mock_queries.assert_not_called() + + +class TestPipelineBuilder: + """Retriever.pipeline() fluent builder composition.""" + + def test_builder_composition_runs_expected_steps(self): + """generate -> score -> judge builds and executes the full chain.""" + r = _make_retriever() + hits = _fake_hits() + + # LiveRetrievalOperator uses retrieve_batch which delegates to queries(). + with patch.object(r, "queries", return_value=[hits]): + # Mock out the three EvalOperator classes that the builder imports + # lazily so we can assert which ones were appended and executed. + with patch("nemo_retriever.evaluation.generation.QAGenerationOperator") as mock_gen_cls, patch( + "nemo_retriever.evaluation.scoring_operator.ScoringOperator" + ) as mock_score_cls, patch("nemo_retriever.evaluation.judging.JudgingOperator") as mock_judge_cls: + # Configure each mocked operator to pass the DataFrame through + # with a sentinel column so we can verify each step ran. + def _gen_process(df, **_): + out = df.copy() + out["answer"] = ["gen-out"] * len(out) + return out + + def _score_process(df, **_): + out = df.copy() + out["token_f1"] = [1.0] * len(out) + return out + + def _judge_process(df, **_): + out = df.copy() + out["judge_score"] = [5] * len(out) + return out + + mock_gen_cls.return_value = _build_mock_operator("QAGenerationOperator", _gen_process) + mock_score_cls.return_value = _build_mock_operator("ScoringOperator", _score_process) + mock_judge_cls.return_value = _build_mock_operator("JudgingOperator", _judge_process) + + llm = _build_fake_llm_client() + judge = _build_fake_judge() + + df_out = r.pipeline().generate(llm).score().judge(judge).run(queries=["q1"], reference=["r1"]) + + assert isinstance(df_out, pd.DataFrame) + assert "context" in df_out.columns + assert "answer" in df_out.columns + assert "token_f1" in df_out.columns + assert "judge_score" in df_out.columns + assert mock_gen_cls.called + assert mock_score_cls.called + assert mock_judge_cls.called + + def test_builder_skip_steps(self): + """.pipeline().generate(llm).run([q]) skips score and judge.""" + r = _make_retriever() + + with patch.object(r, "queries", return_value=[_fake_hits()]): + with patch("nemo_retriever.evaluation.generation.QAGenerationOperator") as mock_gen_cls, patch( + "nemo_retriever.evaluation.scoring_operator.ScoringOperator" + ) as mock_score_cls, patch("nemo_retriever.evaluation.judging.JudgingOperator") as mock_judge_cls: + + def _gen_process(df, **_): + out = df.copy() + out["answer"] = ["answer"] * len(out) + return out + + mock_gen_cls.return_value = _build_mock_operator("QAGenerationOperator", _gen_process) + + llm = _build_fake_llm_client() + + df_out = r.pipeline().generate(llm).run(queries=["q"]) + + assert "context" in df_out.columns + assert "answer" in df_out.columns + assert "token_f1" not in df_out.columns + assert "judge_score" not in df_out.columns + assert mock_gen_cls.called + mock_score_cls.assert_not_called() + mock_judge_cls.assert_not_called() + + def test_builder_generate_requires_llm_or_model(self): + r = _make_retriever() + with pytest.raises(ValueError, match="requires either llm= or model="): + r.pipeline().generate() + + def test_builder_judge_requires_judge_or_model(self): + r = _make_retriever() + with pytest.raises(ValueError, match="requires either judge= or model="): + r.pipeline().judge() + + def test_builder_reference_length_must_match(self): + r = _make_retriever() + llm = _build_fake_llm_client() + with patch("nemo_retriever.evaluation.generation.QAGenerationOperator"): + with pytest.raises(ValueError, match="reference length must match"): + r.pipeline().generate(llm).run(queries=["q1", "q2"], reference=["r1"]) + + def test_builder_surfaces_generation_failure_rate(self): + """``gen_error`` column drives ``df.attrs['generation_failure_rate']``. + + Batch eval jobs that quietly skip scoring on generation failures + would otherwise report misleading success rates: the fraction of + rows with populated ``gen_error`` is attached as a DataFrame + attribute so aggregators have a single authoritative field to + read. No row-level schema change needed. + """ + + r = _make_retriever() + hits = _fake_hits() + + with patch.object(r, "queries", return_value=[hits, hits, hits]): + with patch("nemo_retriever.evaluation.generation.QAGenerationOperator") as mock_gen_cls: + + def _gen_process(df, **_): + out = df.copy() + out["answer"] = ["answer", "", ""] + out["gen_error"] = [None, "TimeoutError", "RateLimitError"] + return out + + mock_gen_cls.return_value = _build_mock_operator("QAGenerationOperator", _gen_process) + + llm = _build_fake_llm_client() + df_out = r.pipeline().generate(llm).run(queries=["q1", "q2", "q3"]) + + assert "generation_failure_rate" in df_out.attrs + assert df_out.attrs["generation_failure_rate"] == pytest.approx(2 / 3, abs=1e-6) + + def test_builder_skips_generation_failure_rate_without_gen_error(self): + """No ``gen_error`` column -> no attrs pollution on retrieval-only runs.""" + + r = _make_retriever() + + with patch.object(r, "queries", return_value=[_fake_hits()]): + df_out = r.pipeline().run(queries=["q"]) + + assert "generation_failure_rate" not in df_out.attrs + + +def _build_mock_operator(class_name: str, process_fn): + """Build a mock operator that cooperates with the graph framework. + + The object must satisfy ``isinstance(op, AbstractOperator)`` so the + pipeline_graph ``Node`` accepts it, and must expose ``.run(df)`` since + ``Graph._execute_node`` invokes that. We subclass the real + ``EvalOperator`` so required-column validation does not fire, and + simply override ``process``. + """ + from nemo_retriever.evaluation.eval_operator import EvalOperator + + class _Mock(EvalOperator): + required_columns = () + output_columns = () + + def __init__(self): + super().__init__() + + def process(self, data, **kwargs): + return process_fn(data, **kwargs) + + op = _Mock() + op.__class__.__name__ = class_name + return op + + +def _build_fake_llm_client(): + """Build a fake LiteLLMClient-shaped object for the builder.""" + transport = SimpleNamespace( + model="fake-llm/test", + api_base=None, + api_key=None, + extra_params={}, + num_retries=3, + timeout=120.0, + ) + sampling = SimpleNamespace(temperature=0.0, max_tokens=512) + return SimpleNamespace(transport=transport, sampling=sampling) + + +def _build_fake_judge(): + """Build a fake LLMJudge-shaped object for the builder.""" + transport = SimpleNamespace( + model="fake-judge/test", + api_base=None, + api_key=None, + extra_params={}, + num_retries=3, + timeout=120.0, + ) + client = SimpleNamespace(transport=transport) + return SimpleNamespace(_client=client) diff --git a/nemo_retriever/tests/test_llm_params.py b/nemo_retriever/tests/test_llm_params.py new file mode 100644 index 000000000..7a1d3bf83 --- /dev/null +++ b/nemo_retriever/tests/test_llm_params.py @@ -0,0 +1,395 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for the unified LLM params layer and client / judge composition. + +Covers: + * LLMRemoteClientParams validation and api_key auto-resolution + * LiteLLMClient(transport, sampling) and .from_kwargs(...) parity + * top_p omission from litellm call kwargs when unset + * LLMJudge default sampling and .from_kwargs(...) back-compat +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + + +def _fake_litellm_response(text: str = "ok") -> SimpleNamespace: + """Mimic the litellm.completion() response shape used by LiteLLMClient.""" + message = SimpleNamespace(content=text) + choice = SimpleNamespace(message=message) + return SimpleNamespace(choices=[choice]) + + +class TestLLMRemoteClientParams: + """Validate LLMRemoteClientParams validators, defaults, and api_key auto-resolution.""" + + def test_defaults(self): + from nemo_retriever.params.models import LLMRemoteClientParams + + p = LLMRemoteClientParams(model="nvidia_nim/meta/llama-3.1-70b-instruct") + assert p.model == "nvidia_nim/meta/llama-3.1-70b-instruct" + assert p.api_base is None + assert p.num_retries == 3 + assert p.timeout == 120.0 + assert p.extra_params == {} + + def test_model_is_required(self): + from nemo_retriever.params.models import LLMRemoteClientParams + + with pytest.raises(ValueError): + LLMRemoteClientParams() # type: ignore[call-arg] + + def test_negative_num_retries_rejected(self): + from nemo_retriever.params.models import LLMRemoteClientParams + + with pytest.raises(ValueError, match="num_retries must be >= 0"): + LLMRemoteClientParams(model="m", num_retries=-1) + + def test_zero_timeout_rejected(self): + from nemo_retriever.params.models import LLMRemoteClientParams + + with pytest.raises(ValueError, match="timeout must be > 0"): + LLMRemoteClientParams(model="m", timeout=0.0) + + def test_negative_timeout_rejected(self): + from nemo_retriever.params.models import LLMRemoteClientParams + + with pytest.raises(ValueError, match="timeout must be > 0"): + LLMRemoteClientParams(model="m", timeout=-1.0) + + def test_extra_forbid(self): + """Unknown kwargs should be rejected by _ParamsModel(extra='forbid').""" + from nemo_retriever.params.models import LLMRemoteClientParams + + with pytest.raises(ValueError): + LLMRemoteClientParams(model="m", unknown_field=123) # type: ignore[call-arg] + + def test_api_key_auto_resolved_from_env(self, monkeypatch): + """api_key=None should resolve from the remote-auth helper.""" + from nemo_retriever.params import models as params_models + + monkeypatch.setattr(params_models, "resolve_remote_api_key", lambda: "resolved-secret") + p = params_models.LLMRemoteClientParams(model="m") + assert p.api_key == "resolved-secret" + + def test_api_key_no_api_key_sentinel_yields_none(self): + """Explicit NO_API_KEY sentinel suppresses auto-resolution.""" + from nemo_retriever.params.models import NO_API_KEY, LLMRemoteClientParams + + p = LLMRemoteClientParams(model="m", api_key=NO_API_KEY) + assert p.api_key is None + + +class TestLiteLLMClientConstruction: + """LiteLLMClient should accept structured params and expose .model for back-compat.""" + + def test_structured_construction(self): + from nemo_retriever.llm.clients import LiteLLMClient + from nemo_retriever.params.models import LLMInferenceParams, LLMRemoteClientParams + + transport = LLMRemoteClientParams(model="openai/gpt-4o-mini", api_key="k") + sampling = LLMInferenceParams(temperature=0.2, top_p=0.9, max_tokens=512) + client = LiteLLMClient(transport=transport, sampling=sampling) + + assert client.transport is transport + assert client.sampling is sampling + assert client.model == "openai/gpt-4o-mini" + + def test_default_sampling_is_llminferenceparams_defaults(self): + from nemo_retriever.llm.clients import LiteLLMClient + from nemo_retriever.params.models import LLMInferenceParams, LLMRemoteClientParams + + client = LiteLLMClient(transport=LLMRemoteClientParams(model="m")) + assert isinstance(client.sampling, LLMInferenceParams) + assert client.sampling.temperature == 1.0 + assert client.sampling.top_p is None + assert client.sampling.max_tokens == 1024 + + def test_from_kwargs_matches_explicit(self): + from nemo_retriever.llm.clients import LiteLLMClient + + flat = LiteLLMClient.from_kwargs( + model="openai/gpt-4o-mini", + api_key="k", + temperature=0.3, + top_p=0.8, + max_tokens=256, + num_retries=5, + timeout=30.0, + extra_params={"user": "tester"}, + ) + assert flat.transport.model == "openai/gpt-4o-mini" + assert flat.transport.api_key == "k" + assert flat.transport.num_retries == 5 + assert flat.transport.timeout == 30.0 + assert flat.transport.extra_params == {"user": "tester"} + assert flat.sampling.temperature == 0.3 + assert flat.sampling.top_p == 0.8 + assert flat.sampling.max_tokens == 256 + + def test_from_kwargs_defaults_top_p_to_none(self): + """The old flat default of top_p=1.0 is now top_p=None (behavior fix).""" + from nemo_retriever.llm.clients import LiteLLMClient + + client = LiteLLMClient.from_kwargs(model="m") + assert client.sampling.top_p is None + + +class TestLiteLLMCompleteCallKwargs: + """Inspect the exact kwargs LiteLLMClient.complete() forwards to litellm.""" + + @patch("litellm.completion") + def test_top_p_omitted_when_none(self, mock_completion): + from nemo_retriever.llm.clients import LiteLLMClient + + mock_completion.return_value = _fake_litellm_response("hi") + client = LiteLLMClient.from_kwargs(model="openai/gpt-4o-mini", temperature=0.5) + client.complete([{"role": "user", "content": "hi"}]) + + kwargs = mock_completion.call_args.kwargs + assert kwargs["model"] == "openai/gpt-4o-mini" + assert kwargs["temperature"] == 0.5 + assert kwargs["max_tokens"] == 4096 + assert "top_p" not in kwargs + assert kwargs["num_retries"] == 3 + assert kwargs["timeout"] == 120.0 + + @patch("litellm.completion") + def test_top_p_forwarded_when_set(self, mock_completion): + from nemo_retriever.llm.clients import LiteLLMClient + + mock_completion.return_value = _fake_litellm_response("hi") + client = LiteLLMClient.from_kwargs(model="m", top_p=0.9) + client.complete([{"role": "user", "content": "hi"}]) + + kwargs = mock_completion.call_args.kwargs + assert kwargs["top_p"] == 0.9 + + @patch("litellm.completion") + def test_max_tokens_override(self, mock_completion): + from nemo_retriever.llm.clients import LiteLLMClient + + mock_completion.return_value = _fake_litellm_response("hi") + client = LiteLLMClient.from_kwargs(model="m", max_tokens=4096) + client.complete([{"role": "user", "content": "hi"}], max_tokens=128) + + kwargs = mock_completion.call_args.kwargs + assert kwargs["max_tokens"] == 128 + + @patch("litellm.completion") + def test_api_key_and_api_base_forwarded(self, mock_completion): + from nemo_retriever.llm.clients import LiteLLMClient + + mock_completion.return_value = _fake_litellm_response("hi") + client = LiteLLMClient.from_kwargs( + model="openai/gpt-4o-mini", + api_base="http://local-vllm:8000/v1", + api_key="secret", + ) + client.complete([{"role": "user", "content": "hi"}]) + + kwargs = mock_completion.call_args.kwargs + assert kwargs["api_base"] == "http://local-vllm:8000/v1" + assert kwargs["api_key"] == "secret" + + @patch("litellm.completion") + def test_extra_params_merged_last(self, mock_completion): + """extra_params should win over keys it overlaps with.""" + from nemo_retriever.llm.clients import LiteLLMClient + + mock_completion.return_value = _fake_litellm_response("hi") + client = LiteLLMClient.from_kwargs( + model="m", + extra_params={"user": "tester", "num_retries": 99}, + ) + client.complete([{"role": "user", "content": "hi"}]) + + kwargs = mock_completion.call_args.kwargs + assert kwargs["user"] == "tester" + assert kwargs["num_retries"] == 99 + + +class TestLLMJudgeConstruction: + """LLMJudge should default to deterministic sampling and expose .model.""" + + def test_structured_construction_uses_defaults(self): + from nemo_retriever.llm.clients import LLMJudge + from nemo_retriever.params.models import LLMRemoteClientParams + + transport = LLMRemoteClientParams(model="nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1") + judge = LLMJudge(transport=transport) + assert judge.model == "nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1" + assert judge._client.sampling.temperature == 0.0 + assert judge._client.sampling.max_tokens == 256 + + def test_custom_sampling_override(self): + from nemo_retriever.llm.clients import LLMJudge + from nemo_retriever.params.models import LLMInferenceParams, LLMRemoteClientParams + + transport = LLMRemoteClientParams(model="m") + sampling = LLMInferenceParams(temperature=0.4, max_tokens=1024) + judge = LLMJudge(transport=transport, sampling=sampling) + assert judge._client.sampling.temperature == 0.4 + assert judge._client.sampling.max_tokens == 1024 + + def test_from_kwargs_matches_structured(self): + from nemo_retriever.llm.clients import LLMJudge + + judge = LLMJudge.from_kwargs( + model="m", + api_key="k", + num_retries=2, + timeout=60.0, + extra_params={"user": "t"}, + ) + assert judge._client.transport.model == "m" + assert judge._client.transport.api_key == "k" + assert judge._client.transport.num_retries == 2 + assert judge._client.transport.timeout == 60.0 + assert judge._client.transport.extra_params == {"user": "t"} + # Sampling stays at judge defaults even when using flat constructor. + assert judge._client.sampling.temperature == 0.0 + assert judge._client.sampling.max_tokens == 256 + + def test_from_kwargs_uses_default_model(self): + from nemo_retriever.llm.clients import LLMJudge + + judge = LLMJudge.from_kwargs() + assert judge.model == LLMJudge._DEFAULT_MODEL + + @patch("litellm.completion") + def test_judge_returns_parsed_result(self, mock_completion): + from nemo_retriever.llm.clients import LLMJudge + + mock_completion.return_value = _fake_litellm_response( + '{"score": 4, "reasoning": "mostly correct"}', + ) + judge = LLMJudge.from_kwargs(model="m") + verdict = judge.judge(query="q", reference="ref", candidate="cand") + assert verdict.score == 4 + assert verdict.reasoning == "mostly correct" + assert verdict.error is None + + def test_judge_empty_candidate_short_circuits(self): + """Empty candidate is handled locally with no LLM call.""" + from nemo_retriever.llm.clients import LLMJudge + + with patch("litellm.completion") as mock_completion: + judge = LLMJudge.from_kwargs(model="m") + verdict = judge.judge(query="q", reference="r", candidate=" ") + mock_completion.assert_not_called() + + assert verdict.score is None + assert verdict.error == "empty_candidate" + + +class TestBackCompatCallSites: + """The four migrated call sites all use .from_kwargs, so they must still work.""" + + @patch("litellm.completion") + def test_qa_generation_operator_constructs_cleanly(self, mock_completion): + from nemo_retriever.evaluation.generation import QAGenerationOperator + + mock_completion.return_value = _fake_litellm_response("answer") + op = QAGenerationOperator(model="m", temperature=0.0, max_tokens=128) + assert op._client.transport.model == "m" + assert op._client.sampling.temperature == 0.0 + assert op._client.sampling.max_tokens == 128 + + def test_judging_operator_constructs_cleanly(self): + from nemo_retriever.evaluation.judging import JudgingOperator + + op = JudgingOperator(model="nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1") + assert op._judge.model == "nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1" + assert op._judge._client.sampling.temperature == 0.0 + + +class TestApiKeyRedaction: + """Guard the repr/str of every transport params object against key leakage. + + The ``_ParamsModel`` base redacts ``api_key`` + ``*_api_key`` fields in + ``__repr__`` / ``__str__`` so that logging a transport object (or + letting Pydantic's default error formatter echo one back) never + prints a bearer token. Consumers still read the plain ``str`` via + attribute access, so no downstream litellm/NIM call is affected. + """ + + def test_api_key_masked_in_repr(self): + from nemo_retriever.params.models import LLMRemoteClientParams + + p = LLMRemoteClientParams(model="m", api_key="nvapi-SECRET-TOKEN") + rendered = repr(p) + assert "nvapi-SECRET-TOKEN" not in rendered + assert "api_key=***" in rendered + + def test_api_key_masked_in_str(self): + from nemo_retriever.params.models import LLMRemoteClientParams + + p = LLMRemoteClientParams(model="m", api_key="nvapi-SECRET-TOKEN") + assert "nvapi-SECRET-TOKEN" not in str(p) + + def test_api_key_attribute_is_plain_str(self): + """Redaction is display-only -- attribute access still yields the raw string.""" + from nemo_retriever.params.models import LLMRemoteClientParams + + p = LLMRemoteClientParams(model="m", api_key="nvapi-SECRET-TOKEN") + assert p.api_key == "nvapi-SECRET-TOKEN" + assert isinstance(p.api_key, str) + + def test_empty_api_key_not_masked(self): + """Redaction only fires when a key is actually present.""" + from nemo_retriever.params.models import NO_API_KEY, LLMRemoteClientParams + + p = LLMRemoteClientParams(model="m", api_key=NO_API_KEY) + assert p.api_key is None + assert "api_key=***" not in repr(p) + assert "api_key=None" in repr(p) + + @patch("litellm.completion") + def test_plain_str_reaches_litellm_call_site(self, mock_completion): + """The redacted __repr__ must not break the wire-format contract.""" + from nemo_retriever.llm.clients import LiteLLMClient + + mock_completion.return_value = _fake_litellm_response("ok") + client = LiteLLMClient.from_kwargs(model="m", api_key="nvapi-SECRET-TOKEN") + client.generate(query="q", chunks=[]) + + _, call_kwargs = mock_completion.call_args + assert call_kwargs["api_key"] == "nvapi-SECRET-TOKEN" + assert isinstance(call_kwargs["api_key"], str) + + def test_nested_api_key_fields_also_masked(self): + """Fields matching *_api_key (not only bare api_key) get redacted.""" + from nemo_retriever.params.models import ExtractParams + + p = ExtractParams( + page_elements_api_key="nvapi-PAGE-ELEM-TOKEN", + ocr_api_key="nvapi-OCR-TOKEN", + ) + rendered = repr(p) + assert "nvapi-PAGE-ELEM-TOKEN" not in rendered + assert "nvapi-OCR-TOKEN" not in rendered + assert "page_elements_api_key=***" in rendered + assert "ocr_api_key=***" in rendered + + +class TestLiteLLMDefaultModel: + """Mirror of LLMJudge._DEFAULT_MODEL coverage for LiteLLMClient.""" + + def test_from_kwargs_uses_default_model(self): + from nemo_retriever.llm.clients import LiteLLMClient + + client = LiteLLMClient.from_kwargs() + assert client.model == LiteLLMClient._DEFAULT_MODEL + + def test_default_model_is_a_non_empty_string(self): + from nemo_retriever.llm.clients import LiteLLMClient + + assert isinstance(LiteLLMClient._DEFAULT_MODEL, str) + assert LiteLLMClient._DEFAULT_MODEL diff --git a/nemo_retriever/uv.lock b/nemo_retriever/uv.lock index a1f761ba9..7fcb234e5 100644 --- a/nemo_retriever/uv.lock +++ b/nemo_retriever/uv.lock @@ -2823,10 +2823,8 @@ dev = [ { name = "build" }, { name = "pytest" }, ] -eval = [ +llm = [ { name = "litellm" }, - { name = "pyyaml" }, - { name = "tenacity" }, ] local = [ { name = "accelerate" }, @@ -2883,7 +2881,7 @@ requires-dist = [ { name = "httpx", specifier = ">=0.27.0" }, { name = "lancedb" }, { name = "langchain-nvidia-ai-endpoints", specifier = ">=0.3.0" }, - { name = "litellm", marker = "extra == 'eval'", specifier = ">=1.40.0" }, + { name = "litellm", marker = "extra == 'llm'", specifier = ">=1.40.0" }, { name = "markitdown" }, { name = "nemo-retriever", extras = ["benchmarks", "local", "multimedia", "stores"], marker = "extra == 'all'" }, { name = "nemotron-graphic-elements-v1", marker = "extra == 'local'", specifier = ">=0.dev0", index = "https://test.pypi.org/simple/" }, @@ -2905,7 +2903,6 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.2" }, { name = "python-multipart", specifier = ">=0.0.9" }, { name = "pyyaml", specifier = ">=6.0" }, - { name = "pyyaml", marker = "extra == 'eval'", specifier = ">=6.0" }, { name = "ray", extras = ["data", "serve"], specifier = ">=2.49.0" }, { name = "requests", specifier = ">=2.32.5" }, { name = "rich", specifier = ">=13.7.0" }, @@ -2913,7 +2910,6 @@ requires-dist = [ { name = "scipy", marker = "extra == 'multimedia'", specifier = ">=1.11.0" }, { name = "soundfile", marker = "extra == 'multimedia'", specifier = ">=0.12.0" }, { name = "sqlglot", specifier = ">=30.0.0" }, - { name = "tenacity", marker = "extra == 'eval'", specifier = ">=8.0.0" }, { name = "timm", marker = "extra == 'local'", specifier = "==1.0.22" }, { name = "tokenizers", marker = "extra == 'local'", specifier = ">=0.20.3" }, { name = "torch", marker = "(sys_platform == 'linux' and extra == 'local') or (sys_platform == 'win32' and extra == 'local')", specifier = "~=2.9.1", index = "https://download.pytorch.org/whl/cu130" }, @@ -2930,7 +2926,7 @@ requires-dist = [ { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64' and sys_platform == 'linux' and extra == 'local'", specifier = "==0.16.0" }, { name = "vllm", marker = "platform_machine == 'x86_64' and sys_platform == 'linux' and extra == 'local'", url = "https://github.com/vllm-project/vllm/releases/download/v0.16.0/vllm-0.16.0+cu130-cp38-abi3-manylinux_2_35_x86_64.whl" }, ] -provides-extras = ["local", "multimedia", "stores", "benchmarks", "eval", "dev", "all"] +provides-extras = ["local", "multimedia", "stores", "benchmarks", "llm", "dev", "all"] [[package]] name = "nemotron-graphic-elements-v1" From 2a867ba009d5e8908f7c4d67691466ef8247ecd4 Mon Sep 17 00:00:00 2001 From: Kyle Zheng <126034466+KyleZheng1284@users.noreply.github.com> Date: Tue, 21 Apr 2026 03:52:12 +0000 Subject: [PATCH 05/10] resolve unsafe thread + silent drop --- .../nemo_retriever/evaluation/generation.py | 3 + .../src/nemo_retriever/retriever.py | 91 ++++++++++++------- nemo_retriever/tests/test_live_rag.py | 91 ++++++++++++++++++- 3 files changed, 148 insertions(+), 37 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/generation.py b/nemo_retriever/src/nemo_retriever/evaluation/generation.py index b879868de..a3ad7f1af 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/generation.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/generation.py @@ -36,6 +36,7 @@ def __init__( api_base: Optional[str] = None, api_key: Optional[str] = None, temperature: float = 0.0, + top_p: Optional[float] = None, max_tokens: int = 4096, extra_params: Optional[dict[str, Any]] = None, num_retries: int = 3, @@ -47,6 +48,7 @@ def __init__( api_base=api_base, api_key=api_key, temperature=temperature, + top_p=top_p, max_tokens=max_tokens, extra_params=extra_params, num_retries=num_retries, @@ -58,6 +60,7 @@ def __init__( api_base=api_base, api_key=api_key, temperature=temperature, + top_p=top_p, max_tokens=max_tokens, extra_params=extra_params, num_retries=num_retries, diff --git a/nemo_retriever/src/nemo_retriever/retriever.py b/nemo_retriever/src/nemo_retriever/retriever.py index 4390e7d37..84a14cf15 100644 --- a/nemo_retriever/src/nemo_retriever/retriever.py +++ b/nemo_retriever/src/nemo_retriever/retriever.py @@ -172,6 +172,7 @@ def _search_lancedb( lancedb_table: str, query_vectors: list[list[float]], query_texts: list[str], + top_k: int, ) -> list[list[dict[str, Any]]]: import lancedb # type: ignore import numpy as np @@ -202,7 +203,7 @@ def _search_lancedb( for i, vector in enumerate(query_vectors): q = np.asarray(vector, dtype="float32") # doubling top_k for both hybrid and dense search in order to have more to rerank - top_k = self.top_k if not self.reranker else self.top_k * self.reranker_refine_factor + fanout_top_k = top_k if not self.reranker else top_k * self.reranker_refine_factor if self.hybrid: from lancedb.rerankers import RRFReranker # type: ignore @@ -212,7 +213,7 @@ def _search_lancedb( .text(query_texts[i]) .nprobes(effective_nprobes) .refine_factor(int(self.refine_factor)) - .limit(int(top_k)) + .limit(int(fanout_top_k)) .rerank(RRFReranker()) .to_list() ) @@ -239,7 +240,7 @@ def _search_lancedb( .nprobes(effective_nprobes) .refine_factor(int(self.refine_factor)) .select(select_cols) - .limit(int(top_k)) + .limit(int(fanout_top_k)) .to_list() ) results.append([{k: v for k, v in h.items() if k in _KEEP_KEYS} for h in hits]) @@ -266,6 +267,8 @@ def _rerank_results( self, query_texts: list[str], results: list[list[dict[str, Any]]], + *, + top_k: int, ) -> list[list[dict[str, Any]]]: """Rerank each per-query result list using the configured reranker.""" from nemo_retriever.rerank import rerank_hits @@ -285,7 +288,7 @@ def _rerank_results( api_key=(self.reranker_api_key or "").strip(), max_length=int(self.reranker_max_length), batch_size=int(self.reranker_batch_size), - top_n=int(self.top_k), + top_n=int(top_k), modality=self.rerank_modality, ) ) @@ -299,13 +302,27 @@ def query( self, query: str, *, + top_k: Optional[int] = None, embedder: Optional[str] = None, lancedb_uri: Optional[str] = None, lancedb_table: Optional[str] = None, ) -> list[dict[str, Any]]: - """Run retrieval for a single query string.""" + """Run retrieval for a single query string. + + Args: + query: The natural-language query. + top_k: Per-call override of ``self.top_k``. When ``None`` the + instance attribute is used. The override is passed as a + local value through the search / rerank stack; it never + mutates ``self.top_k``, which keeps concurrent callers on + the same :class:`Retriever` instance thread-safe. + embedder: Per-call embedder override. + lancedb_uri: Per-call LanceDB URI override. + lancedb_table: Per-call LanceDB table override. + """ return self.queries( [query], + top_k=top_k, embedder=embedder, lancedb_uri=lancedb_uri, lancedb_table=lancedb_table, @@ -315,6 +332,7 @@ def queries( self, queries: Sequence[str], *, + top_k: Optional[int] = None, embedder: Optional[str] = None, lancedb_uri: Optional[str] = None, lancedb_table: Optional[str] = None, @@ -325,11 +343,19 @@ def queries( results are re-scored with ``nvidia/llama-nemotron-rerank-1b-v2`` (or the configured endpoint) and returned sorted by cross-encoder score. Each hit gains a ``"_rerank_score"`` key. + + The ``top_k`` argument is resolved into a local ``effective_top_k`` + that is threaded through the search + rerank stack. ``self.top_k`` + is read once (when ``top_k`` is ``None``) and never written, so + concurrent callers sharing a :class:`Retriever` instance cannot + race on the instance attribute. """ query_texts = [str(q) for q in queries] if not query_texts: return [] + effective_top_k = int(top_k) if top_k is not None else int(self.top_k) + resolved_embedder = str(embedder or self.embedder) resolved_lancedb_uri = str(lancedb_uri or self.lancedb_uri) resolved_lancedb_table = str(lancedb_table or self.lancedb_table) @@ -352,10 +378,11 @@ def queries( lancedb_table=resolved_lancedb_table, query_vectors=vectors, query_texts=query_texts, + top_k=effective_top_k, ) if self.reranker: - results = self._rerank_results(query_texts, results) + results = self._rerank_results(query_texts, results, top_k=effective_top_k) return results @@ -382,8 +409,10 @@ def retrieve( Args: query: The natural-language query. - top_k: Override ``self.top_k`` for this call. When ``None`` the - instance attribute is used. + top_k: Per-call override of ``self.top_k``. Passed through as + a local value to :meth:`query`; ``self.top_k`` is never + mutated, so concurrent callers sharing a :class:`Retriever` + instance remain thread-safe. embedder: Override ``self.embedder`` for this call. lancedb_uri: Override ``self.lancedb_uri`` for this call. lancedb_table: Override ``self.lancedb_table`` for this call. @@ -400,18 +429,13 @@ def retrieve( """ from nemo_retriever.llm.types import RetrievalResult - previous_top_k = self.top_k - if top_k is not None: - self.top_k = int(top_k) - try: - hits = self.query( - query, - embedder=embedder, - lancedb_uri=lancedb_uri, - lancedb_table=lancedb_table, - ) - finally: - self.top_k = previous_top_k + hits = self.query( + query, + top_k=top_k, + embedder=embedder, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + ) chunks: list[str] = [] metadata: list[dict[str, Any]] = [] @@ -443,9 +467,10 @@ def retrieve_batch( Args: queries: Iterable of natural-language query strings. Order is preserved in the returned list. - top_k: Per-call override of ``self.top_k`` (scoped via - ``try/finally`` so the instance attribute is restored on - return, mirroring :meth:`retrieve`). + top_k: Per-call override of ``self.top_k``. Passed through + to :meth:`queries` as a local value, so ``self.top_k`` is + never written. Concurrent callers sharing a single + :class:`Retriever` instance therefore remain thread-safe. embedder: Per-call embedder override. lancedb_uri: Per-call LanceDB URI override. lancedb_table: Per-call LanceDB table override. @@ -462,18 +487,13 @@ def retrieve_batch( if not query_texts: return [] - previous_top_k = self.top_k - if top_k is not None: - self.top_k = int(top_k) - try: - hits_per_query = self.queries( - query_texts, - embedder=embedder, - lancedb_uri=lancedb_uri, - lancedb_table=lancedb_table, - ) - finally: - self.top_k = previous_top_k + hits_per_query = self.queries( + query_texts, + top_k=top_k, + embedder=embedder, + lancedb_uri=lancedb_uri, + lancedb_table=lancedb_table, + ) results: list[RetrievalResult] = [] for hits in hits_per_query: @@ -752,6 +772,7 @@ def generate( api_base=transport.api_base, api_key=transport.api_key, temperature=sampling.temperature, + top_p=sampling.top_p, max_tokens=sampling.max_tokens, extra_params=dict(transport.extra_params) if transport.extra_params else None, num_retries=transport.num_retries, diff --git a/nemo_retriever/tests/test_live_rag.py b/nemo_retriever/tests/test_live_rag.py index d9086b32e..af55c8b45 100644 --- a/nemo_retriever/tests/test_live_rag.py +++ b/nemo_retriever/tests/test_live_rag.py @@ -359,6 +359,56 @@ def test_retrieve_batch_top_k_is_scoped(self): r.retrieve_batch(["q"], top_k=42) assert r.top_k == original_top_k + def test_retrieve_batch_forwards_top_k_to_queries(self): + """The per-call ``top_k`` must be forwarded as a kwarg, not via + attribute mutation. This is the regression test for the + Greptile P1 "thread-unsafe self.top_k mutation" finding: under + the old try/finally pattern the value was visible to ``queries`` + only through ``self.top_k``, which was racy under concurrent use. + """ + + r = _make_retriever() + original_top_k = r.top_k + with patch.object(r, "queries", return_value=[[]]) as mock_queries: + r.retrieve_batch(["q"], top_k=7) + mock_queries.assert_called_once() + call_kwargs = mock_queries.call_args.kwargs + assert call_kwargs.get("top_k") == 7 + assert r.top_k == original_top_k + + def test_retrieve_batch_concurrent_distinct_top_k(self): + """Concurrent ``retrieve_batch`` calls with different ``top_k`` + values must not clobber each other. + + Under the old ``previous_top_k = self.top_k; self.top_k = ...; try: + self.queries(...); finally: self.top_k = previous_top_k`` dance + two threads would race on ``self.top_k``: thread A could set + ``top_k=3``, thread B could overwrite with ``top_k=10`` before + thread A's ``queries()`` call read it, and thread A would run + with the wrong k. The new implementation passes ``top_k`` + through as a local kwarg, so each call sees its own value. + """ + + from concurrent.futures import ThreadPoolExecutor + + r = _make_retriever() + + observed: list[int] = [] + lock = __import__("threading").Lock() + + def fake_queries(query_texts, *, top_k, **_kwargs): + with lock: + observed.append(int(top_k)) + return [[] for _ in query_texts] + + with patch.object(r, "queries", side_effect=fake_queries): + with ThreadPoolExecutor(max_workers=4) as pool: + futures = [pool.submit(r.retrieve_batch, ["q"], top_k=k) for k in (1, 3, 7, 15, 42)] + for f in futures: + f.result() + + assert sorted(observed) == [1, 3, 7, 15, 42] + def test_retrieve_batch_empty_input(self): """Empty input returns an empty list and does not call ``queries``.""" @@ -446,6 +496,43 @@ def _gen_process(df, **_): mock_score_cls.assert_not_called() mock_judge_cls.assert_not_called() + def test_builder_forwards_top_p_from_llm_client(self): + """``.generate(llm)`` must forward ``llm.sampling.top_p`` to the + operator. + + Regression test for the Greptile P1 finding that ``top_p`` was + silently dropped when a caller passed a pre-built ``LiteLLMClient`` + with a non-default ``top_p``. + """ + + r = _make_retriever() + + with patch("nemo_retriever.evaluation.generation.QAGenerationOperator") as mock_gen_cls: + mock_gen_cls.return_value = _build_mock_operator("QAGenerationOperator", lambda df, **_: df) + llm = _build_fake_llm_client(top_p=0.7) + r.pipeline().generate(llm) + + mock_gen_cls.assert_called_once() + kwargs = mock_gen_cls.call_args.kwargs + assert kwargs.get("top_p") == 0.7 + assert kwargs.get("temperature") == 0.0 + assert kwargs.get("max_tokens") == 512 + + def test_builder_forwards_none_top_p_when_unset(self): + """Default path (``top_p=None``) must forward ``None`` -- not raise + and not silently substitute a non-default.""" + + r = _make_retriever() + + with patch("nemo_retriever.evaluation.generation.QAGenerationOperator") as mock_gen_cls: + mock_gen_cls.return_value = _build_mock_operator("QAGenerationOperator", lambda df, **_: df) + llm = _build_fake_llm_client() + r.pipeline().generate(llm) + + mock_gen_cls.assert_called_once() + kwargs = mock_gen_cls.call_args.kwargs + assert kwargs.get("top_p") is None + def test_builder_generate_requires_llm_or_model(self): r = _make_retriever() with pytest.raises(ValueError, match="requires either llm= or model="): @@ -530,7 +617,7 @@ def process(self, data, **kwargs): return op -def _build_fake_llm_client(): +def _build_fake_llm_client(*, top_p: float | None = None): """Build a fake LiteLLMClient-shaped object for the builder.""" transport = SimpleNamespace( model="fake-llm/test", @@ -540,7 +627,7 @@ def _build_fake_llm_client(): num_retries=3, timeout=120.0, ) - sampling = SimpleNamespace(temperature=0.0, max_tokens=512) + sampling = SimpleNamespace(temperature=0.0, top_p=top_p, max_tokens=512) return SimpleNamespace(transport=transport, sampling=sampling) From 346a8f7d503163cb96ffb7da91d8a6b5bf93053c Mon Sep 17 00:00:00 2001 From: Kyle Zheng <126034466+KyleZheng1284@users.noreply.github.com> Date: Tue, 21 Apr 2026 19:04:51 +0000 Subject: [PATCH 06/10] fix minor style fixes + issues --- .../nemo_retriever/evaluation/orchestrator.py | 7 ++- .../nemo_retriever/examples/graph_pipeline.py | 14 +++++ .../src/nemo_retriever/llm/clients/litellm.py | 4 +- .../src/nemo_retriever/retriever.py | 44 +++++----------- nemo_retriever/tests/test_llm_params.py | 52 ++++++++++++++++++- 5 files changed, 82 insertions(+), 39 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py b/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py index 7df9d9eba..faa5f5be5 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py @@ -148,10 +148,9 @@ def process(self, data: Any, **kwargs: Any) -> Any: for model_name, client in self.llm_clients.items(): prefix = _sanitize_prefix(model_name) - # _client and _model_name are captured via default args to avoid - # the late-binding closure bug across loop iterations. self.judge - # is safe to access directly since it does not change per-iteration. - def _process_row(row_tuple, row_aic, _client=client, _model_name=model_name): + # ``_client`` is captured via a default arg to pin each iteration's client + # into the closure and avoid the late-binding bug. + def _process_row(row_tuple, row_aic, _client=client): _, row = row_tuple query = row["query"] ref = row["reference_answer"] diff --git a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py index c576c7ffd..0a9160a78 100644 --- a/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py +++ b/nemo_retriever/src/nemo_retriever/examples/graph_pipeline.py @@ -660,6 +660,8 @@ def main( evaluation_total_time = 0.0 evaluation_metrics: dict[str, float] = {} evaluation_query_count: Optional[int] = None + # Tracks whether the QA sweep had any non-PASS results so CI gets a non-zero exit after cleanup. + qa_failed = False if evaluation_mode == "beir": if not beir_loader: @@ -759,10 +761,17 @@ def main( passed = sum(1 for r in sweep_results if r["status"] == "PASS") logger.info("QA sweep complete: %d/%d passed", passed, len(sweep_results)) + qa_failed = passed < len(sweep_results) for r in sweep_results: if r["status"] == "PASS": out = Path(r["output_path"]).resolve() logger.info("Results: %s", out) + else: + logger.error( + "QA run FAILED: %s: %s", + r.get("label", ""), + r.get("error", ""), + ) er = r.get("eval_results", {}) judge_scores = er.get("tier3_llm_judge", {}) for gen_name, stats in judge_scores.items(): @@ -862,6 +871,11 @@ def main( evaluation_label=evaluation_label, evaluation_count=evaluation_query_count, ) + + # Raise the non-zero exit after summary + Ray shutdown so CI still collects + # diagnostic artifacts, and the outer ``finally`` still restores stdout/stderr. + if qa_failed: + raise typer.Exit(code=1) finally: os.sys.stdout = original_stdout os.sys.stderr = original_stderr diff --git a/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py b/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py index 360c0f2bc..599e8fe14 100644 --- a/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py +++ b/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py @@ -78,7 +78,9 @@ def __init__( sampling: Optional[LLMInferenceParams] = None, ): self.transport = transport - self.sampling = sampling if sampling is not None else LLMInferenceParams() + # Default to ``temperature=0.0`` so the structured constructor matches ``from_kwargs`` + # and keeps RAG-eval runs deterministic. + self.sampling = sampling if sampling is not None else LLMInferenceParams(temperature=0.0) @property def model(self) -> str: diff --git a/nemo_retriever/src/nemo_retriever/retriever.py b/nemo_retriever/src/nemo_retriever/retriever.py index 84a14cf15..68b58175c 100644 --- a/nemo_retriever/src/nemo_retriever/retriever.py +++ b/nemo_retriever/src/nemo_retriever/retriever.py @@ -311,11 +311,8 @@ def query( Args: query: The natural-language query. - top_k: Per-call override of ``self.top_k``. When ``None`` the - instance attribute is used. The override is passed as a - local value through the search / rerank stack; it never - mutates ``self.top_k``, which keeps concurrent callers on - the same :class:`Retriever` instance thread-safe. + top_k: Per-call override of ``self.top_k``; passed as a local + value so the instance attribute is never mutated. embedder: Per-call embedder override. lancedb_uri: Per-call LanceDB URI override. lancedb_table: Per-call LanceDB table override. @@ -344,11 +341,8 @@ def queries( (or the configured endpoint) and returned sorted by cross-encoder score. Each hit gains a ``"_rerank_score"`` key. - The ``top_k`` argument is resolved into a local ``effective_top_k`` - that is threaded through the search + rerank stack. ``self.top_k`` - is read once (when ``top_k`` is ``None``) and never written, so - concurrent callers sharing a :class:`Retriever` instance cannot - race on the instance attribute. + The ``top_k`` kwarg is threaded through the search + rerank stack + as a local value so concurrent callers never race on ``self.top_k``. """ query_texts = [str(q) for q in queries] if not query_texts: @@ -409,10 +403,8 @@ def retrieve( Args: query: The natural-language query. - top_k: Per-call override of ``self.top_k``. Passed through as - a local value to :meth:`query`; ``self.top_k`` is never - mutated, so concurrent callers sharing a :class:`Retriever` - instance remain thread-safe. + top_k: Per-call override of ``self.top_k``; passed as a local + value so the instance attribute is never mutated. embedder: Override ``self.embedder`` for this call. lancedb_uri: Override ``self.lancedb_uri`` for this call. lancedb_table: Override ``self.lancedb_table`` for this call. @@ -455,22 +447,14 @@ def retrieve_batch( ) -> list["RetrievalResult"]: """Run retrieval for a batch of queries in a single embedder call. - This is the batched analogue of :meth:`retrieve`. It funnels the - whole query list through :meth:`queries`, which already dispatches - exactly one call to ``_embed_queries_nim`` (or the local HF - embedder) regardless of ``len(queries)``. Callers that previously - looped over :meth:`retrieve` per row pay ``N`` sequential round - trips to the embed service; routing through ``retrieve_batch`` - collapses that to a single request and a single LanceDB search - sweep. + Funnels the whole query list through :meth:`queries`, which issues + exactly one embed request regardless of ``len(queries)``. Args: queries: Iterable of natural-language query strings. Order is preserved in the returned list. - top_k: Per-call override of ``self.top_k``. Passed through - to :meth:`queries` as a local value, so ``self.top_k`` is - never written. Concurrent callers sharing a single - :class:`Retriever` instance therefore remain thread-safe. + top_k: Per-call override of ``self.top_k``; passed as a local + value so the instance attribute is never mutated. embedder: Per-call embedder override. lancedb_uri: Per-call LanceDB URI override. lancedb_table: Per-call LanceDB table override. @@ -888,17 +872,13 @@ def run( graph = retrieval_op for step in self._steps: graph = graph >> step - # ``Graph.execute`` returns one entry per leaf node; a linear - # live-RAG pipeline has exactly one leaf. + # Linear live-RAG pipelines have exactly one leaf. leaves = graph.execute(df) if len(leaves) != 1: raise RuntimeError(f"Unexpected pipeline fan-out: got {len(leaves)} leaf outputs") out = leaves[0] - # Surface generation failure rate when the pipeline ran generation. - # ``QAGenerationOperator`` writes ``gen_error`` per row (non-null on - # failure); downstream aggregators read ``df.attrs`` without having - # to scan rows or guard on column presence themselves. + # Expose the generation failure rate on ``df.attrs`` for downstream aggregators. if "gen_error" in out.columns and len(out) > 0: out.attrs["generation_failure_rate"] = float(out["gen_error"].notna().mean()) diff --git a/nemo_retriever/tests/test_llm_params.py b/nemo_retriever/tests/test_llm_params.py index 7a1d3bf83..ec9a740a6 100644 --- a/nemo_retriever/tests/test_llm_params.py +++ b/nemo_retriever/tests/test_llm_params.py @@ -101,13 +101,21 @@ def test_structured_construction(self): assert client.sampling is sampling assert client.model == "openai/gpt-4o-mini" - def test_default_sampling_is_llminferenceparams_defaults(self): + def test_default_sampling_matches_from_kwargs_for_rag_determinism(self): + """``LiteLLMClient`` is a RAG-eval client and must default to + deterministic sampling regardless of which constructor path the + caller picks. The structured constructor therefore overrides + ``LLMInferenceParams``'s general-purpose ``temperature=1.0`` with + ``0.0`` so it agrees with :meth:`LiteLLMClient.from_kwargs`. + ``top_p`` / ``max_tokens`` still come from ``LLMInferenceParams`` + (they already match what ``from_kwargs`` builds). + """ from nemo_retriever.llm.clients import LiteLLMClient from nemo_retriever.params.models import LLMInferenceParams, LLMRemoteClientParams client = LiteLLMClient(transport=LLMRemoteClientParams(model="m")) assert isinstance(client.sampling, LLMInferenceParams) - assert client.sampling.temperature == 1.0 + assert client.sampling.temperature == 0.0 assert client.sampling.top_p is None assert client.sampling.max_tokens == 1024 @@ -393,3 +401,43 @@ def test_default_model_is_a_non_empty_string(self): assert isinstance(LiteLLMClient._DEFAULT_MODEL, str) assert LiteLLMClient._DEFAULT_MODEL + + +class TestLiteLLMDefaultSamplingAlignment: + """Both constructor paths must default to the same deterministic sampling. + + Regression test for the Greptile P1 finding that + ``LiteLLMClient(transport=...)`` with ``sampling=None`` silently + fell through to ``LLMInferenceParams()`` (``temperature=1.0``) while + ``LiteLLMClient.from_kwargs(...)`` explicitly defaulted to + ``temperature=0.0``. For RAG-eval reproducibility the two paths + must converge on the same default. + """ + + def test_structured_constructor_defaults_to_zero_temperature(self): + from nemo_retriever.llm.clients import LiteLLMClient + from nemo_retriever.params import LLMRemoteClientParams + + client = LiteLLMClient(transport=LLMRemoteClientParams(model="m")) + assert client.sampling.temperature == 0.0 + + def test_structured_and_flat_paths_agree_on_defaults(self): + from nemo_retriever.llm.clients import LiteLLMClient + from nemo_retriever.params import LLMRemoteClientParams + + structured = LiteLLMClient(transport=LLMRemoteClientParams(model="m")) + flat = LiteLLMClient.from_kwargs(model="m") + assert structured.sampling.temperature == flat.sampling.temperature + assert structured.sampling.max_tokens == flat.sampling.max_tokens + assert structured.sampling.top_p == flat.sampling.top_p + + def test_explicit_sampling_is_not_overridden(self): + """Passing an explicit ``LLMInferenceParams`` must win over the default.""" + from nemo_retriever.llm.clients import LiteLLMClient + from nemo_retriever.params import LLMInferenceParams, LLMRemoteClientParams + + client = LiteLLMClient( + transport=LLMRemoteClientParams(model="m"), + sampling=LLMInferenceParams(temperature=0.7), + ) + assert client.sampling.temperature == 0.7 From c803a4b70c24f90a98898da326705d15888bc2c4 Mon Sep 17 00:00:00 2001 From: Kyle Zheng <126034466+KyleZheng1284@users.noreply.github.com> Date: Wed, 22 Apr 2026 00:23:39 +0000 Subject: [PATCH 07/10] resolve greptile issues + add tests cases + fix existing ones --- .../src/nemo_retriever/evaluation/config.py | 24 ++- .../nemo_retriever/evaluation/retrievers.py | 42 ++-- .../src/nemo_retriever/io/markdown.py | 20 +- .../src/nemo_retriever/llm/clients/litellm.py | 9 +- .../tests/test_evaluation_config.py | 109 ++++++++++ nemo_retriever/tests/test_file_retriever.py | 198 ++++++++++++++++++ nemo_retriever/tests/test_io_markdown.py | 102 ++++++++- nemo_retriever/tests/test_llm_params.py | 16 +- .../tests/test_retriever_queries.py | 16 +- 9 files changed, 495 insertions(+), 41 deletions(-) create mode 100644 nemo_retriever/tests/test_evaluation_config.py create mode 100644 nemo_retriever/tests/test_file_retriever.py diff --git a/nemo_retriever/src/nemo_retriever/evaluation/config.py b/nemo_retriever/src/nemo_retriever/evaluation/config.py index 8fd33ea2b..dbd3fd7cc 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/config.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/config.py @@ -118,6 +118,16 @@ def _normalize_config(config: dict) -> dict: The ``evaluations`` list is always present after normalisation (synthesised from ``generators`` + ``judge`` when using legacy format). + + Raises + ------ + ValueError + If ``evaluations`` specifies multiple distinct judges. + :func:`build_eval_chain` and :func:`build_eval_pipeline` support + only a single judge per invocation; use + :func:`~nemo_retriever.evaluation.runner.run_eval_sweep` for + heterogeneous sweeps -- it iterates ``evaluations`` and selects + the correct judge per combo. """ if "models" in config and "evaluations" in config: models = config["models"] @@ -140,13 +150,13 @@ def _normalize_config(config: dict) -> dict: first_judge_key = evals[0]["judge"] distinct_judges = {e["judge"] for e in evals} if len(distinct_judges) > 1: - logger.warning( - "Config has %d distinct judges %s; legacy 'judge' key uses " - "only the first (%r). Use --config sweep or build per-eval " - "clients for heterogeneous judges.", - len(distinct_judges), - sorted(distinct_judges), - first_judge_key, + raise ValueError( + f"Config has {len(distinct_judges)} distinct judges " + f"{sorted(distinct_judges)}. " + "build_eval_chain() and build_eval_pipeline() support only a " + "single judge per invocation. Use run_eval_sweep() for " + "heterogeneous judges -- it iterates the evaluations list " + "and uses the correct judge per-combo." ) config.setdefault("judge", models[first_judge_key]) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py b/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py index 74884e666..3b387fd07 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/retrievers.py @@ -60,8 +60,6 @@ def __init__(self, file_path: str): if not os.path.exists(file_path): raise FileNotFoundError(f"FileRetriever: retrieval results file not found: {file_path}") - self.file_path = file_path - with open(file_path) as f: data = json.load(f) @@ -79,16 +77,40 @@ def __init__(self, file_path: str): 'Expected: {"queries": {"query": {"chunks": ["..."]}}}' ) + self._initialize_index(raw_index, source=file_path) + + def _initialize_index(self, raw_index: dict[str, dict], *, source: str) -> None: + """Populate instance state from an already-validated queries mapping. + + Single source of truth for all :class:`FileRetriever` instance + fields used by :meth:`retrieve` and :meth:`check_coverage`. + Called by both :meth:`__init__` (file-based) and + :meth:`_from_dict` (in-memory, used by :meth:`from_lancedb`) so + that new instance fields only need to be added in one place and + can never diverge between the two construction paths. + + Parameters + ---------- + raw_index : dict[str, dict] + ``{query_text: {"chunks": [...], "metadata": [...]}}`` -- + the same shape both entry points produce. Must already be + non-empty and contain a ``chunks`` list; validation is the + caller's responsibility so error messages can reference the + originating source (file path vs. in-memory dict). + source : str + Human-readable origin label stored on ``self.file_path`` + (e.g. a filesystem path or ``""``). + """ + self.file_path = source self._norm_index: dict[str, dict] = {} self._raw_keys: dict[str, str] = {} + self._miss_count = 0 + self._miss_lock = threading.Lock() for raw_key, value in raw_index.items(): norm = _normalize_query(raw_key) self._norm_index[norm] = value self._raw_keys[norm] = raw_key - self._miss_count = 0 - self._miss_lock = threading.Lock() - @classmethod def _from_dict(cls, queries: dict[str, dict]) -> "FileRetriever": """Build a FileRetriever from an in-memory queries dict. @@ -113,15 +135,7 @@ def _from_dict(cls, queries: dict[str, dict]) -> "FileRetriever": ) instance = object.__new__(cls) - instance.file_path = "" - instance._norm_index = {} - instance._raw_keys = {} - instance._miss_count = 0 - instance._miss_lock = threading.Lock() - for raw_key, value in queries.items(): - norm = _normalize_query(raw_key) - instance._norm_index[norm] = value - instance._raw_keys[norm] = raw_key + instance._initialize_index(queries, source="") return instance @classmethod diff --git a/nemo_retriever/src/nemo_retriever/io/markdown.py b/nemo_retriever/src/nemo_retriever/io/markdown.py index fe828bb35..b2201cd3c 100644 --- a/nemo_retriever/src/nemo_retriever/io/markdown.py +++ b/nemo_retriever/src/nemo_retriever/io/markdown.py @@ -293,6 +293,13 @@ def _label_for_subtype(subtype: Any, *, fallback: str) -> str: "source_id", "page_number", "text", + # Top-level ``content`` is a fallback text source used by + # _collect_page_record / _collect_primitive_record when neither + # ``text`` nor ``metadata.content`` is populated. It must be + # kept so both the Parquet path (_read_parquet_for_markdown) + # and the in-memory ``dataframe=`` path in build_page_index + # render the same output as the renderer unit tests expect. + "content", "document_type", "_content_type", "metadata", @@ -361,7 +368,18 @@ def build_page_index( raise ValueError("Provide exactly one of parquet_dir or dataframe.") if dataframe is not None: - df = dataframe + # Mirror the column pruning that _read_parquet_for_markdown applies on + # the parquet_dir= path. Caller-supplied DataFrames often still carry + # huge columns (page_image base64 blobs, embedding vectors) that + # row.to_dict() would otherwise materialise for every record, causing + # the same multi-GB memory spikes _read_parquet_for_markdown was built + # to avoid. df[relevant] returns a column-subset view, not a copy -- + # the caller's DataFrame is never mutated. + relevant = [c for c in _MARKDOWN_PARQUET_COLUMNS if c in dataframe.columns] + if relevant and len(relevant) < len(dataframe.columns): + df = dataframe[relevant] + else: + df = dataframe else: parquet_path = Path(parquet_dir) if not parquet_path.is_dir(): diff --git a/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py b/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py index 599e8fe14..78895c298 100644 --- a/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py +++ b/nemo_retriever/src/nemo_retriever/llm/clients/litellm.py @@ -78,9 +78,12 @@ def __init__( sampling: Optional[LLMInferenceParams] = None, ): self.transport = transport - # Default to ``temperature=0.0`` so the structured constructor matches ``from_kwargs`` - # and keeps RAG-eval runs deterministic. - self.sampling = sampling if sampling is not None else LLMInferenceParams(temperature=0.0) + # Default to ``temperature=0.0, max_tokens=4096`` so the structured + # constructor matches ``from_kwargs`` and keeps RAG-eval runs + # deterministic. ``LLMInferenceParams`` itself defaults to + # ``max_tokens=1024`` for captioning/summarization workloads; RAG + # answers routinely exceed that, so the client overrides it. + self.sampling = sampling if sampling is not None else LLMInferenceParams(temperature=0.0, max_tokens=4096) @property def model(self) -> str: diff --git a/nemo_retriever/tests/test_evaluation_config.py b/nemo_retriever/tests/test_evaluation_config.py new file mode 100644 index 000000000..1fb25fce6 --- /dev/null +++ b/nemo_retriever/tests/test_evaluation_config.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for :mod:`nemo_retriever.evaluation.config`. + +Focus: the fail-fast contract in :func:`_normalize_config` that guards +``build_eval_chain`` / ``build_eval_pipeline`` from silently collapsing +heterogeneous-judge configs to a single judge. +""" + +from __future__ import annotations + +import pytest + +from nemo_retriever.evaluation.config import _normalize_config + + +def _make_multi_judge_config() -> dict: + """New ``models`` + ``evaluations`` schema with two distinct judges.""" + return { + "models": { + "gen-a": {"model": "provider/gen-a", "api_key": "k"}, + "gen-b": {"model": "provider/gen-b", "api_key": "k"}, + "judge-x": {"model": "provider/judge-x", "api_key": "k"}, + "judge-y": {"model": "provider/judge-y", "api_key": "k"}, + }, + "evaluations": [ + {"generator": "gen-a", "judge": "judge-x"}, + {"generator": "gen-b", "judge": "judge-y"}, + ], + } + + +def _make_single_judge_new_schema_config() -> dict: + """New schema with one judge shared across multiple generators.""" + return { + "models": { + "gen-a": {"model": "provider/gen-a", "api_key": "k"}, + "gen-b": {"model": "provider/gen-b", "api_key": "k"}, + "judge-x": {"model": "provider/judge-x", "api_key": "k"}, + }, + "evaluations": [ + {"generator": "gen-a", "judge": "judge-x"}, + {"generator": "gen-b", "judge": "judge-x"}, + ], + } + + +def _make_legacy_config() -> dict: + """Legacy ``generators`` + ``judge`` schema.""" + return { + "generators": [ + {"name": "gen-a", "model": "provider/gen-a", "api_key": "k"}, + {"name": "gen-b", "model": "provider/gen-b", "api_key": "k"}, + ], + "judge": {"name": "judge-x", "model": "provider/judge-x", "api_key": "k"}, + } + + +def test_normalize_config_multi_judge_raises() -> None: + """Multi-judge configs must fail fast instead of silently collapsing. + + Previously ``_normalize_config`` logged a warning and kept the first + judge, which meant ``build_eval_chain`` / ``build_eval_pipeline`` + scored every generator against that single judge without any error. + """ + config = _make_multi_judge_config() + + with pytest.raises(ValueError) as exc_info: + _normalize_config(config) + + message = str(exc_info.value) + assert "run_eval_sweep" in message, "error must point users at the correct API for heterogeneous judges" + assert "judge-x" in message and "judge-y" in message, "error must list the distinct judges that conflict" + assert "2 distinct judges" in message, "error must report how many judges collided" + + +def test_normalize_config_single_judge_new_schema_passes() -> None: + """Single-judge configs in the new schema must normalise cleanly.""" + config = _make_single_judge_new_schema_config() + + normalized = _normalize_config(config) + + assert "generators" in normalized + assert "judge" in normalized + assert normalized["judge"]["model"] == "provider/judge-x" + gen_names = {g["name"] for g in normalized["generators"]} + assert gen_names == {"gen-a", "gen-b"} + + +def test_normalize_config_legacy_schema_passes() -> None: + """Legacy ``generators`` + ``judge`` configs must still normalise. + + The legacy schema has a scalar ``judge`` by construction, so it can + never trigger the multi-judge fail-fast path. + """ + config = _make_legacy_config() + + normalized = _normalize_config(config) + + assert "models" in normalized + assert "evaluations" in normalized + assert "gen-a" in normalized["models"] + assert "gen-b" in normalized["models"] + assert "judge-x" in normalized["models"] + + eval_judges = {e["judge"] for e in normalized["evaluations"]} + assert eval_judges == {"judge-x"} diff --git a/nemo_retriever/tests/test_file_retriever.py b/nemo_retriever/tests/test_file_retriever.py new file mode 100644 index 000000000..1e056e2df --- /dev/null +++ b/nemo_retriever/tests/test_file_retriever.py @@ -0,0 +1,198 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for :class:`nemo_retriever.evaluation.retrievers.FileRetriever`. + +These tests pin the contract for both entry points -- the file-based +``__init__`` and the in-memory ``_from_dict`` -- and assert that both +produce instances with identical state. That parity invariant is the +structural guard against the two construction paths silently diverging +when new instance fields are added in future. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import patch + +import pytest + +from nemo_retriever.evaluation.retrievers import FileRetriever +from nemo_retriever.llm.types import RetrievalResult + +_SAMPLE_QUERIES: dict[str, dict] = { + "What is the range of the 767?": { + "chunks": ["The 767 has a range of ~6,000 nmi.", "Variants differ."], + "metadata": [{"source": "spec.pdf"}, {"source": "variants.pdf"}], + }, + "How many seats does the 747 have?": { + "chunks": ["Up to 524 passengers in 3-class config."], + "metadata": [{"source": "747_brochure.pdf"}], + }, +} + + +def _write_retrieval_json(tmp_path: Path, queries: dict[str, dict]) -> Path: + path = tmp_path / "retrieval.json" + path.write_text(json.dumps({"queries": queries}), encoding="utf-8") + return path + + +def test_init_roundtrip(tmp_path: Path) -> None: + """Loading from JSON -> retrieve() returns the stored chunks/metadata.""" + path = _write_retrieval_json(tmp_path, _SAMPLE_QUERIES) + + retriever = FileRetriever(file_path=str(path)) + result = retriever.retrieve("What is the range of the 767?", top_k=2) + + assert isinstance(result, RetrievalResult) + assert result.chunks == _SAMPLE_QUERIES["What is the range of the 767?"]["chunks"] + assert result.metadata == _SAMPLE_QUERIES["What is the range of the 767?"]["metadata"] + assert retriever.file_path == str(path) + + +def test_init_empty_raises(tmp_path: Path) -> None: + """Empty ``queries`` dict raises ValueError with the file path in the message.""" + path = _write_retrieval_json(tmp_path, {}) + + with pytest.raises(ValueError, match="no 'queries' key found") as exc_info: + FileRetriever(file_path=str(path)) + + assert str(path) in str(exc_info.value), "error must reference the offending file path" + + +def test_init_missing_chunks_raises(tmp_path: Path) -> None: + """An entry without a ``chunks`` list raises ValueError.""" + path = _write_retrieval_json(tmp_path, {"a query": {"metadata": [{"source": "x"}]}}) + + with pytest.raises(ValueError, match="missing a 'chunks' list"): + FileRetriever(file_path=str(path)) + + +def test_init_missing_file_raises(tmp_path: Path) -> None: + """A non-existent file raises FileNotFoundError, not a cryptic IOError.""" + path = tmp_path / "does_not_exist.json" + + with pytest.raises(FileNotFoundError, match="retrieval results file not found"): + FileRetriever(file_path=str(path)) + + +def test_from_dict_roundtrip() -> None: + """Loading from in-memory dict -> retrieve() returns the stored chunks.""" + retriever = FileRetriever._from_dict(_SAMPLE_QUERIES) + result = retriever.retrieve("How many seats does the 747 have?", top_k=5) + + assert result.chunks == _SAMPLE_QUERIES["How many seats does the 747 have?"]["chunks"] + assert retriever.file_path == "" + + +def test_from_dict_normalizes_keys() -> None: + """Whitespace and case variations in the query still match the stored entry. + + Locks the contract with :func:`_normalize_query` -- if either + construction path skips normalization the lookup would miss. + """ + retriever = FileRetriever._from_dict(_SAMPLE_QUERIES) + + variants = [ + "what is the range of the 767?", + "What is the range of the 767? ", + " What is the range of the 767? ", + ] + for variant in variants: + result = retriever.retrieve(variant, top_k=2) + assert result.chunks, f"normalized lookup failed for variant {variant!r}" + + +def test_from_dict_empty_raises() -> None: + """Empty dict raises ValueError whose message identifies the in-memory path.""" + with pytest.raises(ValueError, match="_from_dict: queries dict is empty"): + FileRetriever._from_dict({}) + + +def test_from_dict_missing_chunks_raises() -> None: + """Entry without a ``chunks`` list raises a _from_dict-tagged ValueError.""" + with pytest.raises(ValueError, match="_from_dict: first entry is missing"): + FileRetriever._from_dict({"a query": {"metadata": []}}) + + +def test_init_and_from_dict_have_identical_state(tmp_path: Path) -> None: + """Structural invariant: both construction paths produce the same instance shape. + + This is the guard against divergence that jperez flagged. If a new + instance field is ever added to one entry point but not the other, + this test fails immediately -- no silent runtime bug. + """ + path = _write_retrieval_json(tmp_path, _SAMPLE_QUERIES) + + via_init = FileRetriever(file_path=str(path)) + via_from_dict = FileRetriever._from_dict(_SAMPLE_QUERIES) + + # Both instances must expose the same set of public + private fields. + assert set(vars(via_init).keys()) == set(vars(via_from_dict).keys()) + + # And every field of the same type: catches e.g. one path forgetting + # to initialise the lock or the miss counter. + for field_name in vars(via_init): + init_attr = getattr(via_init, field_name) + from_dict_attr = getattr(via_from_dict, field_name) + assert type(init_attr) is type(from_dict_attr), ( + f"field {field_name!r} has mismatched types: " + f"{type(init_attr).__name__} via __init__ vs " + f"{type(from_dict_attr).__name__} via _from_dict" + ) + + # The normalized index must contain the same keys and chunk payloads + # regardless of entry point. + assert via_init._norm_index.keys() == via_from_dict._norm_index.keys() + for norm_key in via_init._norm_index: + assert via_init._norm_index[norm_key] == via_from_dict._norm_index[norm_key] + + +def test_from_lancedb_save_path_sets_file_path(tmp_path: Path) -> None: + """When ``save_path`` is provided, ``file_path`` reflects the saved path. + + Mocks :func:`query_lancedb` + :func:`write_retrieval_json` so the + test does not depend on a live LanceDB directory. + """ + save_path = tmp_path / "saved_retrieval.json" + fake_meta = {"lancedb_uri": "mock"} + + with ( + patch( + "nemo_retriever.export.query_lancedb", + return_value=(_SAMPLE_QUERIES, fake_meta), + ), + patch("nemo_retriever.export.write_retrieval_json") as mock_write, + ): + retriever = FileRetriever.from_lancedb( + qa_pairs=[{"query": "What is the range of the 767?"}], + lancedb_uri="mock", + save_path=str(save_path), + ) + + assert retriever.file_path == str(save_path) + mock_write.assert_called_once() + + +def test_from_lancedb_no_save_path_keeps_memory_label() -> None: + """Without ``save_path`` the instance reports the in-memory origin.""" + fake_meta = {"lancedb_uri": "mock"} + + with ( + patch( + "nemo_retriever.export.query_lancedb", + return_value=(_SAMPLE_QUERIES, fake_meta), + ), + patch("nemo_retriever.export.write_retrieval_json") as mock_write, + ): + retriever = FileRetriever.from_lancedb( + qa_pairs=[{"query": "What is the range of the 767?"}], + lancedb_uri="mock", + save_path=None, + ) + + assert retriever.file_path == "" + mock_write.assert_not_called() diff --git a/nemo_retriever/tests/test_io_markdown.py b/nemo_retriever/tests/test_io_markdown.py index 837f6d8e0..e5f22605e 100644 --- a/nemo_retriever/tests/test_io_markdown.py +++ b/nemo_retriever/tests/test_io_markdown.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from nemo_retriever.io import to_markdown, to_markdown_by_page +from nemo_retriever.io import build_page_index, to_markdown, to_markdown_by_page class _LazyRows: @@ -130,3 +130,103 @@ def test_to_markdown_rejects_multi_document_results() -> None: with pytest.raises(ValueError, match="single document result"): to_markdown([doc_a, doc_b]) + + +def test_build_page_index_prunes_irrelevant_columns_on_dataframe_path() -> None: + """Regression test: ``build_page_index(dataframe=)`` must not materialise + huge columns that the markdown renderer never reads. + + Before the fix, the ``dataframe=`` branch passed the user's DataFrame + straight through to ``df.iterrows()`` + ``row.to_dict()``, which + materialised every column -- including ``page_image`` base64 blobs and + embedding vectors -- for every record. That produced the same multi-GB + memory spikes ``_read_parquet_for_markdown`` was explicitly built to + avoid on the Parquet path. The fix mirrors the Parquet-path column + pruning via ``_MARKDOWN_PARQUET_COLUMNS``. + + This test verifies three guarantees: + 1. Rendering still produces the same markdown output. + 2. A huge extraneous column does not propagate into the rendered + records (catches the bug by construction). + 3. The caller's DataFrame is not mutated in place. + """ + large_blob = "x" * 100_000 + df = pd.DataFrame( + [ + { + "path": "/tmp/doc.pdf", + "page_number": 1, + "text": "First page text", + "page_image": large_blob, + "embedding": [0.0] * 1024, + }, + { + "path": "/tmp/doc.pdf", + "page_number": 2, + "text": "Second page text", + "page_image": large_blob, + "embedding": [0.0] * 1024, + }, + ] + ) + original_columns = set(df.columns) + + index, failures = build_page_index(dataframe=df) + + assert not failures + assert "/tmp/doc.pdf" in index + rendered = index["/tmp/doc.pdf"] + assert "1" in rendered and "2" in rendered + assert "First page text" in rendered["1"] + assert "Second page text" in rendered["2"] + + for page_md in rendered.values(): + assert large_blob not in page_md, "huge column leaked into rendered markdown" + + assert set(df.columns) == original_columns, "caller's DataFrame must not be mutated" + assert "page_image" in df.columns + assert "embedding" in df.columns + + +def test_build_page_index_no_op_when_all_columns_are_allow_listed() -> None: + """When the caller already supplies a pruned DataFrame, the filter is + a no-op: ``df`` is identical (same object, same columns).""" + df = pd.DataFrame( + [ + { + "path": "/tmp/doc.pdf", + "page_number": 1, + "text": "Only essentials", + } + ] + ) + + index, failures = build_page_index(dataframe=df) + + assert not failures + assert "/tmp/doc.pdf" in index + assert "Only essentials" in index["/tmp/doc.pdf"]["1"] + + +def test_build_page_index_preserves_content_fallback_column() -> None: + """Guards against regressing the ``content`` fallback path. + + ``_collect_page_record`` reads ``record.get("content")`` as a tertiary + fallback when ``record.get("text")`` is absent. The allow-list must + include ``content`` so rows that carry only that column still render. + """ + df = pd.DataFrame( + [ + { + "path": "/tmp/content_only.pdf", + "page_number": 1, + "content": "Fallback body text", + "page_image": "x" * 100_000, + } + ] + ) + + index, failures = build_page_index(dataframe=df) + + assert not failures + assert "Fallback body text" in index["/tmp/content_only.pdf"]["1"] diff --git a/nemo_retriever/tests/test_llm_params.py b/nemo_retriever/tests/test_llm_params.py index ec9a740a6..8cf65b955 100644 --- a/nemo_retriever/tests/test_llm_params.py +++ b/nemo_retriever/tests/test_llm_params.py @@ -102,13 +102,13 @@ def test_structured_construction(self): assert client.model == "openai/gpt-4o-mini" def test_default_sampling_matches_from_kwargs_for_rag_determinism(self): - """``LiteLLMClient`` is a RAG-eval client and must default to - deterministic sampling regardless of which constructor path the - caller picks. The structured constructor therefore overrides - ``LLMInferenceParams``'s general-purpose ``temperature=1.0`` with - ``0.0`` so it agrees with :meth:`LiteLLMClient.from_kwargs`. - ``top_p`` / ``max_tokens`` still come from ``LLMInferenceParams`` - (they already match what ``from_kwargs`` builds). + """``LiteLLMClient`` is a RAG-eval client and must default to the + same deterministic sampling regardless of which constructor path + the caller picks. The structured constructor therefore overrides + ``LLMInferenceParams``'s general-purpose defaults + (``temperature=1.0``, ``max_tokens=1024``) with the RAG-tuned + ``temperature=0.0`` / ``max_tokens=4096`` so it agrees with + :meth:`LiteLLMClient.from_kwargs`. """ from nemo_retriever.llm.clients import LiteLLMClient from nemo_retriever.params.models import LLMInferenceParams, LLMRemoteClientParams @@ -117,7 +117,7 @@ def test_default_sampling_matches_from_kwargs_for_rag_determinism(self): assert isinstance(client.sampling, LLMInferenceParams) assert client.sampling.temperature == 0.0 assert client.sampling.top_p is None - assert client.sampling.max_tokens == 1024 + assert client.sampling.max_tokens == 4096 def test_from_kwargs_matches_explicit(self): from nemo_retriever.llm.clients import LiteLLMClient diff --git a/nemo_retriever/tests/test_retriever_queries.py b/nemo_retriever/tests/test_retriever_queries.py index 2d5f9301e..eec292bf2 100644 --- a/nemo_retriever/tests/test_retriever_queries.py +++ b/nemo_retriever/tests/test_retriever_queries.py @@ -15,7 +15,6 @@ import pytest - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -190,6 +189,7 @@ def test_query_delegates_to_queries_and_returns_first_element(self): mock_queries.assert_called_once_with( ["find something"], + top_k=None, embedder=None, lancedb_uri=None, lancedb_table=None, @@ -201,7 +201,7 @@ def test_query_passes_through_overrides(self): with patch.object(r, "queries", return_value=[[]]) as mock_queries: r.query("q", embedder="e", lancedb_uri="u", lancedb_table="t") - mock_queries.assert_called_once_with(["q"], embedder="e", lancedb_uri="u", lancedb_table="t") + mock_queries.assert_called_once_with(["q"], top_k=None, embedder="e", lancedb_uri="u", lancedb_table="t") # --------------------------------------------------------------------------- @@ -234,7 +234,7 @@ def test_rerank_results_called_when_reranker_set(self): ): r.queries(["q"]) - mock_rerank.assert_called_once_with(["q"], fake_results) + mock_rerank.assert_called_once_with(["q"], fake_results, top_k=r.top_k) def test_rerank_not_called_when_reranker_is_none(self): r = _make_retriever(reranker=None) @@ -275,7 +275,7 @@ def test_rerank_results_uses_endpoint_not_local_model(self): } with patch("requests.post", return_value=mock_resp) as mock_post: - out = r._rerank_results(["q"], [fake_hits]) + out = r._rerank_results(["q"], [fake_hits], top_k=r.top_k) mock_post.assert_called() # Results should be sorted descending @@ -297,7 +297,7 @@ def test_rerank_results_with_local_model(self): fake_model.score.return_value = [0.1, 0.9, 0.5, 0.3] with patch.object(r, "_get_reranker_model", return_value=fake_model): - out = r._rerank_results(["q"], [hits]) + out = r._rerank_results(["q"], [hits], top_k=r.top_k) scores = [h["_rerank_score"] for h in out[0]] assert scores == sorted(scores, reverse=True) @@ -310,7 +310,7 @@ def test_rerank_results_respects_top_k(self): fake_model.score.return_value = [0.1, 0.9, 0.5, 0.3] with patch.object(r, "_get_reranker_model", return_value=fake_model): - out = r._rerank_results(["q"], [hits]) + out = r._rerank_results(["q"], [hits], top_k=r.top_k) assert len(out[0]) == 2 @@ -322,7 +322,7 @@ def test_rerank_results_multiple_queries(self): fake_model.score.side_effect = [[0.2, 0.8], [0.6, 0.4]] with patch.object(r, "_get_reranker_model", return_value=fake_model): - out = r._rerank_results(["q1", "q2"], [hits_a, hits_b]) + out = r._rerank_results(["q1", "q2"], [hits_a, hits_b], top_k=r.top_k) assert len(out) == 2 # Each per-query list should be sorted descending @@ -424,6 +424,7 @@ def test_extra_keys_stripped_from_dense_results(self): lancedb_table="t", query_vectors=[_DUMMY_VECTOR], query_texts=["q"], + top_k=r.top_k, ) hit = results[0][0] @@ -471,6 +472,7 @@ def test_extra_keys_stripped_from_hybrid_results(self): lancedb_table="t", query_vectors=[_DUMMY_VECTOR], query_texts=["q"], + top_k=r.top_k, ) hit = results[0][0] From 8137d3b6c18761d008b0fa934b00d30208dc5a29 Mon Sep 17 00:00:00 2001 From: Kyle Zheng <126034466+KyleZheng1284@users.noreply.github.com> Date: Wed, 22 Apr 2026 17:31:07 +0000 Subject: [PATCH 08/10] fix smoke testing + issues --- nemo_retriever/pyproject.toml | 2 +- .../nemo_retriever/evaluation/orchestrator.py | 2 +- .../src/nemo_retriever/io/markdown.py | 11 +- .../tests/test_evaluation_orchestrator.py | 149 ++++++++++++++++++ ...iever.py => test_evaluation_retrievers.py} | 0 5 files changed, 155 insertions(+), 9 deletions(-) create mode 100644 nemo_retriever/tests/test_evaluation_orchestrator.py rename nemo_retriever/tests/{test_file_retriever.py => test_evaluation_retrievers.py} (100%) diff --git a/nemo_retriever/pyproject.toml b/nemo_retriever/pyproject.toml index 9408502cf..c7896ebb1 100644 --- a/nemo_retriever/pyproject.toml +++ b/nemo_retriever/pyproject.toml @@ -126,7 +126,7 @@ dev = [ # ── Convenience: full install ───────────────────────────────────────────────── all = [ - "nemo_retriever[local,multimedia,stores,benchmarks]", + "nemo_retriever[local,multimedia,stores,benchmarks,llm]", ] [project.scripts] diff --git a/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py b/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py index faa5f5be5..bfa6b8f01 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/orchestrator.py @@ -265,7 +265,7 @@ def _prepare_dataframe(self, qa_pairs: list[dict]) -> pd.DataFrame: def _retrieve(idx: int, pair: dict) -> tuple[int, dict]: query = pair["query"] - reference = pair.get("reference_answer") or pair["answer"] + reference = pair.get("reference_answer") or pair.get("answer", "") retrieval = self.retriever.retrieve(query, self.top_k) return idx, { "query": query, diff --git a/nemo_retriever/src/nemo_retriever/io/markdown.py b/nemo_retriever/src/nemo_retriever/io/markdown.py index b2201cd3c..fac99032d 100644 --- a/nemo_retriever/src/nemo_retriever/io/markdown.py +++ b/nemo_retriever/src/nemo_retriever/io/markdown.py @@ -368,13 +368,10 @@ def build_page_index( raise ValueError("Provide exactly one of parquet_dir or dataframe.") if dataframe is not None: - # Mirror the column pruning that _read_parquet_for_markdown applies on - # the parquet_dir= path. Caller-supplied DataFrames often still carry - # huge columns (page_image base64 blobs, embedding vectors) that - # row.to_dict() would otherwise materialise for every record, causing - # the same multi-GB memory spikes _read_parquet_for_markdown was built - # to avoid. df[relevant] returns a column-subset view, not a copy -- - # the caller's DataFrame is never mutated. + # Prune to the same allow-list the parquet_dir= path uses so wide + # columns like page_image base64 blobs or embedding vectors never + # reach row.to_dict(). df[relevant] is a column-subset view, not + # a copy -- the caller's DataFrame is not mutated. relevant = [c for c in _MARKDOWN_PARQUET_COLUMNS if c in dataframe.columns] if relevant and len(relevant) < len(dataframe.columns): df = dataframe[relevant] diff --git a/nemo_retriever/tests/test_evaluation_orchestrator.py b/nemo_retriever/tests/test_evaluation_orchestrator.py new file mode 100644 index 000000000..8415cd148 --- /dev/null +++ b/nemo_retriever/tests/test_evaluation_orchestrator.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Regression tests for ``QAEvalPipeline._prepare_dataframe`` reference-answer +fallback. + +Before the fix in this PR, line 268 of ``orchestrator.py`` used:: + + reference = pair.get("reference_answer") or pair["answer"] + +which raised ``KeyError('answer')`` in two situations: + +1. Neither ``reference_answer`` nor ``answer`` key is present in the qa pair. +2. ``reference_answer`` is the empty string ``""`` (falsy), and ``answer`` + is missing. + +Both cases were silently caught by the surrounding ``except Exception`` +handler and surfaced to users as misleading ``"Retrieval for query [..] +failed"`` log lines, even though retrieval was never attempted. + +The fix mirrors the identical fallback already used on line 294:: + + reference = pair.get("reference_answer") or pair.get("answer", "") + +so both cases now yield an empty-string reference and retrieval proceeds +normally. The row still participates in downstream scoring; the empty +reference simply produces an empty ``answer_in_context``/``token_f1`` +value, which is the correct behaviour for a missing ground-truth pair. + +These tests additionally guard against the two lines drifting apart again +by asserting identical behaviour between the happy-path and exception-path +on the same malformed input. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from nemo_retriever.evaluation.orchestrator import QAEvalPipeline +from nemo_retriever.llm.types import RetrievalResult + + +def _make_pipeline(*, retrieve_side_effect: Any = None) -> QAEvalPipeline: + """Build a QAEvalPipeline with mocked retriever/llm/judge. + + The retriever returns a stable empty RetrievalResult unless a custom + ``retrieve_side_effect`` is provided (e.g. to simulate a real retrieval + failure for the exception-path test). + """ + retriever = MagicMock() + if retrieve_side_effect is None: + retriever.retrieve.return_value = RetrievalResult(chunks=[], metadata=[]) + else: + retriever.retrieve.side_effect = retrieve_side_effect + + llm = MagicMock() + judge = MagicMock() + + return QAEvalPipeline( + retriever=retriever, + llm_clients={"m": llm}, + judge=judge, + top_k=3, + max_workers=1, + ) + + +class TestReferenceAnswerFallback: + """Line 268 must mirror line 294's fallback semantics exactly.""" + + def test_missing_both_keys_does_not_raise(self) -> None: + """No ``KeyError`` when neither ``reference_answer`` nor ``answer`` + is present; the row is still emitted with an empty reference.""" + pipeline = _make_pipeline() + + df = pipeline._prepare_dataframe([{"query": "what is foo?"}]) + + assert len(df) == 1 + assert df.iloc[0]["query"] == "what is foo?" + assert df.iloc[0]["reference_answer"] == "" + pipeline.retriever.retrieve.assert_called_once_with("what is foo?", 3) + + def test_empty_reference_answer_falls_through_to_answer(self) -> None: + """``reference_answer == ""`` is falsy in Python, so the fallback + kicks in. ``answer`` wins when present.""" + pipeline = _make_pipeline() + + df = pipeline._prepare_dataframe([{"query": "q", "reference_answer": "", "answer": "alt"}]) + + assert df.iloc[0]["reference_answer"] == "alt" + + def test_empty_reference_answer_without_answer_key_does_not_raise(self) -> None: + """This is the subtle second crash path the reviewer identified: + empty reference_answer + no answer key used to raise ``KeyError`` + because ``or`` treats "" as falsy and fell through to + ``pair["answer"]``. It must now yield an empty string instead.""" + pipeline = _make_pipeline() + + df = pipeline._prepare_dataframe([{"query": "q", "reference_answer": ""}]) + + assert df.iloc[0]["reference_answer"] == "" + + def test_reference_answer_present_is_preserved(self) -> None: + """Happy path regression guard: an explicit reference wins.""" + pipeline = _make_pipeline() + + df = pipeline._prepare_dataframe([{"query": "q", "reference_answer": "truth", "answer": "other"}]) + + assert df.iloc[0]["reference_answer"] == "truth" + + def test_only_answer_key_is_used_as_legacy_fallback(self) -> None: + """Backward-compat path documented in ``evaluate.__doc__``: a legacy + ground-truth CSV only has ``answer``, not ``reference_answer``.""" + pipeline = _make_pipeline() + + df = pipeline._prepare_dataframe([{"query": "q", "answer": "legacy"}]) + + assert df.iloc[0]["reference_answer"] == "legacy" + + +class TestRetrievalExceptionPathMatchesHappyPath: + """Structural invariant: the exception-branch fallback at line 294 and + the happy-branch fallback at line 268 must produce the same reference + string for the same input. This test locks them together so a future + refactor cannot reintroduce the divergence Greptile flagged.""" + + @pytest.mark.parametrize( + "pair,expected_reference", + [ + ({"query": "q"}, ""), + ({"query": "q", "reference_answer": ""}, ""), + ({"query": "q", "reference_answer": "r"}, "r"), + ({"query": "q", "answer": "a"}, "a"), + ({"query": "q", "reference_answer": "", "answer": "a"}, "a"), + ], + ) + def test_both_branches_agree(self, pair: dict, expected_reference: str) -> None: + happy_pipeline = _make_pipeline() + happy_df = happy_pipeline._prepare_dataframe([pair]) + assert happy_df.iloc[0]["reference_answer"] == expected_reference + + boom_pipeline = _make_pipeline(retrieve_side_effect=RuntimeError("lancedb down")) + boom_df = boom_pipeline._prepare_dataframe([pair]) + assert boom_df.iloc[0]["reference_answer"] == expected_reference + assert boom_df.iloc[0]["context"] == [] diff --git a/nemo_retriever/tests/test_file_retriever.py b/nemo_retriever/tests/test_evaluation_retrievers.py similarity index 100% rename from nemo_retriever/tests/test_file_retriever.py rename to nemo_retriever/tests/test_evaluation_retrievers.py From 117defc4117180309d19c82158a5c20d1183097c Mon Sep 17 00:00:00 2001 From: Kyle Zheng <126034466+KyleZheng1284@users.noreply.github.com> Date: Wed, 22 Apr 2026 17:34:59 +0000 Subject: [PATCH 09/10] regenerate uv lock --- nemo_retriever/uv.lock | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nemo_retriever/uv.lock b/nemo_retriever/uv.lock index 7fcb234e5..75617f25a 100644 --- a/nemo_retriever/uv.lock +++ b/nemo_retriever/uv.lock @@ -2792,6 +2792,7 @@ all = [ { name = "duckdb-engine" }, { name = "easydict" }, { name = "einops" }, + { name = "litellm" }, { name = "nemotron-graphic-elements-v1" }, { name = "nemotron-ocr", marker = "sys_platform == 'linux'" }, { name = "nemotron-page-elements-v3" }, @@ -2883,7 +2884,7 @@ requires-dist = [ { name = "langchain-nvidia-ai-endpoints", specifier = ">=0.3.0" }, { name = "litellm", marker = "extra == 'llm'", specifier = ">=1.40.0" }, { name = "markitdown" }, - { name = "nemo-retriever", extras = ["benchmarks", "local", "multimedia", "stores"], marker = "extra == 'all'" }, + { name = "nemo-retriever", extras = ["benchmarks", "llm", "local", "multimedia", "stores"], marker = "extra == 'all'" }, { name = "nemotron-graphic-elements-v1", marker = "extra == 'local'", specifier = ">=0.dev0", index = "https://test.pypi.org/simple/" }, { name = "nemotron-ocr", marker = "sys_platform == 'linux' and extra == 'local'", specifier = ">=0.dev0", index = "https://test.pypi.org/simple/" }, { name = "nemotron-page-elements-v3", marker = "extra == 'local'", specifier = ">=0.dev0", index = "https://test.pypi.org/simple/" }, From 321e8855834d8b95ef67ebcc86b6932b20359045 Mon Sep 17 00:00:00 2001 From: Kyle Zheng <126034466+KyleZheng1284@users.noreply.github.com> Date: Wed, 22 Apr 2026 17:52:31 +0000 Subject: [PATCH 10/10] change judge shape to support retry param --- .../src/nemo_retriever/evaluation/config.py | 2 + .../src/nemo_retriever/evaluation/judging.py | 3 + .../src/nemo_retriever/retriever.py | 1 + .../tests/test_evaluation_config.py | 89 ++++++++++++++++++- nemo_retriever/tests/test_llm_params.py | 58 ++++++++++++ 5 files changed, 151 insertions(+), 2 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/config.py b/nemo_retriever/src/nemo_retriever/evaluation/config.py index dbd3fd7cc..f279ffd5f 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/config.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/config.py @@ -321,6 +321,7 @@ def build_eval_chain( api_base=judge_cfg.get("api_base"), api_key=judge_cfg.get("api_key"), extra_params=judge_cfg.get("extra_params"), + num_retries=judge_cfg.get("num_retries", 3), timeout=judge_cfg.get("timeout", default_timeout), max_workers=execution.get("max_workers", 8), ) @@ -391,6 +392,7 @@ def build_eval_pipeline(config: dict) -> "QAEvalPipeline": api_base=judge_cfg.get("api_base"), api_key=judge_cfg.get("api_key"), extra_params=judge_cfg.get("extra_params"), + num_retries=judge_cfg.get("num_retries", 3), timeout=judge_cfg.get("timeout", default_timeout), ) diff --git a/nemo_retriever/src/nemo_retriever/evaluation/judging.py b/nemo_retriever/src/nemo_retriever/evaluation/judging.py index d7ef29efa..f1d82129f 100644 --- a/nemo_retriever/src/nemo_retriever/evaluation/judging.py +++ b/nemo_retriever/src/nemo_retriever/evaluation/judging.py @@ -34,6 +34,7 @@ def __init__( api_base: Optional[str] = None, api_key: Optional[str] = None, extra_params: Optional[dict[str, Any]] = None, + num_retries: int = 3, timeout: float = 120.0, max_workers: int = 8, ) -> None: @@ -42,6 +43,7 @@ def __init__( api_base=api_base, api_key=api_key, extra_params=extra_params, + num_retries=num_retries, timeout=timeout, max_workers=max_workers, ) @@ -50,6 +52,7 @@ def __init__( api_base=api_base, api_key=api_key, extra_params=extra_params, + num_retries=num_retries, timeout=timeout, ) self._max_workers = max_workers diff --git a/nemo_retriever/src/nemo_retriever/retriever.py b/nemo_retriever/src/nemo_retriever/retriever.py index 68b58175c..d2c67bf52 100644 --- a/nemo_retriever/src/nemo_retriever/retriever.py +++ b/nemo_retriever/src/nemo_retriever/retriever.py @@ -805,6 +805,7 @@ def judge( api_base=transport.api_base, api_key=transport.api_key, extra_params=dict(transport.extra_params) if transport.extra_params else None, + num_retries=transport.num_retries, timeout=transport.timeout, ) else: diff --git a/nemo_retriever/tests/test_evaluation_config.py b/nemo_retriever/tests/test_evaluation_config.py index 1fb25fce6..3572aca81 100644 --- a/nemo_retriever/tests/test_evaluation_config.py +++ b/nemo_retriever/tests/test_evaluation_config.py @@ -6,14 +6,18 @@ Focus: the fail-fast contract in :func:`_normalize_config` that guards ``build_eval_chain`` / ``build_eval_pipeline`` from silently collapsing -heterogeneous-judge configs to a single judge. +heterogeneous-judge configs to a single judge, plus the ``num_retries`` +plumbing contract from the judge config block down to the constructed +operator / ``LLMJudge``. """ from __future__ import annotations +from unittest.mock import MagicMock, patch + import pytest -from nemo_retriever.evaluation.config import _normalize_config +from nemo_retriever.evaluation.config import _normalize_config, build_eval_chain, build_eval_pipeline def _make_multi_judge_config() -> dict: @@ -107,3 +111,84 @@ def test_normalize_config_legacy_schema_passes() -> None: eval_judges = {e["judge"] for e in normalized["evaluations"]} assert eval_judges == {"judge-x"} + + +def _make_minimal_legacy_config_with_judge_retries(num_retries: int) -> dict: + """Minimal legacy-schema config carrying an explicit judge.num_retries.""" + return { + "generators": [ + {"name": "gen-a", "model": "provider/gen-a", "api_key": "k"}, + ], + "judge": { + "name": "judge-x", + "model": "provider/judge-x", + "api_key": "k", + "num_retries": num_retries, + }, + "retrieval": {"type": "file", "file_path": "dummy.json"}, + "dataset": {"source": "dummy.csv"}, + "execution": {"top_k": 5, "max_workers": 2}, + } + + +def test_build_eval_chain_forwards_judge_num_retries() -> None: + """``judge.num_retries`` from YAML must reach ``JudgingOperator``. + + Before this fix the ``JudgingOperator`` constructor did not accept + ``num_retries`` at all, so any value a user put in the judge block was + silently dropped and the operator always ran with the default ``3``. + """ + config = _make_minimal_legacy_config_with_judge_retries(num_retries=9) + + with ( + patch("nemo_retriever.evaluation.retrieval_loader.RetrievalLoaderOperator"), + patch("nemo_retriever.evaluation.generation.QAGenerationOperator"), + patch("nemo_retriever.evaluation.scoring_operator.ScoringOperator"), + patch("nemo_retriever.evaluation.judging.JudgingOperator") as mock_judge_op, + ): + mock_judge_op.return_value = MagicMock() + + build_eval_chain(config) + + mock_judge_op.assert_called_once() + assert mock_judge_op.call_args.kwargs["num_retries"] == 9 + + +def test_build_eval_chain_defaults_judge_num_retries_when_absent() -> None: + """When ``judge.num_retries`` is omitted, the default ``3`` must be passed.""" + config = _make_minimal_legacy_config_with_judge_retries(num_retries=3) + config["judge"].pop("num_retries") + + with ( + patch("nemo_retriever.evaluation.retrieval_loader.RetrievalLoaderOperator"), + patch("nemo_retriever.evaluation.generation.QAGenerationOperator"), + patch("nemo_retriever.evaluation.scoring_operator.ScoringOperator"), + patch("nemo_retriever.evaluation.judging.JudgingOperator") as mock_judge_op, + ): + mock_judge_op.return_value = MagicMock() + + build_eval_chain(config) + + assert mock_judge_op.call_args.kwargs["num_retries"] == 3 + + +def test_build_eval_pipeline_forwards_judge_num_retries() -> None: + """``judge.num_retries`` from YAML must reach ``LLMJudge.from_kwargs``. + + This is the sibling path to :func:`build_eval_chain` -- same bug + shape, different construction surface. + """ + config = _make_minimal_legacy_config_with_judge_retries(num_retries=11) + + with ( + patch("nemo_retriever.evaluation.retrievers.FileRetriever"), + patch("nemo_retriever.llm.clients.LiteLLMClient"), + patch("nemo_retriever.evaluation.orchestrator.QAEvalPipeline"), + patch("nemo_retriever.llm.clients.LLMJudge.from_kwargs") as mock_from_kwargs, + ): + mock_from_kwargs.return_value = MagicMock() + + build_eval_pipeline(config) + + mock_from_kwargs.assert_called_once() + assert mock_from_kwargs.call_args.kwargs["num_retries"] == 11 diff --git a/nemo_retriever/tests/test_llm_params.py b/nemo_retriever/tests/test_llm_params.py index 8cf65b955..9453ec50d 100644 --- a/nemo_retriever/tests/test_llm_params.py +++ b/nemo_retriever/tests/test_llm_params.py @@ -317,6 +317,64 @@ def test_judging_operator_constructs_cleanly(self): assert op._judge.model == "nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1" assert op._judge._client.sampling.temperature == 0.0 + def test_judging_operator_plumbs_num_retries_to_inner_judge(self): + """JudgingOperator(num_retries=...) must flow down to the LLMJudge it + instantiates internally. + + Before this fix, ``JudgingOperator.__init__`` had no ``num_retries`` + parameter, so the pre-built ``LLMJudge.transport.num_retries`` set by + a pipeline caller was silently dropped at the operator boundary and + the operator always ran with ``LLMJudge``'s default (3).""" + from nemo_retriever.evaluation.judging import JudgingOperator + + op = JudgingOperator( + model="nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1", + num_retries=7, + ) + assert op._judge._client.transport.num_retries == 7 + + def test_pipeline_builder_judge_forwards_transport_num_retries(self): + """RetrieverPipelineBuilder.judge(judge) unpacks transport.* onto the + operator. num_retries must be in that unpack, symmetric with the + identical .generate() branch at retriever.py:762.""" + from unittest.mock import MagicMock + + from nemo_retriever.evaluation.judging import JudgingOperator + from nemo_retriever.llm.clients import LLMJudge + from nemo_retriever.retriever import RetrieverPipelineBuilder + + retriever = MagicMock() + retriever.top_k = 5 + builder = RetrieverPipelineBuilder(retriever, top_k=5) + + judge = LLMJudge.from_kwargs( + model="nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1", + num_retries=7, + ) + builder.judge(judge) + + judging_ops = [s for s in builder._steps if isinstance(s, JudgingOperator)] + assert len(judging_ops) == 1 + assert judging_ops[0]._judge._client.transport.num_retries == 7 + + def test_pipeline_builder_judge_defaults_num_retries_when_flat_kwargs(self): + """The flat ``model=...`` branch of .judge() must still default + num_retries to 3, preserving the current default behaviour.""" + from unittest.mock import MagicMock + + from nemo_retriever.evaluation.judging import JudgingOperator + from nemo_retriever.retriever import RetrieverPipelineBuilder + + retriever = MagicMock() + retriever.top_k = 5 + builder = RetrieverPipelineBuilder(retriever, top_k=5) + + builder.judge(model="nvidia_nim/mistralai/mixtral-8x22b-instruct-v0.1") + + judging_ops = [s for s in builder._steps if isinstance(s, JudgingOperator)] + assert len(judging_ops) == 1 + assert judging_ops[0]._judge._client.transport.num_retries == 3 + class TestApiKeyRedaction: """Guard the repr/str of every transport params object against key leakage.