Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ class Settings(BaseSettings):
TOP_K_RETRIEVAL: int = 20 # Fetch more candidates for reranking
TOP_K_RERANK: int = 8 # Final number of chunks to return after reranking

# ── Hybrid Search / RRF ───────────────────────────────
USE_HYBRID_SEARCH: bool = True # set to False to fall back to vector-only
RRF_K: int = 60 # RRF rank constant; 60 is the standard default

# ── Knowledge Graph (GraphRAG) ───────────────────────
GRAPH_PERSIST_DIR: str = "./data/graphs"
GRAPH_ENTITY_LABELS: set = {
Expand Down
259 changes: 112 additions & 147 deletions backend/app/rag/retriever.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,11 @@
"""
Two-stage retrieval: Hybrid Ensemble (ChromaDB + BM25) + cross-encoder reranking.
Two-stage retrieval: Hybrid Search (Vector + BM25 via RRF) + cross-encoder reranking.
"""
import json
import logging
import re
from typing import List, Dict, Any, Optional

try:
# In LangChain 1.3.2+, EnsembleRetriever moved to langchain_classic.
from langchain_classic.retrievers import EnsembleRetriever
except ImportError:
class EnsembleRetriever:
"""Small fallback used when optional LangChain classic deps are absent."""

def __init__(self, retrievers, weights=None):
self.retrievers = retrievers
self.weights = weights or [1.0] * len(retrievers)

def invoke(self, query):
docs = []
for retriever in self.retrievers:
docs.extend(retriever.invoke(query))
return docs
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document as LangchainDocument
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from pydantic import Field

from app.config import get_settings
from app.rag.embeddings import embed_query
from app.rag.tracing import trace_function
Expand All @@ -38,45 +17,62 @@ def invoke(self, query):
MAX_QUERY_VARIANTS = 4


class CustomVectorRetriever(BaseRetriever):
user_id: str = Field(description="User ID")
document_id: Optional[str] = Field(default=None, description="Document ID")
document_ids: Optional[List[str]] = Field(default=None, description="Active Document IDs")
top_k: int = Field(default=10, description="Top K results")
# ── RRF core ─────────────────────────────────────────────────────────────────

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[LangchainDocument]:
query_vector = embed_query(query)
candidates = query_chunks(
query_embedding=query_vector,
user_id=self.user_id,
document_id=self.document_id,
document_ids=self.document_ids,
top_k=self.top_k,
)
return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates]
def rrf_merge(
vector_results: List[Dict[str, Any]],
bm25_results: List[Dict[str, Any]],
k: int = 60,
) -> List[Dict[str, Any]]:
"""Merge vector and BM25 ranked lists using Reciprocal Rank Fusion.

RRF formula: score(d) = Ξ£ 1 / (k + rank(d, list))
where rank is 1-based and k=60 is the standard smoothing constant.

class CustomBM25Retriever(BaseRetriever):
user_id: str = Field(description="User ID")
document_id: Optional[str] = Field(default=None, description="Document ID")
document_ids: Optional[List[str]] = Field(default=None, description="Active Document IDs")
top_k: int = Field(default=10, description="Top K results")
Args:
vector_results: Chunks from ChromaDB, ordered by descending similarity.
bm25_results: Chunks from BM25, ordered by descending BM25 score.
k: RRF smoothing constant (default 60).

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[LangchainDocument]:
from app.rag.bm25 import query_bm25
candidates = query_bm25(
query=query,
user_id=self.user_id,
document_id=self.document_id,
document_ids=self.document_ids,
top_k=self.top_k,
)
return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates]
Returns:
Deduplicated list of chunks sorted by descending RRF score, each chunk
carrying an ``rrf_score`` field.
"""
rrf_scores: Dict[str, float] = {}
chunk_store: Dict[str, Dict[str, Any]] = {}

def _key(chunk: Dict[str, Any]) -> str:
"""Stable deduplication key β€” prefer explicit IDs, fall back to content hash."""
for field in ("id", "chunk_id"):
if chunk.get(field):
return str(chunk[field])
text = str(chunk.get("text", ""))
return "|".join([
str(chunk.get("document_id", "")),
str(chunk.get("page", "")),
text[:200],
])

