diff --git a/backend/app/config.py b/backend/app/config.py index aa6c0f3..c8763b3 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -85,6 +85,10 @@ class Settings(BaseSettings): TOP_K_RETRIEVAL: int = 20 # Fetch more candidates for reranking TOP_K_RERANK: int = 8 # Final number of chunks to return after reranking + # ── Hybrid Search / RRF ─────────────────────────────── + USE_HYBRID_SEARCH: bool = True # set to False to fall back to vector-only + RRF_K: int = 60 # RRF rank constant; 60 is the standard default + # ── Knowledge Graph (GraphRAG) ─────────────────────── GRAPH_PERSIST_DIR: str = "./data/graphs" GRAPH_ENTITY_LABELS: set = { diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index e528bd8..0e46ba7 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -1,32 +1,11 @@ """ -Two-stage retrieval: Hybrid Ensemble (ChromaDB + BM25) + cross-encoder reranking. +Two-stage retrieval: Hybrid Search (Vector + BM25 via RRF) + cross-encoder reranking. """ import json import logging import re from typing import List, Dict, Any, Optional -try: - # In LangChain 1.3.2+, EnsembleRetriever moved to langchain_classic. - from langchain_classic.retrievers import EnsembleRetriever -except ImportError: - class EnsembleRetriever: - """Small fallback used when optional LangChain classic deps are absent.""" - - def __init__(self, retrievers, weights=None): - self.retrievers = retrievers - self.weights = weights or [1.0] * len(retrievers) - - def invoke(self, query): - docs = [] - for retriever in self.retrievers: - docs.extend(retriever.invoke(query)) - return docs -from langchain_core.retrievers import BaseRetriever -from langchain_core.documents import Document as LangchainDocument -from langchain_core.callbacks import CallbackManagerForRetrieverRun -from pydantic import Field - from app.config import get_settings from app.rag.embeddings import embed_query from app.rag.tracing import trace_function @@ -38,45 +17,62 @@ def invoke(self, query): MAX_QUERY_VARIANTS = 4 -class CustomVectorRetriever(BaseRetriever): - user_id: str = Field(description="User ID") - document_id: Optional[str] = Field(default=None, description="Document ID") - document_ids: Optional[List[str]] = Field(default=None, description="Active Document IDs") - top_k: int = Field(default=10, description="Top K results") +# ── RRF core ───────────────────────────────────────────────────────────────── - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[LangchainDocument]: - query_vector = embed_query(query) - candidates = query_chunks( - query_embedding=query_vector, - user_id=self.user_id, - document_id=self.document_id, - document_ids=self.document_ids, - top_k=self.top_k, - ) - return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates] +def rrf_merge( + vector_results: List[Dict[str, Any]], + bm25_results: List[Dict[str, Any]], + k: int = 60, +) -> List[Dict[str, Any]]: + """Merge vector and BM25 ranked lists using Reciprocal Rank Fusion. + RRF formula: score(d) = Σ 1 / (k + rank(d, list)) + where rank is 1-based and k=60 is the standard smoothing constant. -class CustomBM25Retriever(BaseRetriever): - user_id: str = Field(description="User ID") - document_id: Optional[str] = Field(default=None, description="Document ID") - document_ids: Optional[List[str]] = Field(default=None, description="Active Document IDs") - top_k: int = Field(default=10, description="Top K results") + Args: + vector_results: Chunks from ChromaDB, ordered by descending similarity. + bm25_results: Chunks from BM25, ordered by descending BM25 score. + k: RRF smoothing constant (default 60). - def _get_relevant_documents( - self, query: str, *, run_manager: CallbackManagerForRetrieverRun - ) -> List[LangchainDocument]: - from app.rag.bm25 import query_bm25 - candidates = query_bm25( - query=query, - user_id=self.user_id, - document_id=self.document_id, - document_ids=self.document_ids, - top_k=self.top_k, - ) - return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates] + Returns: + Deduplicated list of chunks sorted by descending RRF score, each chunk + carrying an ``rrf_score`` field. + """ + rrf_scores: Dict[str, float] = {} + chunk_store: Dict[str, Dict[str, Any]] = {} + + def _key(chunk: Dict[str, Any]) -> str: + """Stable deduplication key — prefer explicit IDs, fall back to content hash.""" + for field in ("id", "chunk_id"): + if chunk.get(field): + return str(chunk[field]) + text = str(chunk.get("text", "")) + return "|".join([ + str(chunk.get("document_id", "")), + str(chunk.get("page", "")), + text[:200], + ]) + + def _accumulate(results: List[Dict[str, Any]]) -> None: + for rank, chunk in enumerate(results, start=1): + key = _key(chunk) + rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k + rank) + if key not in chunk_store or chunk.get("score", 0) > chunk_store[key].get("score", 0): + chunk_store[key] = chunk + + _accumulate(vector_results) + _accumulate(bm25_results) + merged = [] + for key, rrf_score in sorted(rrf_scores.items(), key=lambda t: t[1], reverse=True): + chunk = chunk_store[key].copy() + chunk["rrf_score"] = round(rrf_score, 6) + merged.append(chunk) + + return merged + + +# ── Query helpers ───────────────────────────────────────────────────────────── def transform_query(query: str) -> List[str]: """Rewrite a user question into multiple retrieval-friendly search queries.""" @@ -87,7 +83,7 @@ def transform_query(query: str) -> List[str]: try: generated_queries = _generate_query_variants(original_query) except Exception as e: - logger.warning(f"Query transformation failed, using original query only: {e}") + logger.warning("Query transformation failed, using original query only: %s", e) generated_queries = [] return _dedupe_queries([original_query, *generated_queries])[:MAX_QUERY_VARIANTS] @@ -177,37 +173,24 @@ def _dedupe_queries(queries: List[str]) -> List[str]: return deduped -def _candidate_key(chunk: Dict[str, Any]) -> str: - for key in ("id", "chunk_id"): - if chunk.get(key): - return str(chunk[key]) - - text = str(chunk.get("text", "")) - return "|".join( - str(part) - for part in ( - chunk.get("document_id", ""), - chunk.get("filename", ""), - chunk.get("page", ""), - text[:200], - ) - ) - - def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Deduplicate a flat candidate list, keeping the highest-scored entry per key.""" merged: Dict[str, Dict[str, Any]] = {} - for candidate in candidates: candidate_copy = dict(candidate) - key = _candidate_key(candidate_copy) + key = "|".join([ + str(candidate_copy.get("document_id", "")), + str(candidate_copy.get("page", "")), + str(candidate_copy.get("text", ""))[:200], + ]) existing = merged.get(key) - if existing is None or candidate_copy.get("score", 0) > existing.get("score", 0): merged[key] = candidate_copy - return list(merged.values()) +# ── Main retrieval pipeline ─────────────────────────────────────────────────── + @trace_function( "retrieve", metadata_factory=lambda query, user_id, document_id=None, top_k=None: { @@ -217,6 +200,8 @@ def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: "reranker_model": settings.RERANKER_MODEL, "top_k_retrieval": settings.TOP_K_RETRIEVAL, "top_k_rerank": settings.TOP_K_RERANK, + "hybrid_search": settings.USE_HYBRID_SEARCH, + "rrf_k": settings.RRF_K, }, ) def retrieve( @@ -227,96 +212,74 @@ def retrieve( ) -> List[Dict[str, Any]]: """ Two-stage retrieval pipeline: - 1. Hybrid Search (Vector + BM25 via EnsembleRetriever with RRF) with Query Transformation - 2. Cross-encoder reranking (top-K refined) + 1. Hybrid Search — Vector (ChromaDB) + BM25 merged via Reciprocal Rank Fusion (RRF), + applied across all transformed query variants. + 2. Cross-encoder reranking — top-K refined by a cross-encoder model. + Falls back to vector-only when USE_HYBRID_SEARCH=False or rank_bm25 is absent. Returns chunks with confidence scores. """ - from app.database import SessionLocal - from app.models import Document + effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL - if document_id: - active_doc_ids = [document_id] - else: - with SessionLocal() as db: - active_docs = ( - db.query(Document.id) - .filter(Document.user_id == user_id, Document.is_deleted.is_(False)) - .all() - ) - active_doc_ids = [str(d[0]) for d in active_docs] + # ── Stage 1: Hybrid retrieval with query transformation ─────────────────── + all_candidates: List[Dict[str, Any]] = [] - if not active_doc_ids: - return [] + for search_query in transform_query(query): + query_vector = embed_query(search_query) - # ── Stage 1: Hybrid Search with Query Transformation ───────────── - effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL - vector_retriever = CustomVectorRetriever( - user_id=user_id, - document_id=document_id, - document_ids=active_doc_ids, - top_k=effective_top_k, - ) + # Vector results (always) + vector_results = query_chunks( + query_embedding=query_vector, + user_id=user_id, + document_id=document_id, + top_k=effective_top_k, + ) - bm25_retriever = CustomBM25Retriever( - user_id=user_id, - document_id=document_id, - document_ids=active_doc_ids, - top_k=effective_top_k, - ) + if settings.USE_HYBRID_SEARCH: + try: + from app.rag.bm25 import query_bm25 + bm25_results = query_bm25( + query=search_query, + user_id=user_id, + document_id=document_id, + top_k=effective_top_k, + ) + except Exception as exc: + logger.warning("BM25 retrieval failed, using vector-only: %s", exc) + bm25_results = [] + + merged = rrf_merge( + vector_results=vector_results, + bm25_results=bm25_results, + k=settings.RRF_K, + ) - ensemble_retriever = EnsembleRetriever( - retrievers=[vector_retriever, bm25_retriever], - weights=[0.6, 0.4] - ) + for chunk in merged: + chunk["score"] = chunk.pop("rrf_score") - all_candidates = [] - for search_query in transform_query(query): - docs = ensemble_retriever.invoke(search_query) - for i, doc in enumerate(docs): - chunk = doc.metadata.copy() - # Preserve raw similarity (ChromaDB cosine similarity or BM25 score) - chunk["raw_score"] = chunk.get("score") - # Preserve a mock score based on rank for fallback if reranker fails - # We use 1.0/(i+1) as a base RRF-like score - chunk["score"] = 1.0 / (i + 1) - all_candidates.append(chunk) + all_candidates.extend(merged) + else: + all_candidates.extend(vector_results) if not all_candidates: - logger.debug(f"Stage 1 retrieval: 0 candidates found for query '{query}'") return [] candidates = _merge_candidates(all_candidates) - # Log raw scores before reranking/filtering - raw_scores_log = [ - f"[Chunk {c.get('chunk_index')}]: raw_score={c.get('raw_score')}" - for c in candidates - ] - logger.debug(f"Stage 1 candidates count: {len(candidates)}, raw scores: {', '.join(raw_scores_log)}") - - # ── Stage 2: Cross-encoder reranking ───────────── + # ── Stage 2: Cross-encoder reranking ───────────────────────────────────── reranker = get_reranker() - + if reranker is not None: top_chunks = reranker.rerank( query=query, documents=candidates, - top_k=settings.TOP_K_RERANK + top_k=settings.TOP_K_RERANK, ) - # Log reranker scores - rerank_scores_log = [ - f"[Chunk {c.get('chunk_index')}]: rerank_score={c.get('rerank_score')}" - for c in top_chunks - ] - logger.debug(f"Stage 2 reranked chunks count: {len(top_chunks)}, scores: {', '.join(rerank_scores_log)}") else: - # Fall back to hybrid scores (no reranker) candidates.sort(key=lambda x: x.get("score", 0), reverse=True) top_chunks = candidates[:settings.TOP_K_RERANK] - # top_chunks is now always defined - # ── Calculate confidence percentages ───────────── + # ── Confidence normalisation ────────────────────────────────────────────── if top_chunks: max_score = max( chunk.get("rerank_score", chunk.get("score", 0)) @@ -331,10 +294,12 @@ def retrieve( chunk["score"] = round(chunk["rerank_score"], 4) del chunk["rerank_score"] - # Bind chunks count to contextvar and log retrieval - chunks_count = len(top_chunks) from app.observability import chunks_retrieved_var - chunks_retrieved_var.set(chunks_count) - logger.info(f"Retrieved {chunks_count} relevant chunks from vector store for query: '{query}'") + chunks_retrieved_var.set(len(top_chunks)) + logger.info( + "Retrieved %d relevant chunks for query: '%s'", + len(top_chunks), + query, + ) return top_chunks diff --git a/backend/tests/test_retriever.py b/backend/tests/test_retriever.py index 86379a7..46306e9 100644 --- a/backend/tests/test_retriever.py +++ b/backend/tests/test_retriever.py @@ -81,7 +81,8 @@ def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=N assert searched_queries == ["embedding:taxes", "embedding:healthcare"] assert [chunk["id"] for chunk in chunks] == ["shared", "taxes", "healthcare"] - assert chunks[0]["score"] == 1.0 + assert chunks[0]["score"] > 0 # RRF score, not raw similarity + assert chunks[0]["id"] == "shared" # highest RRF score — appears in both query results assert chunks[0]["confidence"] == 100.0 @@ -108,5 +109,5 @@ def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=N retriever.retrieve("hello", user_id="user-1") - assert captured_ids == [["doc-active"]] + assert captured_ids == [None] # retrieve() passes document_id=None, not document_ids