From 85113e9e638a80296e6965df9a3c089d71c33141 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Fri, 3 Apr 2026 01:48:16 -0400 Subject: [PATCH 1/2] Implement Multi-Query Retrieval and fix related bugs --- src/config.py | 2 ++ src/generator.py | 2 +- src/main.py | 76 ++++++++++++++++++++++++++++------------ src/query_enhancement.py | 3 +- 4 files changed, 59 insertions(+), 24 deletions(-) diff --git a/src/config.py b/src/config.py index d296b9c4..a22decb6 100644 --- a/src/config.py +++ b/src/config.py @@ -44,6 +44,8 @@ class RAGConfig: use_hyde: bool = False hyde_max_tokens: int = 300 use_double_prompt: bool = False + use_query_expansion: bool = False + query_expansion_max_tokens: int = 64 # conversational memory enable_history: bool = True diff --git a/src/generator.py b/src/generator.py index 8e8b3181..76aa8385 100644 --- a/src/generator.py +++ b/src/generator.py @@ -110,7 +110,7 @@ def format_prompt(chunks, query, max_chunk_chars=400, system_prompt_mode="tutor" _LLM_CACHE = {} -def get_llama_model(model_path: str, n_ctx: int = 4096): +def get_llama_model(model_path: str, n_ctx: int = 8192): if model_path not in _LLM_CACHE: try: _LLM_CACHE[model_path] = Llama(model_path=model_path, diff --git a/src/main.py b/src/main.py index d14b8a23..247cee2b 100644 --- a/src/main.py +++ b/src/main.py @@ -17,7 +17,7 @@ from src.instrumentation.logging import get_logger from src.ranking.ranker import EnsembleRanker from src.preprocessing.chunking import DocumentChunker -from src.query_enhancement import generate_hypothetical_document, contextualize_query +from src.query_enhancement import generate_hypothetical_document, contextualize_query, expand_query_with_keywords from src.retriever import ( filter_retrieved_chunks, BM25Retriever, @@ -47,6 +47,11 @@ def parse_args() -> argparse.Namespace: action="store_true", help="enable double prompting for higher quality answers" ) + parser.add_argument( + "--use_query_expansion", + action="store_true", + help="enable multi-query retrieval (query expansion) for higher accuracy" + ) return parser.parse_args() @@ -131,38 +136,65 @@ def get_answer( elif cfg.use_indexed_chunks: ranked_chunks, topk_idxs = use_indexed_chunks(question, chunks) else: - retrieval_query = question - # print(f"Retrieval query: {retrieval_query}") - if cfg.use_hyde: - retrieval_query = generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens) + use_expansion = getattr(args, "use_query_expansion", False) or cfg.use_query_expansion + + if use_expansion: + queries_to_run = expand_query_with_keywords(question, cfg.gen_model, max_tokens=cfg.query_expansion_max_tokens) + elif cfg.use_hyde: + queries_to_run = [generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens)] + else: + queries_to_run = [question] pool_n = max(cfg.num_candidates, cfg.top_k + 10) raw_scores: Dict[str, Dict[int, float]] = {} - for retriever in retrievers: - # print(f"Getting scores from retriever: {retriever.name}...") - raw_scores[retriever.name] = retriever.get_scores(retrieval_query, pool_n, chunks) - # TODO: Fix retrieval logging. - - # print("Raw scores from retrievers:") - # for retriever_name, score_dict in raw_scores.items(): - # print(f" {retriever_name}: {list(score_dict.values())}") + + # Store original weights so we can adjust dynamically for RRF if multiple queries are used + original_weights = ranker.weights.copy() + + for q_idx, q_variant in enumerate(queries_to_run): + for retriever in retrievers: + scores = retriever.get_scores(q_variant, pool_n, chunks) + + if len(queries_to_run) == 1: + raw_scores[retriever.name] = scores + else: + r_key = f"{retriever.name}_{q_idx}" + raw_scores[r_key] = scores + if r_key not in ranker.weights: + ranker.weights[r_key] = original_weights.get(retriever.name, 0.0) / len(queries_to_run) + + if len(queries_to_run) > 1: + for retriever in retrievers: + ranker.weights[retriever.name] = 0.0 + # Step 2: Ranking ordered, scores = ranker.rank(raw_scores=raw_scores) - # print(f"Ordered candidate indices after ranking: {ordered[:cfg.top_k]}") - # print(f"Corresponding scores: {scores[:cfg.top_k]}") topk_idxs = filter_retrieved_chunks(cfg, chunks, ordered) ranked_chunks = [chunks[i] for i in topk_idxs] - # print(f"Top-{cfg.top_k} chunk indices after filtering: {topk_idxs}") - # print("Len Ranked chunks:", len(ranked_chunks)) - # print("Example ranked chunk content:", ranked_chunks[0] if ranked_chunks else "No chunks retrieved") + # Restore old ranker weights to keep things clean + ranker.weights = original_weights # Capture chunk info if in test mode if is_test_mode: - # Compute individual ranker ranks - faiss_scores = raw_scores.get("faiss", {}) - bm25_scores = raw_scores.get("bm25", {}) - index_scores = raw_scores.get("index_keywords", {}) + from collections import defaultdict + if len(queries_to_run) > 1: + agg_faiss = defaultdict(float) + agg_bm25 = defaultdict(float) + agg_index = defaultdict(float) + for rv, s_dict in raw_scores.items(): + base_rv = rv.rsplit('_', 1)[0] + for cand, sc in s_dict.items(): + if base_rv == "faiss": agg_faiss[cand] += sc + elif base_rv == "bm25": agg_bm25[cand] += sc + elif base_rv == "index_keywords": agg_index[cand] += sc + faiss_scores = dict(agg_faiss) + bm25_scores = dict(agg_bm25) + index_scores = dict(agg_index) + else: + faiss_scores = raw_scores.get("faiss", {}) + bm25_scores = raw_scores.get("bm25", {}) + index_scores = raw_scores.get("index_keywords", {}) faiss_ranked = sorted(faiss_scores.keys(), key=lambda i: faiss_scores[i], reverse=True) bm25_ranked = sorted(bm25_scores.keys(), key=lambda i: bm25_scores[i], reverse=True) diff --git a/src/query_enhancement.py b/src/query_enhancement.py index 4b0705a1..586a79f3 100644 --- a/src/query_enhancement.py +++ b/src/query_enhancement.py @@ -44,10 +44,11 @@ def generate_hypothetical_document( prompt, model_path, max_tokens=max_tokens, + temperature=0.7, **llm_kwargs ) - return hypothetical.strip() + return hypothetical["choices"][0]["text"].strip() def correct_query_grammar( query: str, From 93a84c62eaf19c1d93ab3e4ba5ee44a91e119829 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 9 Apr 2026 21:41:50 -0400 Subject: [PATCH 2/2] update weights for primary query and better enhancement --- config/config.yaml | 3 ++- src/main.py | 10 +++++++--- src/query_enhancement.py | 41 ++++++++++++++++++++++++++++++++++++++++ src/ranking/reranker.py | 2 +- tests/test_api.py | 23 ++++++++++++++++------ 5 files changed, 68 insertions(+), 11 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 262fb5f7..0e8cdb0b 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,7 +4,7 @@ num_candidates: 50 ensemble_method: "rrf" ranker_weights: {"faiss":1,"bm25":0,"index_keywords":0} rrf_k: 60 -max_gen_tokens: 400 +max_gen_tokens: 1024 chunk_mode : "recursive_sections" gen_model: "models/qwen2.5-3b-instruct-q8_0.gguf" chunk_size: 2000 @@ -17,3 +17,4 @@ rerank_top_k: 5 use_double_prompt: false enable_history: true max_history_turns: 3 +use_query_expansion: false diff --git a/src/main.py b/src/main.py index 247cee2b..d928c701 100644 --- a/src/main.py +++ b/src/main.py @@ -17,7 +17,7 @@ from src.instrumentation.logging import get_logger from src.ranking.ranker import EnsembleRanker from src.preprocessing.chunking import DocumentChunker -from src.query_enhancement import generate_hypothetical_document, contextualize_query, expand_query_with_keywords +from src.query_enhancement import generate_hypothetical_document, contextualize_query, expand_query_with_keywords, expand_query_with_perspectives from src.retriever import ( filter_retrieved_chunks, BM25Retriever, @@ -139,7 +139,7 @@ def get_answer( use_expansion = getattr(args, "use_query_expansion", False) or cfg.use_query_expansion if use_expansion: - queries_to_run = expand_query_with_keywords(question, cfg.gen_model, max_tokens=cfg.query_expansion_max_tokens) + queries_to_run = expand_query_with_perspectives(question, cfg.gen_model, max_tokens=cfg.query_expansion_max_tokens) elif cfg.use_hyde: queries_to_run = [generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens)] else: @@ -161,7 +161,11 @@ def get_answer( r_key = f"{retriever.name}_{q_idx}" raw_scores[r_key] = scores if r_key not in ranker.weights: - ranker.weights[r_key] = original_weights.get(retriever.name, 0.0) / len(queries_to_run) + base_weight = original_weights.get(retriever.name, 0.0) + if q_idx == 0: + ranker.weights[r_key] = base_weight * 0.5 + else: + ranker.weights[r_key] = (base_weight * 0.5) / (len(queries_to_run) - 1) if len(queries_to_run) > 1: for retriever in retrievers: diff --git a/src/query_enhancement.py b/src/query_enhancement.py index 586a79f3..28990859 100644 --- a/src/query_enhancement.py +++ b/src/query_enhancement.py @@ -127,6 +127,47 @@ def expand_query_with_keywords( return query_lines +def expand_query_with_perspectives( + query: str, + model_path: str, + max_tokens: int = 64, + **llm_kwargs +) -> str: + """ + Query Expansion: Generates queries from different technical perspectives. + Provides broader context for retrieval to achieve better performance. + """ + prompt = textwrap.dedent(f"""\ + <|im_start|>system + You are a search optimization expert. + Generate 3 alternative versions of the user's query that approach the topic from different technical perspectives or use different specialized terminology. + This provides broader context for retrieval. Output the alternative queries separated by newlines. Do not provide explanations or numberings. + <|im_end|> + <|im_start|>user + Query: {query} + <|im_end|> + <|im_start|>assistant + """) + + prompt = text_cleaning(prompt) + expansion = run_llama_cpp( + prompt, + model_path, + max_tokens=max_tokens, + temperature=0.5, + **llm_kwargs + ) + + # Combine original query with expansion + query_lines = [query] + query_lines.extend([line.strip() for line in expansion["choices"][0]["text"].split('\n') if line.strip()]) + + # Remove numbering if present + query_lines = [line.split('.', 1)[-1].strip() if '.' in line[:3] else line for line in query_lines] + + return query_lines + + def decompose_complex_query( query: str, model_path: str, diff --git a/src/ranking/reranker.py b/src/ranking/reranker.py index 58e20195..91b97059 100644 --- a/src/ranking/reranker.py +++ b/src/ranking/reranker.py @@ -40,7 +40,7 @@ def rerank_with_cross_encoder(query: str, chunks: List[str], top_n: int) -> List chunk_with_scores = list(zip(chunks, scores)) chunk_with_scores.sort(key=lambda x: x[1], reverse=True) - return chunk_with_scores[:top_n] + return [chunk for chunk, score in chunk_with_scores[:top_n]] # -------------------------- Reranking Router ----------------------------- diff --git a/tests/test_api.py b/tests/test_api.py index f838c290..c8a31912 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -534,16 +534,12 @@ def test_generate_hypothetical_document(self, mock_llm): """generate_hypothetical_document returns string.""" from src.query_enhancement import generate_hypothetical_document - # Note: The function calls .strip() on run_llama_cpp result. - # run_llama_cpp returns a dict, so there's an inconsistency in the source. - # For API testing, mock to return a string (what the function expects). - mock_llm.return_value = "A hypothetical answer about databases." + mock_llm.return_value = {"choices": [{"text": "A hypothetical answer about databases."}]} result = generate_hypothetical_document( "What is a transaction?", model_path="mock_model", - max_tokens=100, - temperature=0.5 + max_tokens=100 ) assert isinstance(result, str) @@ -592,6 +588,21 @@ def test_decompose_complex_query(self, mock_llm): assert isinstance(result, list) + @patch('src.query_enhancement.run_llama_cpp') + def test_expand_query_with_perspectives(self, mock_llm): + """expand_query_with_perspectives returns list of specific perspectives.""" + from src.query_enhancement import expand_query_with_perspectives + + mock_llm.return_value = {"choices": [{"text": "1. Database systems from an architecture perspective\n2. Data storage from a hardware perspective"}]} + + result = expand_query_with_perspectives( + "database architecture", + model_path="mock_model" + ) + + assert isinstance(result, list) + assert len(result) > 0 + # ====================== Load Artifacts Tests ======================