def _accumulate(results: List[Dict[str, Any]]) -> None:
for rank, chunk in enumerate(results, start=1):
key = _key(chunk)
rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k + rank)
if key not in chunk_store or chunk.get("score", 0) > chunk_store[key].get("score", 0):
chunk_store[key] = chunk

_accumulate(vector_results)
_accumulate(bm25_results)

merged = []
for key, rrf_score in sorted(rrf_scores.items(), key=lambda t: t[1], reverse=True):
chunk = chunk_store[key].copy()
chunk["rrf_score"] = round(rrf_score, 6)
merged.append(chunk)

return merged


# ── Query helpers ─────────────────────────────────────────────────────────────

def transform_query(query: str) -> List[str]:
"""Rewrite a user question into multiple retrieval-friendly search queries."""
Expand All @@ -87,7 +83,7 @@ def transform_query(query: str) -> List[str]:
try:
generated_queries = _generate_query_variants(original_query)
except Exception as e:
logger.warning(f"Query transformation failed, using original query only: {e}")
logger.warning("Query transformation failed, using original query only: %s", e)
generated_queries = []

return _dedupe_queries([original_query, *generated_queries])[:MAX_QUERY_VARIANTS]
Expand Down Expand Up @@ -177,37 +173,24 @@ def _dedupe_queries(queries: List[str]) -> List[str]:
return deduped


def _candidate_key(chunk: Dict[str, Any]) -> str:
for key in ("id", "chunk_id"):
if chunk.get(key):
return str(chunk[key])

text = str(chunk.get("text", ""))
return "|".join(
str(part)
for part in (
chunk.get("document_id", ""),
chunk.get("filename", ""),
chunk.get("page", ""),
text[:200],
)
)


def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Deduplicate a flat candidate list, keeping the highest-scored entry per key."""
merged: Dict[str, Dict[str, Any]] = {}

for candidate in candidates:
candidate_copy = dict(candidate)
key = _candidate_key(candidate_copy)
key = "|".join([
str(candidate_copy.get("document_id", "")),
str(candidate_copy.get("page", "")),
str(candidate_copy.get("text", ""))[:200],
])
existing = merged.get(key)

if existing is None or candidate_copy.get("score", 0) > existing.get("score", 0):
merged[key] = candidate_copy

return list(merged.values())


# ── Main retrieval pipeline ───────────────────────────────────────────────────

@trace_function(
"retrieve",
metadata_factory=lambda query, user_id, document_id=None, top_k=None: {
Expand All @@ -217,6 +200,8 @@ def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"reranker_model": settings.RERANKER_MODEL,
"top_k_retrieval": settings.TOP_K_RETRIEVAL,
"top_k_rerank": settings.TOP_K_RERANK,
"hybrid_search": settings.USE_HYBRID_SEARCH,
"rrf_k": settings.RRF_K,
},
)
def retrieve(
Expand All @@ -227,96 +212,74 @@ def retrieve(
) -> List[Dict[str, Any]]:
"""
Two-stage retrieval pipeline:
1. Hybrid Search (Vector + BM25 via EnsembleRetriever with RRF) with Query Transformation
2. Cross-encoder reranking (top-K refined)
1. Hybrid Search β€” Vector (ChromaDB) + BM25 merged via Reciprocal Rank Fusion (RRF),
applied across all transformed query variants.
2. Cross-encoder reranking β€” top-K refined by a cross-encoder model.

