diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index 0e46ba7..1c303f3 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -72,6 +72,41 @@ def _accumulate(results: List[Dict[str, Any]]) -> None: return merged +def rrf_merge_multiple( + ranked_lists: List[List[Dict[str, Any]]], + k: int = 60, +) -> List[Dict[str, Any]]: + """Merge multiple ranked lists using Reciprocal Rank Fusion.""" + rrf_scores: Dict[str, float] = {} + chunk_store: Dict[str, Dict[str, Any]] = {} + + def _key(chunk: Dict[str, Any]) -> str: + for field in ("id", "chunk_id"): + if chunk.get(field): + return str(chunk[field]) + text = str(chunk.get("text", "")) + return "|".join([ + str(chunk.get("document_id", "")), + str(chunk.get("page", "")), + text[:200], + ]) + + for lst in ranked_lists: + for rank, chunk in enumerate(lst, start=1): + key = _key(chunk) + rrf_scores[key] = rrf_scores.get(key, 0.0) + 1.0 / (k + rank) + if key not in chunk_store or chunk.get("score", 0) > chunk_store[key].get("score", 0): + chunk_store[key] = chunk + + merged = [] + for key, rrf_score in sorted(rrf_scores.items(), key=lambda t: t[1], reverse=True): + chunk = chunk_store[key].copy() + chunk["rrf_score"] = round(rrf_score, 6) + merged.append(chunk) + + return merged + + # ── Query helpers ───────────────────────────────────────────────────────────── def transform_query(query: str) -> List[str]: @@ -98,15 +133,18 @@ def _generate_query_variants(query: str) -> List[str]: client = InferenceClient(token=settings.HF_TOKEN) prompt = ( - "Rewrite the user question into concise semantic search queries for document retrieval. " - "Split independent topics into separate queries. Return a JSON array of strings only. " + "Decompose the user's complex multi-part question into simple, distinct semantic sub-queries. " + "Each sub-query should focus on a single question, topic, or comparison. " + "Return a JSON array of strings only. " + "Example question: 'Compare treatment A and treatment B for diabetes'\n" + "Example output: [\"treatment A for diabetes\", \"treatment B for diabetes\", \"diabetes treatments comparison\"]\n" f"User question: {query}" ) response = client.chat_completion( messages=[ { "role": "system", - "content": "You create optimized search queries for a RAG retriever.", + "content": "You decompose complex search queries into list of search sub-queries for a RAG retriever.", }, {"role": "user", "content": prompt}, ], @@ -221,10 +259,11 @@ def retrieve( """ effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL - # ── Stage 1: Hybrid retrieval with query transformation ─────────────────── - all_candidates: List[Dict[str, Any]] = [] + # ── Stage 1: Parallel retrieval of sub-queries and RRF merging ─────────── + sub_queries = transform_query(query) + sub_query_results: List[List[Dict[str, Any]]] = [] - for search_query in transform_query(query): + def retrieve_single_query(search_query: str) -> List[Dict[str, Any]]: query_vector = embed_query(search_query) # Vector results (always) @@ -257,14 +296,28 @@ def retrieve( for chunk in merged: chunk["score"] = chunk.pop("rrf_score") - all_candidates.extend(merged) + return merged else: - all_candidates.extend(vector_results) + return vector_results + + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor(max_workers=len(sub_queries) or 1) as executor: + future_to_query = {executor.submit(retrieve_single_query, sq): sq for sq in sub_queries} + for future in concurrent.futures.as_completed(future_to_query): + try: + results = future.result() + sub_query_results.append(results) + except Exception as e: + sq = future_to_query[future] + logger.error("Failed retrieval for sub-query '%s': %s", sq, e) - if not all_candidates: + if not sub_query_results: return [] - candidates = _merge_candidates(all_candidates) + # Merge all sub-query candidate lists using generalized RRF + candidates = rrf_merge_multiple(sub_query_results, k=settings.RRF_K) + for chunk in candidates: + chunk["score"] = chunk.pop("rrf_score") # ── Stage 2: Cross-encoder reranking ───────────────────────────────────── reranker = get_reranker()