Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 63 additions & 10 deletions backend/app/rag/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,41 @@ def _accumulate(results: List[Dict[str, Any]]) -> None:
return merged


def rrf_merge_multiple(
ranked_lists: List[List[Dict[str, Any]]],
k: int = 60,
) -> List[Dict[str, Any]]:
"""Merge multiple ranked lists using Reciprocal Rank Fusion."""
rrf_scores: Dict[str, float] = {}
chunk_store: Dict[str, Dict[str, Any]] = {}

def _key(chunk: Dict[str, Any]) -> str:
for field in ("id", "chunk_id"):
if chunk.get(field):
return str(chunk[field])
text = str(chunk.get("text", ""))
return "|".join([
str(chunk.get("document_id", "")),
str(chunk.get("page", "")),
text[:200],
])

for lst in ranked_lists:
for rank, chunk in enumerate(lst, start=1):
key = _key(chunk)
rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k + rank)
if key not in chunk_store or chunk.get("score", 0) > chunk_store[key].get("score", 0):
chunk_store[key] = chunk

merged = []
for key, rrf_score in sorted(rrf_scores.items(), key=lambda t: t[1], reverse=True):
chunk = chunk_store[key].copy()
chunk["rrf_score"] = round(rrf_score, 6)
merged.append(chunk)

return merged


# ── Query helpers ─────────────────────────────────────────────────────────────

def transform_query(query: str) -> List[str]:
Expand All @@ -98,15 +133,18 @@ def _generate_query_variants(query: str) -> List[str]:

client = InferenceClient(token=settings.HF_TOKEN)
prompt = (
"Rewrite the user question into concise semantic search queries for document retrieval. "
"Split independent topics into separate queries. Return a JSON array of strings only. "
"Decompose the user's complex multi-part question into simple, distinct semantic sub-queries. "
"Each sub-query should focus on a single question, topic, or comparison. "
"Return a JSON array of strings only. "
"Example question: 'Compare treatment A and treatment B for diabetes'\n"
"Example output: [\"treatment A for diabetes\", \"treatment B for diabetes\", \"diabetes treatments comparison\"]\n"
f"User question: {query}"
)
response = client.chat_completion(
messages=[
{
"role": "system",
"content": "You create optimized search queries for a RAG retriever.",
"content": "You decompose complex search queries into list of search sub-queries for a RAG retriever.",
},
{"role": "user", "content": prompt},
],
Expand Down Expand Up @@ -221,10 +259,11 @@ def retrieve(
"""
effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL

# ── Stage 1: Hybrid retrieval with query transformation ───────────────────
all_candidates: List[Dict[str, Any]] = []
# ── Stage 1: Parallel retrieval of sub-queries and RRF merging ───────────
sub_queries = transform_query(query)
sub_query_results: List[List[Dict[str, Any]]] = []

for search_query in transform_query(query):
def retrieve_single_query(search_query: str) -> List[Dict[str, Any]]:
query_vector = embed_query(search_query)

# Vector results (always)
Expand Down Expand Up @@ -257,14 +296,28 @@ def retrieve(
for chunk in merged:
chunk["score"] = chunk.pop("rrf_score")

all_candidates.extend(merged)
return merged
else:
all_candidates.extend(vector_results)
return vector_results

import concurrent.futures
with concurrent.futures.ThreadPoolExecutor(max_workers=len(sub_queries) or 1) as executor:
future_to_query = {executor.submit(retrieve_single_query, sq): sq for sq in sub_queries}
for future in concurrent.futures.as_completed(future_to_query):
try:
results = future.result()
sub_query_results.append(results)
except Exception as e:
sq = future_to_query[future]
logger.error("Failed retrieval for sub-query '%s': %s", sq, e)

if not all_candidates:
if not sub_query_results:
return []

candidates = _merge_candidates(all_candidates)
# Merge all sub-query candidate lists using generalized RRF
candidates = rrf_merge_multiple(sub_query_results, k=settings.RRF_K)
for chunk in candidates:
chunk["score"] = chunk.pop("rrf_score")

# ── Stage 2: Cross-encoder reranking ─────────────────────────────────────
reranker = get_reranker()
Expand Down
Loading