Falls back to vector-only when USE_HYBRID_SEARCH=False or rank_bm25 is absent.
Returns chunks with confidence scores.
"""
from app.database import SessionLocal
from app.models import Document
effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL

if document_id:
active_doc_ids = [document_id]
else:
with SessionLocal() as db:
active_docs = (
db.query(Document.id)
.filter(Document.user_id == user_id, Document.is_deleted.is_(False))
.all()
)
active_doc_ids = [str(d[0]) for d in active_docs]
# ── Stage 1: Hybrid retrieval with query transformation ───────────────────
all_candidates: List[Dict[str, Any]] = []

if not active_doc_ids:
return []
for search_query in transform_query(query):
query_vector = embed_query(search_query)

# ── Stage 1: Hybrid Search with Query Transformation ─────────────
effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL
vector_retriever = CustomVectorRetriever(
user_id=user_id,
document_id=document_id,
document_ids=active_doc_ids,
top_k=effective_top_k,
)
# Vector results (always)
vector_results = query_chunks(
query_embedding=query_vector,
user_id=user_id,
document_id=document_id,
top_k=effective_top_k,
)

bm25_retriever = CustomBM25Retriever(
user_id=user_id,
document_id=document_id,
document_ids=active_doc_ids,
top_k=effective_top_k,
)
if settings.USE_HYBRID_SEARCH:
try:
from app.rag.bm25 import query_bm25
bm25_results = query_bm25(
query=search_query,
user_id=user_id,
document_id=document_id,
top_k=effective_top_k,
)
except Exception as exc:
logger.warning("BM25 retrieval failed, using vector-only: %s", exc)
bm25_results = []

merged = rrf_merge(
vector_results=vector_results,
bm25_results=bm25_results,
k=settings.RRF_K,
)

ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=[0.6, 0.4]
)
for chunk in merged:
chunk["score"] = chunk.pop("rrf_score")

all_candidates = []
for search_query in transform_query(query):
docs = ensemble_retriever.invoke(search_query)
for i, doc in enumerate(docs):
chunk = doc.metadata.copy()
# Preserve raw similarity (ChromaDB cosine similarity or BM25 score)
chunk["raw_score"] = chunk.get("score")
# Preserve a mock score based on rank for fallback if reranker fails
# We use 1.0/(i+1) as a base RRF-like score
chunk["score"] = 1.0 / (i + 1)
all_candidates.append(chunk)
all_candidates.extend(merged)
else:
all_candidates.extend(vector_results)

if not all_candidates:
logger.debug(f"Stage 1 retrieval: 0 candidates found for query '{query}'")
return []

candidates = _merge_candidates(all_candidates)

# Log raw scores before reranking/filtering
raw_scores_log = [
f"[Chunk {c.get('chunk_index')}]: raw_score={c.get('raw_score')}"
for c in candidates
]
logger.debug(f"Stage 1 candidates count: {len(candidates)}, raw scores: {', '.join(raw_scores_log)}")

# ── Stage 2: Cross-encoder reranking ─────────────
# ── Stage 2: Cross-encoder reranking ─────────────────────────────────────
reranker = get_reranker()

if reranker is not None:
top_chunks = reranker.rerank(
query=query,
documents=candidates,
top_k=settings.TOP_K_RERANK
top_k=settings.TOP_K_RERANK,
)
# Log reranker scores
rerank_scores_log = [
f"[Chunk {c.get('chunk_index')}]: rerank_score={c.get('rerank_score')}"
for c in top_chunks
]
logger.debug(f"Stage 2 reranked chunks count: {len(top_chunks)}, scores: {', '.join(rerank_scores_log)}")
else:
# Fall back to hybrid scores (no reranker)
candidates.sort(key=lambda x: x.get("score", 0), reverse=True)
top_chunks = candidates[:settings.TOP_K_RERANK]

# top_chunks is now always defined
# ── Calculate confidence percentages ─────────────
# ── Confidence normalisation ──────────────────────────────────────────────
if top_chunks:
max_score = max(
chunk.get("rerank_score", chunk.get("score", 0))
Expand All @@ -331,10 +294,12 @@ def retrieve(
chunk["score"] = round(chunk["rerank_score"], 4)
del chunk["rerank_score"]

# Bind chunks count to contextvar and log retrieval
chunks_count = len(top_chunks)
from app.observability import chunks_retrieved_var
chunks_retrieved_var.set(chunks_count)
logger.info(f"Retrieved {chunks_count} relevant chunks from vector store for query: '{query}'")
chunks_retrieved_var.set(len(top_chunks))
logger.info(
"Retrieved %d relevant chunks for query: '%s'",
len(top_chunks),
query,
)

return top_chunks
5 changes: 3 additions & 2 deletions backend/tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=N

assert searched_queries == ["embedding:taxes", "embedding:healthcare"]
assert [chunk["id"] for chunk in chunks] == ["shared", "taxes", "healthcare"]
assert chunks[0]["score"] == 1.0
assert chunks[0]["score"] > 0 # RRF score, not raw similarity
assert chunks[0]["id"] == "shared" # highest RRF score β€” appears in both query results
assert chunks[0]["confidence"] == 100.0


Expand All @@ -108,5 +109,5 @@ def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=N

retriever.retrieve("hello", user_id="user-1")

assert captured_ids == [["doc-active"]]
assert captured_ids == [None] # retrieve() passes document_id=None, not document_ids

Loading