From 2d68a87e3bc16f7278119e4b20350d1fbc6bca29 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Fri, 6 Feb 2026 01:47:27 -0500 Subject: [PATCH 01/20] add semantic caching with Bi-Direction Encoder and Cross Encoder --- src/cache.py | 169 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.py | 64 ++++++++++++++++++- 2 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 src/cache.py diff --git a/src/cache.py b/src/cache.py new file mode 100644 index 00000000..97e58b77 --- /dev/null +++ b/src/cache.py @@ -0,0 +1,169 @@ +import argparse +import json +import hashlib +from typing import Dict, Optional, Any, List +import numpy as np +from src.embedder import SentenceTransformer +from src.config import RAGConfig +from src.retriever import filter_retrieved_chunks, BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts +from sentence_transformers import CrossEncoder + +_SEMANTIC_CACHE: Dict[str, List[Dict[str, Any]]] = {} +_SEMANTIC_CACHE_THRESHOLD = 0.85 +_SEMANTIC_CACHE_MAX_ENTRIES = 50 +_QUESTION_EMBEDDERS: Dict[str, SentenceTransformer] = {} + + + +# Add to your global variables +_CROSS_ENCODER_MODEL: Optional[CrossEncoder] = None + +def _get_cross_encoder(): + global _CROSS_ENCODER_MODEL + if _CROSS_ENCODER_MODEL is None: + # A small, fast model ideal for caching verification + _CROSS_ENCODER_MODEL = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') + return _CROSS_ENCODER_MODEL + + +def _normalize_question(q: str) -> str: + return " ".join((q or "").strip().lower().split()) + + +def _make_cache_config_key( + cfg: RAGConfig, + args: argparse.Namespace, + golden_chunks: Optional[list], +) -> str: + payload = { + "gen_model": args.model_path or cfg.gen_model, + "embed_model": cfg.embed_model, + "top_k": cfg.top_k, + "system_prompt_mode": args.system_prompt_mode or cfg.system_prompt_mode, + "ensemble_method": cfg.ensemble_method, + "ranker_weights": cfg.ranker_weights, + "use_hyde": cfg.use_hyde, + "use_indexed_chunks": cfg.use_indexed_chunks, + "disable_chunks": cfg.disable_chunks, + "use_golden_chunks": bool(golden_chunks and cfg.use_golden_chunks), + "index_prefix": getattr(args, "index_prefix", None), + } + if golden_chunks and cfg.use_golden_chunks: + sig = hashlib.sha256("||".join(golden_chunks).encode("utf-8")).hexdigest() + payload["golden_signature"] = sig + return json.dumps(payload, sort_keys=True) + + +# def _semantic_cache_lookup(config_key: str, query_embedding: np.ndarray): +# entries = _SEMANTIC_CACHE.get(config_key) or [] +# if not entries or query_embedding is None: +# return None +# best_entry = None +# best_score = -1.0 +# for entry in entries: +# cached_vec = entry.get("embedding") +# if cached_vec is None: +# continue +# sim = float(np.dot(cached_vec, query_embedding)) +# if sim > best_score: +# best_score = sim +# best_entry = entry +# if best_entry and best_score >= _SEMANTIC_CACHE_THRESHOLD: +# return best_entry["payload"] +# return None + +def _semantic_cache_lookup( + config_key: str, + query_embedding: np.ndarray, + current_question: str # New parameter +): + entries = _SEMANTIC_CACHE.get(config_key) or [] + if not entries or query_embedding is None: + return None + + candidates = [] + for entry in entries: + cached_vec = entry.get("embedding") + if cached_vec is None: + continue + + # Fast Bi-Encoder filter (Cosine Similarity) + sim = float(np.dot(cached_vec, query_embedding)) + + # Shortlist candidates that are "vaguely" similar + if sim > 0.40: + candidates.append(entry) + + if not candidates: + return None + + # Verification Step: Cross-Encoder + ce_model = _get_cross_encoder() + + # Pair the current user question with every candidate's original question + pairs = [[current_question, c["question"]] for c in candidates] + + # Get scores (higher is more similar) + ce_scores = ce_model.predict(pairs) + + best_idx = np.argmax(ce_scores) + + # A score of 0.7-0.8 on most Cross-Encoders indicates strong semantic equivalence + if ce_scores[best_idx] > 0.75: + return candidates[best_idx]["payload"] + + return None + +def _semantic_cache_store( + config_key: str, + normalized_question: str, + question_embedding: Optional[np.ndarray], + payload: Dict[str, Any], +) -> None: + if question_embedding is None: + return + entries = _SEMANTIC_CACHE.setdefault(config_key, []) + entries.append( + { + "question": normalized_question, + "embedding": question_embedding.astype(np.float32), + "payload": payload, + } + ) + if len(entries) > _SEMANTIC_CACHE_MAX_ENTRIES: + entries.pop(0) + + +def _get_question_embedder( + retrievers: List[Any], cfg: RAGConfig +) -> Optional[SentenceTransformer]: + for retriever in retrievers or []: + if isinstance(retriever, FAISSRetriever): + return retriever.embedder + model_path = cfg.embed_model + if not model_path: + return None + embedder = _QUESTION_EMBEDDERS.get(model_path) + if embedder is None: + embedder = SentenceTransformer(model_path) + _QUESTION_EMBEDDERS[model_path] = embedder + return embedder + + +def _compute_question_embedding( + question: str, + retrievers: List[Any], + cfg: RAGConfig, +) -> Optional[np.ndarray]: + embedder = _get_question_embedder(retrievers, cfg) + if not embedder: + return None + vec = embedder.encode( + [question], + batch_size=1, + normalize=True, + show_progress_bar=False, + ) + if vec.size == 0: + return None + return vec[0] \ No newline at end of file diff --git a/src/main.py b/src/main.py index 998d3a11..983d1a36 100644 --- a/src/main.py +++ b/src/main.py @@ -18,6 +18,7 @@ from src.retriever import filter_retrieved_chunks, BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts from src.query_enhancement import generate_hypothetical_document from src.ranking.reranker import rerank +from src.cache import _SEMANTIC_CACHE, _semantic_cache_store, _compute_question_embedding, _normalize_question, _make_cache_config_key, _semantic_cache_lookup from rich.console import Console from rich.markdown import Markdown @@ -145,7 +146,32 @@ def get_answer( ranker = artifacts["ranker"] logger.log_query_start(question) - + + normalized_question = _normalize_question(question) + config_cache_key = _make_cache_config_key(cfg, args, golden_chunks) + question_embedding: Optional[np.ndarray] = None + + semantic_hit = None + if _SEMANTIC_CACHE.get(config_cache_key): + question_embedding = _compute_question_embedding( + normalized_question, retrievers, cfg + ) + semantic_hit = _semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) + + if semantic_hit: + console.print(f"Semantic cache hit") + # console.print(f"{_SEMANTIC_CACHE}") + chunk_indices = semantic_hit.get("chunk_indices", []) + if chunk_indices and not cfg.disable_chunks and not cfg.use_indexed_chunks: + logger.log_chunks_used(chunk_indices, chunks, sources) + ans = semantic_hit.get("answer", "") + if is_test_mode: + return ans, semantic_hit.get("chunks_info"), semantic_hit.get("hyde_query") + from rich.markdown import Markdown + console.print(Markdown(ans)) + return ans + + console.print(f"Semantic cache miss") # Step 1: Get chunks (golden, retrieved, or none) chunks_info = None hyde_query = None @@ -239,10 +265,46 @@ def get_answer( for delta in stream_iter: ans += delta ans = dedupe_generated_text(ans) + + cache_payload = { + "answer": ans, + "chunks_info": chunks_info, + "hyde_query": hyde_query, + "chunk_indices": topk_idxs, + } + if question_embedding is None: + question_embedding = _compute_question_embedding( + normalized_question, retrievers, cfg + ) + _semantic_cache_store( + config_cache_key, + normalized_question, + question_embedding, + cache_payload, + ) + console.print(f"{_SEMANTIC_CACHE}") return ans, chunks_info, hyde_query else: # Accumulate the full text while rendering incremental Markdown chunks ans = render_streaming_ans(console, stream_iter) + + cache_payload = { + "answer": ans, + "chunks_info": chunks_info, + "hyde_query": hyde_query, + "chunk_indices": topk_idxs, + } + if question_embedding is None: + question_embedding = _compute_question_embedding( + normalized_question, retrievers, cfg + ) + _semantic_cache_store( + config_cache_key, + normalized_question, + question_embedding, + cache_payload, + ) + # console.print(f"{_SEMANTIC_CACHE}") return ans From 9cdf2a784d96b11bca48d387bc48da08e8f9ec4a Mon Sep 17 00:00:00 2001 From: infinite-void Date: Tue, 10 Feb 2026 18:52:16 -0500 Subject: [PATCH 02/20] initial cleanup --- src/cache.py | 66 +++++++++++++++++++--------------------------------- src/main.py | 27 ++++++++------------- 2 files changed, 34 insertions(+), 59 deletions(-) diff --git a/src/cache.py b/src/cache.py index 97e58b77..5aab2af5 100644 --- a/src/cache.py +++ b/src/cache.py @@ -8,29 +8,29 @@ from src.retriever import filter_retrieved_chunks, BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts from sentence_transformers import CrossEncoder -_SEMANTIC_CACHE: Dict[str, List[Dict[str, Any]]] = {} -_SEMANTIC_CACHE_THRESHOLD = 0.85 -_SEMANTIC_CACHE_MAX_ENTRIES = 50 -_QUESTION_EMBEDDERS: Dict[str, SentenceTransformer] = {} +SEMANTIC_CACHE: Dict[str, List[Dict[str, Any]]] = {} +SEMANTIC_CACHE_THRESHOLD = 0.85 +SEMANTIC_CACHE_MAX_ENTRIES = 50 +QUESTION_EMBEDDERS: Dict[str, SentenceTransformer] = {} # Add to your global variables -_CROSS_ENCODER_MODEL: Optional[CrossEncoder] = None +CROSS_ENCODER_MODEL: Optional[CrossEncoder] = None -def _get_cross_encoder(): - global _CROSS_ENCODER_MODEL - if _CROSS_ENCODER_MODEL is None: +def get_cross_encoder(): + global CROSS_ENCODER_MODEL + if CROSS_ENCODER_MODEL is None: # A small, fast model ideal for caching verification - _CROSS_ENCODER_MODEL = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') - return _CROSS_ENCODER_MODEL + CROSS_ENCODER_MODEL = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') + return CROSS_ENCODER_MODEL -def _normalize_question(q: str) -> str: +def normalize_question(q: str) -> str: return " ".join((q or "").strip().lower().split()) -def _make_cache_config_key( +def make_cache_config_key( cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[list], @@ -54,30 +54,12 @@ def _make_cache_config_key( return json.dumps(payload, sort_keys=True) -# def _semantic_cache_lookup(config_key: str, query_embedding: np.ndarray): -# entries = _SEMANTIC_CACHE.get(config_key) or [] -# if not entries or query_embedding is None: -# return None -# best_entry = None -# best_score = -1.0 -# for entry in entries: -# cached_vec = entry.get("embedding") -# if cached_vec is None: -# continue -# sim = float(np.dot(cached_vec, query_embedding)) -# if sim > best_score: -# best_score = sim -# best_entry = entry -# if best_entry and best_score >= _SEMANTIC_CACHE_THRESHOLD: -# return best_entry["payload"] -# return None - -def _semantic_cache_lookup( +def semantic_cache_lookup( config_key: str, query_embedding: np.ndarray, current_question: str # New parameter ): - entries = _SEMANTIC_CACHE.get(config_key) or [] + entries = SEMANTIC_CACHE.get(config_key) or [] if not entries or query_embedding is None: return None @@ -98,7 +80,7 @@ def _semantic_cache_lookup( return None # Verification Step: Cross-Encoder - ce_model = _get_cross_encoder() + ce_model = get_cross_encoder() # Pair the current user question with every candidate's original question pairs = [[current_question, c["question"]] for c in candidates] @@ -114,7 +96,7 @@ def _semantic_cache_lookup( return None -def _semantic_cache_store( +def semantic_cache_store( config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], @@ -122,7 +104,7 @@ def _semantic_cache_store( ) -> None: if question_embedding is None: return - entries = _SEMANTIC_CACHE.setdefault(config_key, []) + entries = SEMANTIC_CACHE.setdefault(config_key, []) entries.append( { "question": normalized_question, @@ -130,11 +112,11 @@ def _semantic_cache_store( "payload": payload, } ) - if len(entries) > _SEMANTIC_CACHE_MAX_ENTRIES: + if len(entries) > SEMANTIC_CACHE_MAX_ENTRIES: entries.pop(0) -def _get_question_embedder( +def get_question_embedder( retrievers: List[Any], cfg: RAGConfig ) -> Optional[SentenceTransformer]: for retriever in retrievers or []: @@ -143,19 +125,19 @@ def _get_question_embedder( model_path = cfg.embed_model if not model_path: return None - embedder = _QUESTION_EMBEDDERS.get(model_path) + embedder = QUESTION_EMBEDDERS.get(model_path) if embedder is None: embedder = SentenceTransformer(model_path) - _QUESTION_EMBEDDERS[model_path] = embedder + QUESTION_EMBEDDERS[model_path] = embedder return embedder -def _compute_question_embedding( +def compute_question_embedding( question: str, retrievers: List[Any], cfg: RAGConfig, ) -> Optional[np.ndarray]: - embedder = _get_question_embedder(retrievers, cfg) + embedder = get_question_embedder(retrievers, cfg) if not embedder: return None vec = embedder.encode( @@ -166,4 +148,4 @@ def _compute_question_embedding( ) if vec.size == 0: return None - return vec[0] \ No newline at end of file + return vec[0] diff --git a/src/main.py b/src/main.py index 983d1a36..605dedf3 100644 --- a/src/main.py +++ b/src/main.py @@ -18,7 +18,7 @@ from src.retriever import filter_retrieved_chunks, BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts from src.query_enhancement import generate_hypothetical_document from src.ranking.reranker import rerank -from src.cache import _SEMANTIC_CACHE, _semantic_cache_store, _compute_question_embedding, _normalize_question, _make_cache_config_key, _semantic_cache_lookup +from src.cache import SEMANTIC_CACHE, semantic_cache_store, compute_question_embedding, normalize_question, make_cache_config_key, semantic_cache_lookup from rich.console import Console from rich.markdown import Markdown @@ -147,31 +147,26 @@ def get_answer( logger.log_query_start(question) - normalized_question = _normalize_question(question) - config_cache_key = _make_cache_config_key(cfg, args, golden_chunks) + normalized_question = normalize_question(question) + config_cache_key = make_cache_config_key(cfg, args, golden_chunks) question_embedding: Optional[np.ndarray] = None semantic_hit = None - if _SEMANTIC_CACHE.get(config_cache_key): - question_embedding = _compute_question_embedding( + if SEMANTIC_CACHE.get(config_cache_key): + question_embedding = compute_question_embedding( normalized_question, retrievers, cfg ) - semantic_hit = _semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) + semantic_hit = semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) if semantic_hit: - console.print(f"Semantic cache hit") - # console.print(f"{_SEMANTIC_CACHE}") chunk_indices = semantic_hit.get("chunk_indices", []) if chunk_indices and not cfg.disable_chunks and not cfg.use_indexed_chunks: logger.log_chunks_used(chunk_indices, chunks, sources) ans = semantic_hit.get("answer", "") if is_test_mode: return ans, semantic_hit.get("chunks_info"), semantic_hit.get("hyde_query") - from rich.markdown import Markdown - console.print(Markdown(ans)) return ans - console.print(f"Semantic cache miss") # Step 1: Get chunks (golden, retrieved, or none) chunks_info = None hyde_query = None @@ -273,16 +268,15 @@ def get_answer( "chunk_indices": topk_idxs, } if question_embedding is None: - question_embedding = _compute_question_embedding( + question_embedding = compute_question_embedding( normalized_question, retrievers, cfg ) - _semantic_cache_store( + semantic_cache_store( config_cache_key, normalized_question, question_embedding, cache_payload, ) - console.print(f"{_SEMANTIC_CACHE}") return ans, chunks_info, hyde_query else: # Accumulate the full text while rendering incremental Markdown chunks @@ -295,16 +289,15 @@ def get_answer( "chunk_indices": topk_idxs, } if question_embedding is None: - question_embedding = _compute_question_embedding( + question_embedding = compute_question_embedding( normalized_question, retrievers, cfg ) - _semantic_cache_store( + semantic_cache_store( config_cache_key, normalized_question, question_embedding, cache_payload, ) - # console.print(f"{_SEMANTIC_CACHE}") return ans From d07a90805b299a509c05b0738f32ab14baa1be1c Mon Sep 17 00:00:00 2001 From: infinite-void Date: Tue, 10 Feb 2026 19:10:43 -0500 Subject: [PATCH 03/20] print the cache answer without straming --- src/cache.py | 3 ++- src/main.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/cache.py b/src/cache.py index 5aab2af5..094d7f5f 100644 --- a/src/cache.py +++ b/src/cache.py @@ -1,3 +1,4 @@ +from enum import show_flag_values import argparse import json import hashlib @@ -86,7 +87,7 @@ def semantic_cache_lookup( pairs = [[current_question, c["question"]] for c in candidates] # Get scores (higher is more similar) - ce_scores = ce_model.predict(pairs) + ce_scores = ce_model.predict(pairs, show_progress_bar=False) best_idx = np.argmax(ce_scores) diff --git a/src/main.py b/src/main.py index 605dedf3..d7f7179a 100644 --- a/src/main.py +++ b/src/main.py @@ -165,6 +165,7 @@ def get_answer( ans = semantic_hit.get("answer", "") if is_test_mode: return ans, semantic_hit.get("chunks_info"), semantic_hit.get("hyde_query") + render_final_answer(console, ans) return ans # Step 1: Get chunks (golden, retrieved, or none) @@ -319,6 +320,18 @@ def render_streaming_ans(console, stream_iter): console.print("\n[bold cyan]===================== END OF ANSWER ====================[/bold cyan]\n") return ans +# Fully generated answer without streaming (Usage: cache hits) +def render_final_answer(console, ans): + if not console: + raise ValueError("Console must be non null for rendering.") + console.print( + "\n[bold cyan]==================== START OF ANSWER ===================[/bold cyan]\n" + ) + console.print(Markdown(ans)) + console.print( + "\n[bold cyan]===================== END OF ANSWER ====================[/bold cyan]\n" + ) + def get_keywords(question: str) -> list: """ Simple keyword extraction from the question. From 7e5af1900156896dd7f6ac445520fd93e4b6c61e Mon Sep 17 00:00:00 2001 From: infinite-void Date: Tue, 10 Feb 2026 19:39:02 -0500 Subject: [PATCH 04/20] code clean up --- src/cache.py | 145 +++++++++++++++++++++++++-------------------------- src/main.py | 78 ++++++++++++--------------- 2 files changed, 103 insertions(+), 120 deletions(-) diff --git a/src/cache.py b/src/cache.py index 094d7f5f..3460aaa3 100644 --- a/src/cache.py +++ b/src/cache.py @@ -1,41 +1,46 @@ -from enum import show_flag_values +from yaml import Node import argparse import json -import hashlib +import hashlib from typing import Dict, Optional, Any, List + import numpy as np +from sentence_transformers import CrossEncoder + from src.embedder import SentenceTransformer from src.config import RAGConfig -from src.retriever import filter_retrieved_chunks, BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts -from sentence_transformers import CrossEncoder +from src.retriever import BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts, filter_retrieved_chunks +# ----------------------------- +# Global cache and constants +# ----------------------------- SEMANTIC_CACHE: Dict[str, List[Dict[str, Any]]] = {} SEMANTIC_CACHE_THRESHOLD = 0.85 SEMANTIC_CACHE_MAX_ENTRIES = 50 QUESTION_EMBEDDERS: Dict[str, SentenceTransformer] = {} - - - -# Add to your global variables CROSS_ENCODER_MODEL: Optional[CrossEncoder] = None -def get_cross_encoder(): + +# ----------------------------- +# Utilities +# ----------------------------- +def get_cross_encoder() -> CrossEncoder: + """Return a global cross-encoder model instance, initializing if needed.""" global CROSS_ENCODER_MODEL if CROSS_ENCODER_MODEL is None: - # A small, fast model ideal for caching verification CROSS_ENCODER_MODEL = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') return CROSS_ENCODER_MODEL def normalize_question(q: str) -> str: + """Normalize a question string: lowercase, strip, and collapse spaces.""" return " ".join((q or "").strip().lower().split()) -def make_cache_config_key( - cfg: RAGConfig, - args: argparse.Namespace, - golden_chunks: Optional[list], -) -> str: +def make_cache_config_key(cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str: + """ + Create a unique JSON key for semantic cache based on config, arguments, and optional golden chunks. + """ payload = { "gen_model": args.model_path or cfg.gen_model, "embed_model": cfg.embed_model, @@ -49,104 +54,96 @@ def make_cache_config_key( "use_golden_chunks": bool(golden_chunks and cfg.use_golden_chunks), "index_prefix": getattr(args, "index_prefix", None), } + if golden_chunks and cfg.use_golden_chunks: - sig = hashlib.sha256("||".join(golden_chunks).encode("utf-8")).hexdigest() - payload["golden_signature"] = sig - return json.dumps(payload, sort_keys=True) + signature = hashlib.sha256("||".join(golden_chunks).encode("utf-8")).hexdigest() + payload["golden_signature"] = signature + return json.dumps(payload, sort_keys=True) -def semantic_cache_lookup( - config_key: str, - query_embedding: np.ndarray, - current_question: str # New parameter -): - entries = SEMANTIC_CACHE.get(config_key) or [] +# ----------------------------- +# Semantic cache operations +# ----------------------------- +def semantic_cache_lookup(config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: + """ + Retrieve a cached answer if semantically similar to the current question. + """ + entries = SEMANTIC_CACHE.get(config_key, []) if not entries or query_embedding is None: return None - - candidates = [] - for entry in entries: - cached_vec = entry.get("embedding") - if cached_vec is None: - continue - - # Fast Bi-Encoder filter (Cosine Similarity) - sim = float(np.dot(cached_vec, query_embedding)) - - # Shortlist candidates that are "vaguely" similar - if sim > 0.40: - candidates.append(entry) + # Step 1: Bi-Encoder filter (fast cosine similarity) + candidates = [ + entry for entry in entries + if entry.get("embedding") is not None and float(np.dot(entry["embedding"], query_embedding)) > 0.40 + ] if not candidates: return None - # Verification Step: Cross-Encoder + # Step 2: Cross-Encoder verification ce_model = get_cross_encoder() - - # Pair the current user question with every candidate's original question pairs = [[current_question, c["question"]] for c in candidates] - - # Get scores (higher is more similar) ce_scores = ce_model.predict(pairs, show_progress_bar=False) - - best_idx = np.argmax(ce_scores) - - # A score of 0.7-0.8 on most Cross-Encoders indicates strong semantic equivalence + best_idx = int(np.argmax(ce_scores)) + if ce_scores[best_idx] > 0.75: return candidates[best_idx]["payload"] - return None -def semantic_cache_store( - config_key: str, - normalized_question: str, - question_embedding: Optional[np.ndarray], - payload: Dict[str, Any], -) -> None: + +def semantic_cache_store(config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None: + """ + Store a question, its embedding, and the generated answer in the semantic cache. + Evict oldest entries if cache exceeds SEMANTIC_CACHE_MAX_ENTRIES. + """ if question_embedding is None: return + entries = SEMANTIC_CACHE.setdefault(config_key, []) - entries.append( - { - "question": normalized_question, - "embedding": question_embedding.astype(np.float32), - "payload": payload, - } - ) + entries.append({ + "question": normalized_question, + "embedding": question_embedding.astype(np.float32), + "payload": payload, + }) + if len(entries) > SEMANTIC_CACHE_MAX_ENTRIES: entries.pop(0) -def get_question_embedder( - retrievers: List[Any], cfg: RAGConfig -) -> Optional[SentenceTransformer]: +# ----------------------------- +# Question embedding +# ----------------------------- +def get_question_embedder(retrievers: List[Any], cfg: RAGConfig) -> Optional[SentenceTransformer]: + """ + Get or initialize a SentenceTransformer for encoding questions. + Prefers the embedder from any FAISSRetriever in the retrievers list. + """ for retriever in retrievers or []: if isinstance(retriever, FAISSRetriever): return retriever.embedder + model_path = cfg.embed_model if not model_path: return None + embedder = QUESTION_EMBEDDERS.get(model_path) if embedder is None: embedder = SentenceTransformer(model_path) QUESTION_EMBEDDERS[model_path] = embedder + return embedder -def compute_question_embedding( - question: str, - retrievers: List[Any], - cfg: RAGConfig, -) -> Optional[np.ndarray]: +def compute_question_embedding(question: str, retrievers: List[Any], cfg: RAGConfig) -> Optional[np.ndarray]: + """ + Compute a normalized embedding vector for a question using the configured embedder. + """ embedder = get_question_embedder(retrievers, cfg) if not embedder: return None - vec = embedder.encode( - [question], - batch_size=1, - normalize=True, - show_progress_bar=False, - ) + + vec = embedder.encode([question], batch_size=1, normalize=True, show_progress_bar=False) if vec.size == 0: return None + return vec[0] diff --git a/src/main.py b/src/main.py index d7f7179a..8b8e5e48 100644 --- a/src/main.py +++ b/src/main.py @@ -150,21 +150,24 @@ def get_answer( normalized_question = normalize_question(question) config_cache_key = make_cache_config_key(cfg, args, golden_chunks) question_embedding: Optional[np.ndarray] = None + semantic_hit: Optional[Dict[str, Any]] = None - semantic_hit = None - if SEMANTIC_CACHE.get(config_cache_key): - question_embedding = compute_question_embedding( - normalized_question, retrievers, cfg - ) + # Check semantic cache + if config_cache_key in SEMANTIC_CACHE: + question_embedding = compute_question_embedding(normalized_question, retrievers, cfg) semantic_hit = semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) + # Return cached answer if found if semantic_hit: chunk_indices = semantic_hit.get("chunk_indices", []) - if chunk_indices and not cfg.disable_chunks and not cfg.use_indexed_chunks: + if chunk_indices and not (cfg.disable_chunks or cfg.use_indexed_chunks): logger.log_chunks_used(chunk_indices, chunks, sources) + ans = semantic_hit.get("answer", "") + if is_test_mode: return ans, semantic_hit.get("chunks_info"), semantic_hit.get("hyde_query") + render_final_answer(console, ans) return ans @@ -255,50 +258,33 @@ def get_answer( system_prompt_mode=system_prompt, ) + # Accumulate the answer + ans = "".join(stream_iter) if is_test_mode else render_streaming_ans(console, stream_iter) + + # Deduplicate in test mode if is_test_mode: - # We do not render MD in the test mode - ans = "" - for delta in stream_iter: - ans += delta ans = dedupe_generated_text(ans) - cache_payload = { - "answer": ans, - "chunks_info": chunks_info, - "hyde_query": hyde_query, - "chunk_indices": topk_idxs, - } - if question_embedding is None: - question_embedding = compute_question_embedding( - normalized_question, retrievers, cfg - ) - semantic_cache_store( - config_cache_key, - normalized_question, - question_embedding, - cache_payload, - ) + # Store in semantic cache + cache_payload = { + "answer": ans, + "chunks_info": chunks_info, + "hyde_query": hyde_query, + "chunk_indices": topk_idxs, + } + if question_embedding is None: + question_embedding = compute_question_embedding(normalized_question, retrievers, cfg) + semantic_cache_store( + config_cache_key, + normalized_question, + question_embedding, + cache_payload + ) + + # Return for test mode + if is_test_mode: return ans, chunks_info, hyde_query - else: - # Accumulate the full text while rendering incremental Markdown chunks - ans = render_streaming_ans(console, stream_iter) - - cache_payload = { - "answer": ans, - "chunks_info": chunks_info, - "hyde_query": hyde_query, - "chunk_indices": topk_idxs, - } - if question_embedding is None: - question_embedding = compute_question_embedding( - normalized_question, retrievers, cfg - ) - semantic_cache_store( - config_cache_key, - normalized_question, - question_embedding, - cache_payload, - ) + return ans From 305b10ded7c3a2766ace30923811b432f3cb6dff Mon Sep 17 00:00:00 2001 From: infinite-void Date: Fri, 13 Feb 2026 00:31:18 -0800 Subject: [PATCH 05/20] benchmark checkpoint --- src/cache.py | 8 +- src/main.py | 4 +- tests/cache_benchmark.yaml | 450 ++++++++++++++++++++++++++++++++++ tests/test_cache_benchmark.py | 131 ++++++++++ 4 files changed, 587 insertions(+), 6 deletions(-) create mode 100644 tests/cache_benchmark.yaml create mode 100644 tests/test_cache_benchmark.py diff --git a/src/cache.py b/src/cache.py index 3460aaa3..8cb67a92 100644 --- a/src/cache.py +++ b/src/cache.py @@ -113,7 +113,7 @@ def semantic_cache_store(config_key: str, normalized_question: str, question_emb # ----------------------------- # Question embedding # ----------------------------- -def get_question_embedder(retrievers: List[Any], cfg: RAGConfig) -> Optional[SentenceTransformer]: +def get_question_embedder(retrievers: List[Any], embed_model: str) -> Optional[SentenceTransformer]: """ Get or initialize a SentenceTransformer for encoding questions. Prefers the embedder from any FAISSRetriever in the retrievers list. @@ -122,7 +122,7 @@ def get_question_embedder(retrievers: List[Any], cfg: RAGConfig) -> Optional[Sen if isinstance(retriever, FAISSRetriever): return retriever.embedder - model_path = cfg.embed_model + model_path = embed_model if not model_path: return None @@ -134,11 +134,11 @@ def get_question_embedder(retrievers: List[Any], cfg: RAGConfig) -> Optional[Sen return embedder -def compute_question_embedding(question: str, retrievers: List[Any], cfg: RAGConfig) -> Optional[np.ndarray]: +def compute_question_embedding(question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]: """ Compute a normalized embedding vector for a question using the configured embedder. """ - embedder = get_question_embedder(retrievers, cfg) + embedder = get_question_embedder(retrievers, embed_model) if not embedder: return None diff --git a/src/main.py b/src/main.py index 8b8e5e48..6924d8d8 100644 --- a/src/main.py +++ b/src/main.py @@ -154,7 +154,7 @@ def get_answer( # Check semantic cache if config_cache_key in SEMANTIC_CACHE: - question_embedding = compute_question_embedding(normalized_question, retrievers, cfg) + question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) semantic_hit = semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) # Return cached answer if found @@ -273,7 +273,7 @@ def get_answer( "chunk_indices": topk_idxs, } if question_embedding is None: - question_embedding = compute_question_embedding(normalized_question, retrievers, cfg) + question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) semantic_cache_store( config_cache_key, normalized_question, diff --git a/tests/cache_benchmark.yaml b/tests/cache_benchmark.yaml new file mode 100644 index 00000000..e740c6a2 --- /dev/null +++ b/tests/cache_benchmark.yaml @@ -0,0 +1,450 @@ +- id: q1 + question: "What is a database management system?" + variations: + - "Define DBMS." + - "Explain the concept of a database management system." + - "What does DBMS stand for and what is it?" + - "Describe the function of a DBMS." + - "What is the purpose of a Database Management System?" + - "Can you define what a DBMS is?" + - "What do we mean by database management system?" + - "Give a definition of DBMS." + - "What is a system for managing databases called?" + - "Explain the term DBMS." + - "What is the role of a database management system?" + - "Define the term Database Management System." + +- id: q2 + question: "What is the relational model?" + variations: + - "Define the relational data model." + - "Explain the relational model in databases." + - "What is the structure of the relational model?" + - "Describe the relational database model." + - "What does the relational model consist of?" + - "Can you explain the relational model?" + - "What is a relational model?" + - "Give an overview of the relational model." + - "How is data represented in the relational model?" + - "What are the core concepts of the relational model?" + - "Define relational database modeling." + +- id: q3 + question: "What is SQL?" + variations: + - "Define SQL." + - "What does SQL stand for?" + - "Explain what SQL is." + - "What is the purpose of SQL?" + - "Describe the Structured Query Language." + - "What language is used to query relational databases?" + - "Can you tell me what SQL is?" + - "What is the function of SQL?" + - "What is the standard language for relational databases?" + - "Explain the acronym SQL." + - "What is Structured Query Language?" + +- id: q4 + question: "What is a primary key?" + variations: + - "Define primary key." + - "What is the purpose of a primary key?" + - "Explain the concept of a primary key in a table." + - "What uniquely identifies a row in a table?" + - "Describe a primary key." + - "What is the role of a primary key constraint?" + - "Can you define primary key?" + - "What does a primary key do?" + - "What is a PK in a database?" + - "How do you uniquely identify a record?" + - "What property must a primary key have?" + +- id: q5 + question: "What is a foreign key?" + variations: + - "Define foreign key." + - "What is the purpose of a foreign key?" + - "Explain the concept of a foreign key." + - "How do tables relate to each other using keys?" + - "Describe a foreign key constraint." + - "What is an FK in a database?" + - "What allows a table to reference another table?" + - "Can you define foreign key?" + - "What is the role of foreign keys?" + - "How is referential integrity enforced?" + - "What key links two tables together?" + +- id: q6 + question: "What is an Entity-Relationship model?" + variations: + - "Define E-R model." + - "Explain the Entity-Relationship diagram." + - "What is an ERD?" + - "Describe the E-R data model." + - "What are entities and relationships?" + - "What is the purpose of the ER model?" + - "Can you explain E-R modeling?" + - "What is an entity in a database design?" + - "What model is used for conceptual database design?" + - "Define the components of an E-R model." + - "What is entity relationship modeling?" + +- id: q7 + question: "What is normalization?" + variations: + - "Define database normalization." + - "Why do we normalize databases?" + - "Explain the process of normalization." + - "What is the goal of normalization?" + - "Describe normalization in DBMS." + - "What are normal forms?" + - "How do we reduce redundancy in databases?" + - "Can you define normalization?" + - "What is the purpose of normalizing data?" + - "What is 1NF, 2NF, 3NF?" + - "Explain data normalization." + +- id: q8 + question: "What is a transaction?" + variations: + - "Define a database transaction." + - "What is a unit of work in a database?" + - "Explain the concept of a transaction." + - "What constitutes a transaction in DBMS?" + - "Describe a database transaction." + - "What are the ACID properties apply to?" + - "Can you define a transaction?" + - "What is an atomic unit of execution?" + - "What is a sequence of operations treated as a single unit?" + - "Explain database transactions." + - "What is the definition of a transaction?" + +- id: q9 + question: "What are ACID properties?" + variations: + - "Define ACID in databases." + - "What does ACID stand for?" + - "Explain Atomicity, Consistency, Isolation, Durability." + - "What are the properties of a transaction?" + - "Describe the ACID properties." + - "What guarantees reliability in transactions?" + - "Can you explain ACID?" + - "What safeguards database transactions?" + - "What does A.C.I.D. mean?" + - "List the transaction properties." + - "What is atomicity and durability?" + +- id: q10 + question: "What is concurrency control?" + variations: + - "Define concurrency control." + - "Why is concurrency control needed?" + - "Explain how databases handle simultaneous access." + - "What manages multiple transactions running at once?" + - "Describe concurrency control mechanisms." + - "How do we prevent conflicts in databases?" + - "Can you define concurrency control?" + - "What ensures data consistency in multi-user environments?" + - "What is the purpose of locking in databases?" + - "Explain concurrency in DBMS." + - "How are concurrent transactions managed?" + +- id: q11 + question: "What involves indexing in databases?" + variations: + - "Define database indexing." + - "What is the purpose of an index?" + - "How do indexes improve performance?" + - "Explain indexing in DBMS." + - "What structure speeds up data retrieval?" + - "Describe how a database index works." + - "Can you define indexing?" + - "What is a B-tree index?" + - "Why create an index on a column?" + - "What optimizes search queries?" + - "Explain the concept of an index." + +- id: q12 + question: "What is a B+ tree?" + variations: + - "Define B+ tree." + - "How does a B+ tree work?" + - "What is the difference between B-tree and B+ tree?" + - "Explain the B+ tree data structure." + - "What is commonly used for database indexing?" + - "Describe a B+ tree." + - "Can you explain B+ trees?" + - "What is a balanced tree structure in databases?" + - "How are range queries handled in B+ trees?" + - "What is the structure of a B+ tree?" + - "Why use B+ trees for indexing?" + +- id: q13 + question: "What is hashing?" + variations: + - "Define hashing in databases." + - "What is a hash function?" + - "Explain hash-based indexing." + - "How does hashing work?" + - "Describe static and dynamic hashing." + - "What is a hash table?" + - "Can you define hashing?" + - "What technique computes a location from a key?" + - "What is extendable hashing?" + - "Explain the concept of hashing." + - "How is data retrieved using hashing?" + +- id: q14 + question: "What is query optimization?" + variations: + - "Define query optimization." + - "How does a database optimize queries?" + - "Explain the query optimizer's role." + - "What is the most efficient execution plan?" + - "Describe query optimization." + - "How does the DBMS choose a query plan?" + - "Can you explain query optimization?" + - "What constitutes query processing?" + - "How are SQL queries made faster?" + - "What involves cost-based optimization?" + - "Explain the query evaluation engine." + +- id: q15 + question: "What is Big Data?" + variations: + - "Define Big Data." + - "What characterizes Big Data?" + - "Explain the concept of Big Data." + - "What are the 3 Vs of Big Data?" + - "Describe Big Data analytics." + - "What is large-scale data processing?" + - "Can you define Big Data?" + - "What defines datasets too large for traditional DBs?" + - "What is the definition of Big Data?" + - "Explain the term Big Data." + - "What technology handles massive datasets?" + +- id: q16 + question: "What is MapReduce?" + variations: + - "Define MapReduce." + - "How does MapReduce work?" + - "Explain the Map and Reduce phases." + - "What is the MapReduce programming model?" + - "Describe the MapReduce paradigm." + - "What framework processes data in parallel?" + - "Can you explain MapReduce?" + - "What is used for distributed processing?" + - "How is data processed in Hadoop?" + - "What involves mapping followed by reducing?" + - "Explain the MapReduce algorithm." + +- id: q17 + question: "What is NoSQL?" + variations: + - "Define NoSQL." + - "What does NoSQL stand for?" + - "Explain NoSQL databases." + - "How is NoSQL different from SQL?" + - "Describe non-relational databases." + - "What are key-value stores?" + - "Can you define NoSQL?" + - "What databases are schema-less?" + - "What is Not Only SQL?" + - "Explain the types of NoSQL databases." + - "What is a document store?" + +- id: q18 + question: "What is data mining?" + variations: + - "Define data mining." + - "What is the process of discovering patterns?" + - "Explain data mining." + - "What involves extracting knowledge from data?" + - "Describe data mining techniques." + - "How do we find hidden patterns in data?" + - "Can you define data mining?" + - "What is knowledge discovery in databases?" + - "What is KDD?" + - "Explain the concept of data mining." + - "What is association rule mining?" + +- id: q19 + question: "What is a data warehouse?" + variations: + - "Define data warehouse." + - "What is the purpose of a data warehouse?" + - "Explain data warehousing." + - "How is a data warehouse different from a database?" + - "Describe a data warehouse." + - "What stores historical data for analysis?" + - "Can you define a data warehouse?" + - "What is an OLAP system?" + - "What supports business intelligence?" + - "Explain the concept of a data warehouse." + - "What is a subject-oriented integrated data collection?" + +- id: q20 + question: "What is distributed database?" + variations: + - "Define distributed database." + - "What is a distributed DBMS?" + - "Explain distributed databases." + - "How is data stored across multiple sites?" + - "Describe a distributed database system." + - "What involves data fragmentation and replication?" + - "Can you define centralized vs distributed databases?" + - "What is a DDBMS?" + - "How do distributed transactions work?" + - "Explain data distribution." + - "What manages data across a network?" + +- id: q21 + question: "What is database security?" + variations: + - "Define database security." + - "Why is database security important?" + - "Explain how to secure a database." + - "What protects data from unauthorized access?" + - "Describe database security measures." + - "What involves authentication and authorization?" + - "Can you define database security?" + - "What is encryption in databases?" + - "How do we prevent SQL injection?" + - "Explain data privacy and security." + - "What ensures confidentiality of data?" + +- id: q22 + question: "What is RAID?" + variations: + - "Define RAID." + - "What does RAID stand for?" + - "Explain Redundant Array of Independent Disks." + - "How does RAID improve reliability?" + - "Describe RAID levels." + - "What is disk mirroring and striping?" + - "Can you explain RAID?" + - "What provides disk redundancy?" + - "What is RAID 0, 1, 5?" + - "Explain storage virtualization." + - "What technology combines multiple disk drives?" + +- id: q23 + question: "What is a view in SQL?" + variations: + - "Define SQL view." + - "What is a virtual table?" + - "Explain the concept of a view." + - "How do you create a view in SQL?" + - "Describe a database view." + - "What is a stored query result?" + - "Can you define a view?" + - "What simplifies complex queries?" + - "What provides a customized look at data?" + - "Explain views in relational databases." + - "What is the utility of a view?" + +- id: q24 + question: "What is a trigger?" + variations: + - "Define database trigger." + - "What is an SQL trigger?" + - "Explain how triggers work." + - "When does a trigger execute?" + - "Describe a trigger." + - "What is an automatically executed procedure?" + - "Can you define a trigger?" + - "What responds to INSERT, UPDATE, DELETE?" + - "What enforces business rules automatically?" + - "Explain the use of triggers." + - "What is event-driven execution in DB?" + +- id: q25 + question: "What is two-phase locking?" + variations: + - "Define two-phase locking." + - "What is 2PL?" + - "Explain the 2PL protocol." + - "How does two-phase locking ensure serializability?" + - "Describe the growing and shrinking phases." + - "What prevents conflict serializability?" + - "Can you explain two-phase locking?" + - "What locking protocol is standard?" + - "What is strict two-phase locking?" + - "Explain the locking phase and unlocking phase." + - "What involves expanding and shrinking of locks?" + +- id: q26 + question: "What is deadlock?" + variations: + - "Define deadlock in databases." + - "What causes a deadlock?" + - "Explain database deadlock." + - "What happens when transactions wait for each other?" + - "Describe a deadlock situation." + - "How do we handle deadlocks?" + - "Can you define deadlock?" + - "What is a cycle of waiting transactions?" + - "What requires deadlock detection?" + - "Explain deadlock prevention." + - "What stops transactions from proceeding permanently?" + +- id: q27 + question: "What is write-ahead logging?" + variations: + - "Define write-ahead logging." + - "What is WAL?" + - "Explain the log-based recovery." + - "Why write to log before data file?" + - "Describe write-ahead logging." + - "What ensures durability in databases?" + - "Can you explain WAL?" + - "What records changes before applying them?" + - "What is the recovery log?" + - "Explain the logging protocol." + - "How does recovery work using logs?" + +- id: q28 + question: "What is a stored procedure?" + variations: + - "Define stored procedure." + - "What is a procedure stored in the database?" + - "Explain stored procedures." + - "How do stored procedures differ from functions?" + - "Describe a stored proc." + - "What executes logic on the database server?" + - "Can you define a stored procedure?" + - "What precompiled SQL code is saved?" + - "What encapsulates logic in the DB?" + - "Explain the benefits of stored procedures." + - "What is PL/SQL routine?" + +- id: q29 + question: "What is database recovery?" + variations: + - "Define database recovery." + - "What happens after a system crash?" + - "Explain recovery techniques." + - "How do we restore a database?" + - "Describe database recovery management." + - "What involves ARIES algorithm?" + - "Can you define recovery?" + - "What restores consistency after failure?" + - "What is log-based recovery?" + - "Explain shadow paging." + - "How is data durability ensured?" + +- id: q30 + question: "What is XML?" + variations: + - "Define XML." + - "What does XML stand for?" + - "Explain Extensible Markup Language." + - "How is XML used in databases?" + - "Describe XML data format." + - "What is semi-structured data?" + - "Can you define XML?" + - "What uses tags to define data?" + - "What is an XML database?" + - "Explain parsing XML." + - "What is the difference between HTML and XML?" diff --git a/tests/test_cache_benchmark.py b/tests/test_cache_benchmark.py new file mode 100644 index 00000000..c45e6d48 --- /dev/null +++ b/tests/test_cache_benchmark.py @@ -0,0 +1,131 @@ + +import pytest +import numpy as np +from unittest.mock import MagicMock, patch +from src.config import RAGConfig +from src.cache import ( + SEMANTIC_CACHE, + semantic_cache_store, + semantic_cache_lookup, + compute_question_embedding, + make_cache_config_key, + normalize_question +) + +import pytest +import numpy as np +import yaml +from pathlib import Path +from unittest.mock import MagicMock, patch +from src.config import RAGConfig +from src.cache import ( + SEMANTIC_CACHE, + semantic_cache_store, + semantic_cache_lookup, + compute_question_embedding, + make_cache_config_key, + normalize_question +) + +def load_benchmark_data(): + """Load benchmark questions from YAML.""" + yaml_path = Path(__file__).parent / "cache_benchmark.yaml" + with open(yaml_path, "r") as f: + return yaml.safe_load(f) + +BENCHMARK_DATA = load_benchmark_data() + +@pytest.fixture +def mock_config(): + """Create a mock RAGConfig for testing.""" + config = MagicMock(spec=RAGConfig) + config.gen_model = "mock-model" + config.embed_model = "models/Qwen3-Embedding-4B-Q5_K_M.gguf" # Use local available model + config.top_k = 5 + config.system_prompt_mode = "baseline" + config.ensemble_method = "rrf" + config.ranker_weights = {"faiss": 0.5, "bm25": 0.5} + config.use_hyde = False + config.use_indexed_chunks = False + config.disable_chunks = False + config.use_golden_chunks = False + return config + +def test_cache_benchmark_comprehensive(mock_config): + """ + Test 30 questions with ~10-15 variations each to verify semantic cache. + """ + # 1. Clear cache + SEMANTIC_CACHE.clear() + + # Setup + args = MagicMock() + args.model_path = None + args.system_prompt_mode = None + args.index_prefix = "test_index" + + cache_key = make_cache_config_key(mock_config, args, None) + embed_model_name = mock_config.embed_model + + total_hits = 0 + total_variations = 0 + failures = [] + + print(f"\n{'='*60}") + print(f" RUNNING COMPREHENSIVE CACHE BENCHMARK") + print(f" Model: {embed_model_name}") + print(f"{'='*60}") + + for entry in BENCHMARK_DATA: + question_id = entry["id"] + main_question = entry["question"] + variations = entry["variations"] + + print(f"\n[{question_id}] Seeding: '{main_question}'") + + # 2. Seed the main question + normalized_main = normalize_question(main_question) + embedding_main = compute_question_embedding(normalized_main, [], embed_model_name) + assert embedding_main is not None, f"Failed to compute embedding for {question_id}" + + payload = { + "answer": f"Cached answer for {question_id}", + "chunks_info": [], + "hyde_query": None, + "chunk_indices": [] + } + + semantic_cache_store(cache_key, normalized_main, embedding_main, payload) + + # 3. Test variations + entry_hits = 0 + + for var_q in variations: + normalized_var = normalize_question(var_q) + embedding_var = compute_question_embedding(normalized_var, [], embed_model_name) + + result = semantic_cache_lookup(cache_key, embedding_var, normalized_var) + + if result: + print(f" ✅ Hit: '{var_q}'") + entry_hits += 1 + else: + failures.append(f"[{question_id}] Missed: '{var_q}'") + print(f" ❌ Missed: '{var_q}'") + + total_hits += entry_hits + total_variations += len(variations) + print(f" ✅ Hits: {entry_hits}/{len(variations)}") + + # Summary + hit_rate = total_hits / total_variations if total_variations > 0 else 0 + print(f"\n{'='*60}") + print(f" SUMMARY") + print(f" Total Variations: {total_variations}") + print(f" Total Hits: {total_hits}") + print(f" Hit Rate: {hit_rate:.2%}") + print(f"{'='*60}") + + # Assert acceptable hit rate (e.g., >80% for semantic similarity) + # Given we are using a good embedding model, it should be high. + assert hit_rate >= 0.80, f"Cache hit rate {hit_rate:.2%} is below 80%. Failures: {failures[:10]}..." From d8ee0288469a2a5bec5a7a16b15e5808a33e6580 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 5 Mar 2026 18:32:40 -0800 Subject: [PATCH 06/20] fix benchmark variations --- tests/cache_benchmark.yaml | 477 ++++++++++++------------------------- 1 file changed, 148 insertions(+), 329 deletions(-) diff --git a/tests/cache_benchmark.yaml b/tests/cache_benchmark.yaml index e740c6a2..c369dbe4 100644 --- a/tests/cache_benchmark.yaml +++ b/tests/cache_benchmark.yaml @@ -1,450 +1,269 @@ - id: q1 question: "What is a database management system?" variations: - - "Define DBMS." - - "Explain the concept of a database management system." - - "What does DBMS stand for and what is it?" - - "Describe the function of a DBMS." - - "What is the purpose of a Database Management System?" - - "Can you define what a DBMS is?" - - "What do we mean by database management system?" - - "Give a definition of DBMS." - - "What is a system for managing databases called?" - - "Explain the term DBMS." - - "What is the role of a database management system?" - - "Define the term Database Management System." + - "Can you explain what a database management system is?" + - "How would you define a database management system?" + - "Please tell me what a database management system is." + - "What exactly is a database management system?" + - "What is the definition of a database management system?" - id: q2 question: "What is the relational model?" variations: - - "Define the relational data model." - - "Explain the relational model in databases." - - "What is the structure of the relational model?" - - "Describe the relational database model." - - "What does the relational model consist of?" - - "Can you explain the relational model?" - - "What is a relational model?" - - "Give an overview of the relational model." - - "How is data represented in the relational model?" - - "What are the core concepts of the relational model?" - - "Define relational database modeling." + - "Can you explain what the relational model is?" + - "How would you define the relational model?" + - "Please tell me what the relational model is." + - "What exactly is the relational model?" + - "What is the definition of the relational model?" - id: q3 question: "What is SQL?" variations: - - "Define SQL." - - "What does SQL stand for?" - - "Explain what SQL is." - - "What is the purpose of SQL?" - - "Describe the Structured Query Language." - - "What language is used to query relational databases?" - - "Can you tell me what SQL is?" - - "What is the function of SQL?" - - "What is the standard language for relational databases?" - - "Explain the acronym SQL." - - "What is Structured Query Language?" + - "Can you explain what SQL is?" + - "How would you define SQL?" + - "Please tell me what SQL is." + - "What exactly is SQL?" + - "What is the definition of SQL?" - id: q4 question: "What is a primary key?" variations: - - "Define primary key." - - "What is the purpose of a primary key?" - - "Explain the concept of a primary key in a table." - - "What uniquely identifies a row in a table?" - - "Describe a primary key." - - "What is the role of a primary key constraint?" - - "Can you define primary key?" - - "What does a primary key do?" - - "What is a PK in a database?" - - "How do you uniquely identify a record?" - - "What property must a primary key have?" + - "Can you explain what a primary key is?" + - "How would you define a primary key?" + - "Please tell me what a primary key is." + - "What exactly is a primary key?" + - "What is the definition of a primary key?" - id: q5 question: "What is a foreign key?" variations: - - "Define foreign key." - - "What is the purpose of a foreign key?" - - "Explain the concept of a foreign key." - - "How do tables relate to each other using keys?" - - "Describe a foreign key constraint." - - "What is an FK in a database?" - - "What allows a table to reference another table?" - - "Can you define foreign key?" - - "What is the role of foreign keys?" - - "How is referential integrity enforced?" - - "What key links two tables together?" + - "Can you explain what a foreign key is?" + - "How would you define a foreign key?" + - "Please tell me what a foreign key is." + - "What exactly is a foreign key?" + - "What is the definition of a foreign key?" - id: q6 question: "What is an Entity-Relationship model?" variations: - - "Define E-R model." - - "Explain the Entity-Relationship diagram." - - "What is an ERD?" - - "Describe the E-R data model." - - "What are entities and relationships?" - - "What is the purpose of the ER model?" - - "Can you explain E-R modeling?" - - "What is an entity in a database design?" - - "What model is used for conceptual database design?" - - "Define the components of an E-R model." - - "What is entity relationship modeling?" + - "Can you explain what an Entity-Relationship model is?" + - "How would you define an Entity-Relationship model?" + - "Please tell me what an Entity-Relationship model is." + - "What exactly is an Entity-Relationship model?" + - "What is the definition of an Entity-Relationship model?" - id: q7 question: "What is normalization?" variations: - - "Define database normalization." - - "Why do we normalize databases?" - - "Explain the process of normalization." - - "What is the goal of normalization?" - - "Describe normalization in DBMS." - - "What are normal forms?" - - "How do we reduce redundancy in databases?" - - "Can you define normalization?" - - "What is the purpose of normalizing data?" - - "What is 1NF, 2NF, 3NF?" - - "Explain data normalization." + - "Can you explain what normalization is?" + - "How would you define normalization?" + - "Please tell me what normalization is." + - "What exactly is normalization?" + - "What is the definition of normalization?" - id: q8 question: "What is a transaction?" variations: - - "Define a database transaction." - - "What is a unit of work in a database?" - - "Explain the concept of a transaction." - - "What constitutes a transaction in DBMS?" - - "Describe a database transaction." - - "What are the ACID properties apply to?" - - "Can you define a transaction?" - - "What is an atomic unit of execution?" - - "What is a sequence of operations treated as a single unit?" - - "Explain database transactions." + - "Can you explain what a transaction is?" + - "How would you define a transaction?" + - "Please tell me what a transaction is." + - "What exactly is a transaction?" - "What is the definition of a transaction?" - id: q9 question: "What are ACID properties?" variations: - - "Define ACID in databases." - - "What does ACID stand for?" - - "Explain Atomicity, Consistency, Isolation, Durability." - - "What are the properties of a transaction?" - - "Describe the ACID properties." - - "What guarantees reliability in transactions?" - - "Can you explain ACID?" - - "What safeguards database transactions?" - - "What does A.C.I.D. mean?" - - "List the transaction properties." - - "What is atomicity and durability?" + - "Can you explain what ACID properties are?" + - "How would you define ACID properties?" + - "Please tell me what ACID properties are." + - "What exactly are ACID properties?" + - "What is the definition of ACID properties?" - id: q10 question: "What is concurrency control?" variations: - - "Define concurrency control." - - "Why is concurrency control needed?" - - "Explain how databases handle simultaneous access." - - "What manages multiple transactions running at once?" - - "Describe concurrency control mechanisms." - - "How do we prevent conflicts in databases?" - - "Can you define concurrency control?" - - "What ensures data consistency in multi-user environments?" - - "What is the purpose of locking in databases?" - - "Explain concurrency in DBMS." - - "How are concurrent transactions managed?" + - "Can you explain what concurrency control is?" + - "How would you define concurrency control?" + - "Please tell me what concurrency control is." + - "What exactly is concurrency control?" + - "What is the definition of concurrency control?" - id: q11 question: "What involves indexing in databases?" variations: - - "Define database indexing." - - "What is the purpose of an index?" - - "How do indexes improve performance?" - - "Explain indexing in DBMS." - - "What structure speeds up data retrieval?" - - "Describe how a database index works." - - "Can you define indexing?" - - "What is a B-tree index?" - - "Why create an index on a column?" - - "What optimizes search queries?" - - "Explain the concept of an index." + - "Can you explain what involves indexing in databases?" + - "Please describe what involves indexing in databases." + - "What exactly involves indexing in databases?" + - "I would like to know what involves indexing in databases." + - "Could you clarify what involves indexing in databases?" - id: q12 question: "What is a B+ tree?" variations: - - "Define B+ tree." - - "How does a B+ tree work?" - - "What is the difference between B-tree and B+ tree?" - - "Explain the B+ tree data structure." - - "What is commonly used for database indexing?" - - "Describe a B+ tree." - - "Can you explain B+ trees?" - - "What is a balanced tree structure in databases?" - - "How are range queries handled in B+ trees?" - - "What is the structure of a B+ tree?" - - "Why use B+ trees for indexing?" + - "Can you explain what a B+ tree is?" + - "How would you define a B+ tree?" + - "Please tell me what a B+ tree is." + - "What exactly is a B+ tree?" + - "What is the definition of a B+ tree?" - id: q13 question: "What is hashing?" variations: - - "Define hashing in databases." - - "What is a hash function?" - - "Explain hash-based indexing." - - "How does hashing work?" - - "Describe static and dynamic hashing." - - "What is a hash table?" - - "Can you define hashing?" - - "What technique computes a location from a key?" - - "What is extendable hashing?" - - "Explain the concept of hashing." - - "How is data retrieved using hashing?" + - "Can you explain what hashing is?" + - "How would you define hashing?" + - "Please tell me what hashing is." + - "What exactly is hashing?" + - "What is the definition of hashing?" - id: q14 question: "What is query optimization?" variations: - - "Define query optimization." - - "How does a database optimize queries?" - - "Explain the query optimizer's role." - - "What is the most efficient execution plan?" - - "Describe query optimization." - - "How does the DBMS choose a query plan?" - - "Can you explain query optimization?" - - "What constitutes query processing?" - - "How are SQL queries made faster?" - - "What involves cost-based optimization?" - - "Explain the query evaluation engine." + - "Can you explain what query optimization is?" + - "How would you define query optimization?" + - "Please tell me what query optimization is." + - "What exactly is query optimization?" + - "What is the definition of query optimization?" - id: q15 question: "What is Big Data?" variations: - - "Define Big Data." - - "What characterizes Big Data?" - - "Explain the concept of Big Data." - - "What are the 3 Vs of Big Data?" - - "Describe Big Data analytics." - - "What is large-scale data processing?" - - "Can you define Big Data?" - - "What defines datasets too large for traditional DBs?" + - "Can you explain what Big Data is?" + - "How would you define Big Data?" + - "Please tell me what Big Data is." + - "What exactly is Big Data?" - "What is the definition of Big Data?" - - "Explain the term Big Data." - - "What technology handles massive datasets?" - id: q16 question: "What is MapReduce?" variations: - - "Define MapReduce." - - "How does MapReduce work?" - - "Explain the Map and Reduce phases." - - "What is the MapReduce programming model?" - - "Describe the MapReduce paradigm." - - "What framework processes data in parallel?" - - "Can you explain MapReduce?" - - "What is used for distributed processing?" - - "How is data processed in Hadoop?" - - "What involves mapping followed by reducing?" - - "Explain the MapReduce algorithm." + - "Can you explain what MapReduce is?" + - "How would you define MapReduce?" + - "Please tell me what MapReduce is." + - "What exactly is MapReduce?" + - "What is the definition of MapReduce?" - id: q17 question: "What is NoSQL?" variations: - - "Define NoSQL." - - "What does NoSQL stand for?" - - "Explain NoSQL databases." - - "How is NoSQL different from SQL?" - - "Describe non-relational databases." - - "What are key-value stores?" - - "Can you define NoSQL?" - - "What databases are schema-less?" - - "What is Not Only SQL?" - - "Explain the types of NoSQL databases." - - "What is a document store?" + - "Can you explain what NoSQL is?" + - "How would you define NoSQL?" + - "Please tell me what NoSQL is." + - "What exactly is NoSQL?" + - "What is the definition of NoSQL?" - id: q18 question: "What is data mining?" variations: - - "Define data mining." - - "What is the process of discovering patterns?" - - "Explain data mining." - - "What involves extracting knowledge from data?" - - "Describe data mining techniques." - - "How do we find hidden patterns in data?" - - "Can you define data mining?" - - "What is knowledge discovery in databases?" - - "What is KDD?" - - "Explain the concept of data mining." - - "What is association rule mining?" + - "Can you explain what data mining is?" + - "How would you define data mining?" + - "Please tell me what data mining is." + - "What exactly is data mining?" + - "What is the definition of data mining?" - id: q19 question: "What is a data warehouse?" variations: - - "Define data warehouse." - - "What is the purpose of a data warehouse?" - - "Explain data warehousing." - - "How is a data warehouse different from a database?" - - "Describe a data warehouse." - - "What stores historical data for analysis?" - - "Can you define a data warehouse?" - - "What is an OLAP system?" - - "What supports business intelligence?" - - "Explain the concept of a data warehouse." - - "What is a subject-oriented integrated data collection?" + - "Can you explain what a data warehouse is?" + - "How would you define a data warehouse?" + - "Please tell me what a data warehouse is." + - "What exactly is a data warehouse?" + - "What is the definition of a data warehouse?" - id: q20 question: "What is distributed database?" variations: - - "Define distributed database." - - "What is a distributed DBMS?" - - "Explain distributed databases." - - "How is data stored across multiple sites?" - - "Describe a distributed database system." - - "What involves data fragmentation and replication?" - - "Can you define centralized vs distributed databases?" - - "What is a DDBMS?" - - "How do distributed transactions work?" - - "Explain data distribution." - - "What manages data across a network?" + - "Can you explain what distributed database is?" + - "How would you define distributed database?" + - "Please tell me what distributed database is." + - "What exactly is distributed database?" + - "What is the definition of distributed database?" - id: q21 question: "What is database security?" variations: - - "Define database security." - - "Why is database security important?" - - "Explain how to secure a database." - - "What protects data from unauthorized access?" - - "Describe database security measures." - - "What involves authentication and authorization?" - - "Can you define database security?" - - "What is encryption in databases?" - - "How do we prevent SQL injection?" - - "Explain data privacy and security." - - "What ensures confidentiality of data?" + - "Can you explain what database security is?" + - "How would you define database security?" + - "Please tell me what database security is." + - "What exactly is database security?" + - "What is the definition of database security?" - id: q22 question: "What is RAID?" variations: - - "Define RAID." - - "What does RAID stand for?" - - "Explain Redundant Array of Independent Disks." - - "How does RAID improve reliability?" - - "Describe RAID levels." - - "What is disk mirroring and striping?" - - "Can you explain RAID?" - - "What provides disk redundancy?" - - "What is RAID 0, 1, 5?" - - "Explain storage virtualization." - - "What technology combines multiple disk drives?" + - "Can you explain what RAID is?" + - "How would you define RAID?" + - "Please tell me what RAID is." + - "What exactly is RAID?" + - "What is the definition of RAID?" - id: q23 question: "What is a view in SQL?" variations: - - "Define SQL view." - - "What is a virtual table?" - - "Explain the concept of a view." - - "How do you create a view in SQL?" - - "Describe a database view." - - "What is a stored query result?" - - "Can you define a view?" - - "What simplifies complex queries?" - - "What provides a customized look at data?" - - "Explain views in relational databases." - - "What is the utility of a view?" + - "Can you explain what a view in SQL is?" + - "How would you define a view in SQL?" + - "Please tell me what a view in SQL is." + - "What exactly is a view in SQL?" + - "What is the definition of a view in SQL?" - id: q24 question: "What is a trigger?" variations: - - "Define database trigger." - - "What is an SQL trigger?" - - "Explain how triggers work." - - "When does a trigger execute?" - - "Describe a trigger." - - "What is an automatically executed procedure?" - - "Can you define a trigger?" - - "What responds to INSERT, UPDATE, DELETE?" - - "What enforces business rules automatically?" - - "Explain the use of triggers." - - "What is event-driven execution in DB?" + - "Can you explain what a trigger is?" + - "How would you define a trigger?" + - "Please tell me what a trigger is." + - "What exactly is a trigger?" + - "What is the definition of a trigger?" - id: q25 question: "What is two-phase locking?" variations: - - "Define two-phase locking." - - "What is 2PL?" - - "Explain the 2PL protocol." - - "How does two-phase locking ensure serializability?" - - "Describe the growing and shrinking phases." - - "What prevents conflict serializability?" - - "Can you explain two-phase locking?" - - "What locking protocol is standard?" - - "What is strict two-phase locking?" - - "Explain the locking phase and unlocking phase." - - "What involves expanding and shrinking of locks?" + - "Can you explain what two-phase locking is?" + - "How would you define two-phase locking?" + - "Please tell me what two-phase locking is." + - "What exactly is two-phase locking?" + - "What is the definition of two-phase locking?" - id: q26 question: "What is deadlock?" variations: - - "Define deadlock in databases." - - "What causes a deadlock?" - - "Explain database deadlock." - - "What happens when transactions wait for each other?" - - "Describe a deadlock situation." - - "How do we handle deadlocks?" - - "Can you define deadlock?" - - "What is a cycle of waiting transactions?" - - "What requires deadlock detection?" - - "Explain deadlock prevention." - - "What stops transactions from proceeding permanently?" + - "Can you explain what deadlock is?" + - "How would you define deadlock?" + - "Please tell me what deadlock is." + - "What exactly is deadlock?" + - "What is the definition of deadlock?" - id: q27 question: "What is write-ahead logging?" variations: - - "Define write-ahead logging." - - "What is WAL?" - - "Explain the log-based recovery." - - "Why write to log before data file?" - - "Describe write-ahead logging." - - "What ensures durability in databases?" - - "Can you explain WAL?" - - "What records changes before applying them?" - - "What is the recovery log?" - - "Explain the logging protocol." - - "How does recovery work using logs?" + - "Can you explain what write-ahead logging is?" + - "How would you define write-ahead logging?" + - "Please tell me what write-ahead logging is." + - "What exactly is write-ahead logging?" + - "What is the definition of write-ahead logging?" - id: q28 question: "What is a stored procedure?" variations: - - "Define stored procedure." - - "What is a procedure stored in the database?" - - "Explain stored procedures." - - "How do stored procedures differ from functions?" - - "Describe a stored proc." - - "What executes logic on the database server?" - - "Can you define a stored procedure?" - - "What precompiled SQL code is saved?" - - "What encapsulates logic in the DB?" - - "Explain the benefits of stored procedures." - - "What is PL/SQL routine?" + - "Can you explain what a stored procedure is?" + - "How would you define a stored procedure?" + - "Please tell me what a stored procedure is." + - "What exactly is a stored procedure?" + - "What is the definition of a stored procedure?" - id: q29 question: "What is database recovery?" variations: - - "Define database recovery." - - "What happens after a system crash?" - - "Explain recovery techniques." - - "How do we restore a database?" - - "Describe database recovery management." - - "What involves ARIES algorithm?" - - "Can you define recovery?" - - "What restores consistency after failure?" - - "What is log-based recovery?" - - "Explain shadow paging." - - "How is data durability ensured?" + - "Can you explain what database recovery is?" + - "How would you define database recovery?" + - "Please tell me what database recovery is." + - "What exactly is database recovery?" + - "What is the definition of database recovery?" - id: q30 question: "What is XML?" variations: - - "Define XML." - - "What does XML stand for?" - - "Explain Extensible Markup Language." - - "How is XML used in databases?" - - "Describe XML data format." - - "What is semi-structured data?" - - "Can you define XML?" - - "What uses tags to define data?" - - "What is an XML database?" - - "Explain parsing XML." - - "What is the difference between HTML and XML?" + - "Can you explain what XML is?" + - "How would you define XML?" + - "Please tell me what XML is." + - "What exactly is XML?" + - "What is the definition of XML?" From e87e664697ed7df9a3c0836956a101cd6adbbee2 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 5 Mar 2026 19:05:13 -0800 Subject: [PATCH 07/20] add flag for semantic caching --- config/config.yaml | 3 ++- src/config.py | 3 +++ src/main.py | 38 ++++++++++++++++------------------- tests/test_cache_benchmark.py | 1 + 4 files changed, 23 insertions(+), 22 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 56ead834..1158f954 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -13,4 +13,5 @@ use_hyde: false hyde_max_tokens: 300 use_indexed_chunks: false rerank_mode: "cross_encoder" -rerank_top_k: 5 \ No newline at end of file +rerank_top_k: 5 +semantic_cache_enabled: false \ No newline at end of file diff --git a/src/config.py b/src/config.py index 27c67001..3880fe32 100644 --- a/src/config.py +++ b/src/config.py @@ -43,6 +43,9 @@ class RAGConfig: # query enhancement use_hyde: bool = False hyde_max_tokens: int = 300 + + # cache + semantic_cache_enabled: bool = False # index parameters use_indexed_chunks: bool = False diff --git a/src/main.py b/src/main.py index 734546ca..6d75682f 100644 --- a/src/main.py +++ b/src/main.py @@ -116,23 +116,18 @@ def get_answer( topk_idxs: List[int] = [] scores = [] - logger.log_query_start(question) - normalized_question = normalize_question(question) config_cache_key = make_cache_config_key(cfg, args, golden_chunks) question_embedding: Optional[np.ndarray] = None semantic_hit: Optional[Dict[str, Any]] = None # Check semantic cache - if config_cache_key in SEMANTIC_CACHE: + if cfg.semantic_cache_enabled and config_cache_key in SEMANTIC_CACHE: question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) semantic_hit = semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) # Return cached answer if found - if semantic_hit: - chunk_indices = semantic_hit.get("chunk_indices", []) - if chunk_indices and not (cfg.disable_chunks or cfg.use_indexed_chunks): - logger.log_chunks_used(chunk_indices, chunks, sources) + if cfg.semantic_cache_enabled and semantic_hit: ans = semantic_hit.get("answer", "") @@ -221,20 +216,21 @@ def get_answer( ans = render_streaming_ans(console, stream_iter) # Store in semantic cache - cache_payload = { - "answer": ans, - "chunks_info": chunks_info, - "hyde_query": hyde_query, - "chunk_indices": topk_idxs, - } - if question_embedding is None: - question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) - semantic_cache_store( - config_cache_key, - normalized_question, - question_embedding, - cache_payload - ) + if cfg.semantic_cache_enabled: + cache_payload = { + "answer": ans, + "chunks_info": chunks_info, + "hyde_query": hyde_query, + "chunk_indices": topk_idxs, + } + if question_embedding is None: + question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) + semantic_cache_store( + config_cache_key, + normalized_question, + question_embedding, + cache_payload + ) if is_test_mode: return ans, chunks_info, hyde_query diff --git a/tests/test_cache_benchmark.py b/tests/test_cache_benchmark.py index c45e6d48..460b79e4 100644 --- a/tests/test_cache_benchmark.py +++ b/tests/test_cache_benchmark.py @@ -49,6 +49,7 @@ def mock_config(): config.use_indexed_chunks = False config.disable_chunks = False config.use_golden_chunks = False + config.semantic_cache_enabled = True return config def test_cache_benchmark_comprehensive(mock_config): From 76e02f5c111ed4bc7c5b4387a05dc115ce976cf0 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 5 Mar 2026 19:52:53 -0800 Subject: [PATCH 08/20] Fix unit tests for semantic caching --- src/cache.py | 4 +-- src/main.py | 38 +++++++++++++-------------- tests/test_benchmarks.py | 57 ++++++++++++++++++++-------------------- 3 files changed, 49 insertions(+), 50 deletions(-) diff --git a/src/cache.py b/src/cache.py index 8cb67a92..087abd25 100644 --- a/src/cache.py +++ b/src/cache.py @@ -42,10 +42,10 @@ def make_cache_config_key(cfg: RAGConfig, args: argparse.Namespace, golden_chunk Create a unique JSON key for semantic cache based on config, arguments, and optional golden chunks. """ payload = { - "gen_model": args.model_path or cfg.gen_model, + "gen_model": getattr(args, "model_path", None) or cfg.gen_model, "embed_model": cfg.embed_model, "top_k": cfg.top_k, - "system_prompt_mode": args.system_prompt_mode or cfg.system_prompt_mode, + "system_prompt_mode": getattr(args, "system_prompt_mode", None) or cfg.system_prompt_mode, "ensemble_method": cfg.ensemble_method, "ranker_weights": cfg.ranker_weights, "use_hyde": cfg.use_hyde, diff --git a/src/main.py b/src/main.py index 6d75682f..9b0d4b74 100644 --- a/src/main.py +++ b/src/main.py @@ -235,25 +235,25 @@ def get_answer( if is_test_mode: return ans, chunks_info, hyde_query - # Logging - meta = artifacts.get("meta", []) - page_nums = get_page_numbers(topk_idxs, meta) - logger.save_chat_log( - query=question, - config_state=cfg.get_config_state(), - ordered_scores=scores[:len(topk_idxs)] if 'scores' in locals() else [], - chat_request_params={ - "system_prompt": system_prompt, - "max_tokens": cfg.max_gen_tokens - }, - top_idxs=topk_idxs, - chunks=chunks, - sources=sources, - page_map=page_nums, - full_response=ans, - top_k=len(topk_idxs) - ) - return ans + # Logging + meta = artifacts.get("meta", []) + page_nums = get_page_numbers(topk_idxs, meta) + logger.save_chat_log( + query=question, + config_state=cfg.get_config_state(), + ordered_scores=scores[:len(topk_idxs)] if 'scores' in locals() else [], + chat_request_params={ + "system_prompt": system_prompt, + "max_tokens": cfg.max_gen_tokens + }, + top_idxs=topk_idxs, + chunks=chunks, + sources=sources, + page_map=page_nums, + full_response=ans, + top_k=len(topk_idxs) + ) + return ans def render_streaming_ans(console, stream_iter): ans = "" diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index 124952fe..f707eb3b 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -187,31 +187,28 @@ def get_tokensmith_answer(question, config, golden_chunks=None): system_prompt_mode=config.get("system_prompt_mode"), ) - # Create RAGConfig from our test config - cfg = RAGConfig( - chunk_config=RAGConfig.get_chunk_config(config), - top_k=config.get("top_k", 10), - pool_size=config.get("pool_size", 60), - embed_model=config.get("embed_model"), - ensemble_method=config.get("retrieval_method", "rrf"), - rrf_k=60, - ranker_weights=config.get("ranker_weights", {"faiss": 1, "bm25": 0}), - rerank_mode=config.get("rerank_mode", "none"), - rerank_top_k=config.get("rerank_top_k", 5), - seg_filter=config.get("seg_filter", None), - system_prompt_mode=config.get("system_prompt_mode", "baseline"), - max_gen_tokens=config.get("max_gen_tokens", 400), - model_path=config.get("model_path"), - disable_chunks=config.get("disable_chunks", False), - use_golden_chunks=config.get("use_golden_chunks", False), - output_mode=config.get("output_mode", "html"), - metrics=config.get("metrics", ["all"]), - use_hyde=config.get("use_hyde", False), - hyde_max_tokens=config.get("hyde_max_tokens", 300), - use_indexed_chunks=config.get("use_indexed_chunks", False), - extracted_index_path=config.get("extracted_index_path", "data/extracted_index.json"), - page_to_chunk_map_path=config.get("page_to_chunk_map_path", "index/sections/textbook_index_page_to_chunk_map.json"), - ) + # Create RAGConfig from config.yaml and override with test config + cfg = RAGConfig.from_yaml("config/config.yaml") + if "top_k" in config: cfg.top_k = config.get("top_k", cfg.top_k) + if "pool_size" in config: cfg.num_candidates = config.get("pool_size", cfg.num_candidates) + if "embed_model" in config: cfg.embed_model = config.get("embed_model", cfg.embed_model) + if "ensemble_method" in config: cfg.ensemble_method = config.get("ensemble_method", cfg.ensemble_method) + if "retrieval_method" in config: cfg.ensemble_method = config.get("retrieval_method", cfg.ensemble_method) + if "ranker_weights" in config and config["ranker_weights"] is not None: cfg.ranker_weights = config.get("ranker_weights", cfg.ranker_weights) + if "rerank_mode" in config: cfg.rerank_mode = config.get("rerank_mode", cfg.rerank_mode) + if "rerank_top_k" in config: cfg.rerank_top_k = config.get("rerank_top_k", cfg.rerank_top_k) + if "system_prompt_mode" in config: cfg.system_prompt_mode = config.get("system_prompt_mode", cfg.system_prompt_mode) + if "max_gen_tokens" in config: cfg.max_gen_tokens = config.get("max_gen_tokens", cfg.max_gen_tokens) + if "model_path" in config: cfg.gen_model = config.get("model_path", cfg.gen_model) + if "disable_chunks" in config: cfg.disable_chunks = config.get("disable_chunks", cfg.disable_chunks) + if "use_golden_chunks" in config: cfg.use_golden_chunks = config.get("use_golden_chunks", cfg.use_golden_chunks) + if "output_mode" in config: cfg.output_mode = config.get("output_mode", cfg.output_mode) + if "metrics" in config: cfg.metrics = config.get("metrics", cfg.metrics) + if "use_hyde" in config: cfg.use_hyde = config.get("use_hyde", cfg.use_hyde) + if "hyde_max_tokens" in config: cfg.hyde_max_tokens = config.get("hyde_max_tokens", cfg.hyde_max_tokens) + if "use_indexed_chunks" in config: cfg.use_indexed_chunks = config.get("use_indexed_chunks", cfg.use_indexed_chunks) + if "extracted_index_path" in config: cfg.extracted_index_path = config.get("extracted_index_path", cfg.extracted_index_path) + if "page_to_chunk_map_path" in config: cfg.page_to_chunk_map_path = config.get("page_to_chunk_map_path", cfg.page_to_chunk_map_path) # Print status if golden_chunks and config["use_golden_chunks"]: @@ -226,8 +223,8 @@ def get_tokensmith_answer(question, config, golden_chunks=None): logger = get_logger() # Run the query through the main pipeline - artifacts_dir = cfg.make_artifacts_directory() - faiss_index, bm25_index, chunks, sources = load_artifacts( + artifacts_dir = cfg.get_artifacts_directory() + faiss_index, bm25_index, chunks, sources, meta = load_artifacts( artifacts_dir=artifacts_dir, index_prefix=config["index_prefix"] ) @@ -254,16 +251,18 @@ def get_tokensmith_answer(question, config, golden_chunks=None): "chunks": chunks, "sources": sources, "retrievers": retrievers, - "ranker": ranker + "ranker": ranker, + "meta": meta } + from rich.console import Console result = get_answer( question=question, cfg=cfg, args=args, logger=logger, artifacts=artifacts, - console=None, + console=Console(quiet=True), golden_chunks=golden_chunks, is_test_mode=True ) From d539b0ad43e62abe048ae22c077f9e9d3d3701ef Mon Sep 17 00:00:00 2001 From: infinite-void Date: Wed, 11 Mar 2026 02:16:57 -0400 Subject: [PATCH 09/20] resolve review comments --- src/cache.py | 19 +++++++++----- tests/test_benchmarks.py | 57 ++++++++++++++++++++-------------------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/src/cache.py b/src/cache.py index 087abd25..771a2676 100644 --- a/src/cache.py +++ b/src/cache.py @@ -2,7 +2,8 @@ import argparse import json import hashlib -from typing import Dict, Optional, Any, List +from typing import Dict, Optional, Any, List, Deque +from collections import deque import numpy as np from sentence_transformers import CrossEncoder @@ -14,8 +15,10 @@ # ----------------------------- # Global cache and constants # ----------------------------- -SEMANTIC_CACHE: Dict[str, List[Dict[str, Any]]] = {} +SEMANTIC_CACHE: Dict[str, Deque[Dict[str, Any]]] = {} SEMANTIC_CACHE_THRESHOLD = 0.85 +BI_ENCODER_THRESHOLD = 0.40 +CROSS_ENCODER_THRESHOLD = 0.75 SEMANTIC_CACHE_MAX_ENTRIES = 50 QUESTION_EMBEDDERS: Dict[str, SentenceTransformer] = {} CROSS_ENCODER_MODEL: Optional[CrossEncoder] = None @@ -75,7 +78,7 @@ def semantic_cache_lookup(config_key: str, query_embedding: np.ndarray, current_ # Step 1: Bi-Encoder filter (fast cosine similarity) candidates = [ entry for entry in entries - if entry.get("embedding") is not None and float(np.dot(entry["embedding"], query_embedding)) > 0.40 + if np.dot(entry["embedding"], query_embedding) > BI_ENCODER_THRESHOLD ] if not candidates: return None @@ -86,7 +89,7 @@ def semantic_cache_lookup(config_key: str, query_embedding: np.ndarray, current_ ce_scores = ce_model.predict(pairs, show_progress_bar=False) best_idx = int(np.argmax(ce_scores)) - if ce_scores[best_idx] > 0.75: + if ce_scores[best_idx] > CROSS_ENCODER_THRESHOLD: return candidates[best_idx]["payload"] return None @@ -99,7 +102,9 @@ def semantic_cache_store(config_key: str, normalized_question: str, question_emb if question_embedding is None: return - entries = SEMANTIC_CACHE.setdefault(config_key, []) + if config_key not in SEMANTIC_CACHE: + SEMANTIC_CACHE[config_key] = deque() + entries = SEMANTIC_CACHE[config_key] entries.append({ "question": normalized_question, "embedding": question_embedding.astype(np.float32), @@ -107,7 +112,7 @@ def semantic_cache_store(config_key: str, normalized_question: str, question_emb }) if len(entries) > SEMANTIC_CACHE_MAX_ENTRIES: - entries.pop(0) + entries.popleft() # ----------------------------- @@ -118,7 +123,7 @@ def get_question_embedder(retrievers: List[Any], embed_model: str) -> Optional[S Get or initialize a SentenceTransformer for encoding questions. Prefers the embedder from any FAISSRetriever in the retrievers list. """ - for retriever in retrievers or []: + for retriever in retrievers: if isinstance(retriever, FAISSRetriever): return retriever.embedder diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index f707eb3b..124952fe 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -187,28 +187,31 @@ def get_tokensmith_answer(question, config, golden_chunks=None): system_prompt_mode=config.get("system_prompt_mode"), ) - # Create RAGConfig from config.yaml and override with test config - cfg = RAGConfig.from_yaml("config/config.yaml") - if "top_k" in config: cfg.top_k = config.get("top_k", cfg.top_k) - if "pool_size" in config: cfg.num_candidates = config.get("pool_size", cfg.num_candidates) - if "embed_model" in config: cfg.embed_model = config.get("embed_model", cfg.embed_model) - if "ensemble_method" in config: cfg.ensemble_method = config.get("ensemble_method", cfg.ensemble_method) - if "retrieval_method" in config: cfg.ensemble_method = config.get("retrieval_method", cfg.ensemble_method) - if "ranker_weights" in config and config["ranker_weights"] is not None: cfg.ranker_weights = config.get("ranker_weights", cfg.ranker_weights) - if "rerank_mode" in config: cfg.rerank_mode = config.get("rerank_mode", cfg.rerank_mode) - if "rerank_top_k" in config: cfg.rerank_top_k = config.get("rerank_top_k", cfg.rerank_top_k) - if "system_prompt_mode" in config: cfg.system_prompt_mode = config.get("system_prompt_mode", cfg.system_prompt_mode) - if "max_gen_tokens" in config: cfg.max_gen_tokens = config.get("max_gen_tokens", cfg.max_gen_tokens) - if "model_path" in config: cfg.gen_model = config.get("model_path", cfg.gen_model) - if "disable_chunks" in config: cfg.disable_chunks = config.get("disable_chunks", cfg.disable_chunks) - if "use_golden_chunks" in config: cfg.use_golden_chunks = config.get("use_golden_chunks", cfg.use_golden_chunks) - if "output_mode" in config: cfg.output_mode = config.get("output_mode", cfg.output_mode) - if "metrics" in config: cfg.metrics = config.get("metrics", cfg.metrics) - if "use_hyde" in config: cfg.use_hyde = config.get("use_hyde", cfg.use_hyde) - if "hyde_max_tokens" in config: cfg.hyde_max_tokens = config.get("hyde_max_tokens", cfg.hyde_max_tokens) - if "use_indexed_chunks" in config: cfg.use_indexed_chunks = config.get("use_indexed_chunks", cfg.use_indexed_chunks) - if "extracted_index_path" in config: cfg.extracted_index_path = config.get("extracted_index_path", cfg.extracted_index_path) - if "page_to_chunk_map_path" in config: cfg.page_to_chunk_map_path = config.get("page_to_chunk_map_path", cfg.page_to_chunk_map_path) + # Create RAGConfig from our test config + cfg = RAGConfig( + chunk_config=RAGConfig.get_chunk_config(config), + top_k=config.get("top_k", 10), + pool_size=config.get("pool_size", 60), + embed_model=config.get("embed_model"), + ensemble_method=config.get("retrieval_method", "rrf"), + rrf_k=60, + ranker_weights=config.get("ranker_weights", {"faiss": 1, "bm25": 0}), + rerank_mode=config.get("rerank_mode", "none"), + rerank_top_k=config.get("rerank_top_k", 5), + seg_filter=config.get("seg_filter", None), + system_prompt_mode=config.get("system_prompt_mode", "baseline"), + max_gen_tokens=config.get("max_gen_tokens", 400), + model_path=config.get("model_path"), + disable_chunks=config.get("disable_chunks", False), + use_golden_chunks=config.get("use_golden_chunks", False), + output_mode=config.get("output_mode", "html"), + metrics=config.get("metrics", ["all"]), + use_hyde=config.get("use_hyde", False), + hyde_max_tokens=config.get("hyde_max_tokens", 300), + use_indexed_chunks=config.get("use_indexed_chunks", False), + extracted_index_path=config.get("extracted_index_path", "data/extracted_index.json"), + page_to_chunk_map_path=config.get("page_to_chunk_map_path", "index/sections/textbook_index_page_to_chunk_map.json"), + ) # Print status if golden_chunks and config["use_golden_chunks"]: @@ -223,8 +226,8 @@ def get_tokensmith_answer(question, config, golden_chunks=None): logger = get_logger() # Run the query through the main pipeline - artifacts_dir = cfg.get_artifacts_directory() - faiss_index, bm25_index, chunks, sources, meta = load_artifacts( + artifacts_dir = cfg.make_artifacts_directory() + faiss_index, bm25_index, chunks, sources = load_artifacts( artifacts_dir=artifacts_dir, index_prefix=config["index_prefix"] ) @@ -251,18 +254,16 @@ def get_tokensmith_answer(question, config, golden_chunks=None): "chunks": chunks, "sources": sources, "retrievers": retrievers, - "ranker": ranker, - "meta": meta + "ranker": ranker } - from rich.console import Console result = get_answer( question=question, cfg=cfg, args=args, logger=logger, artifacts=artifacts, - console=Console(quiet=True), + console=None, golden_chunks=golden_chunks, is_test_mode=True ) From 8614e7d68a6226fe8b823b04ba71cb45a451089b Mon Sep 17 00:00:00 2001 From: shahmeer99 Date: Thu, 12 Mar 2026 23:14:12 -0400 Subject: [PATCH 10/20] use RAGConfig methid get_config_state() to get payload for hashing key --- src/cache.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/cache.py b/src/cache.py index 771a2676..fb30e5de 100644 --- a/src/cache.py +++ b/src/cache.py @@ -44,20 +44,25 @@ def make_cache_config_key(cfg: RAGConfig, args: argparse.Namespace, golden_chunk """ Create a unique JSON key for semantic cache based on config, arguments, and optional golden chunks. """ - payload = { - "gen_model": getattr(args, "model_path", None) or cfg.gen_model, - "embed_model": cfg.embed_model, - "top_k": cfg.top_k, - "system_prompt_mode": getattr(args, "system_prompt_mode", None) or cfg.system_prompt_mode, - "ensemble_method": cfg.ensemble_method, - "ranker_weights": cfg.ranker_weights, - "use_hyde": cfg.use_hyde, - "use_indexed_chunks": cfg.use_indexed_chunks, - "disable_chunks": cfg.disable_chunks, - "use_golden_chunks": bool(golden_chunks and cfg.use_golden_chunks), - "index_prefix": getattr(args, "index_prefix", None), - } + try: + payload = RAGConfig.get_config_state() + except Exception as e: + payload = { + "gen_model": getattr(args, "model_path", None) or cfg.gen_model, + "embed_model": cfg.embed_model, + "top_k": cfg.top_k, + "system_prompt_mode": getattr(args, "system_prompt_mode", None) or cfg.system_prompt_mode, + "ensemble_method": cfg.ensemble_method, + "ranker_weights": cfg.ranker_weights, + "use_hyde": cfg.use_hyde, + "use_indexed_chunks": cfg.use_indexed_chunks, + "disable_chunks": cfg.disable_chunks, + "use_golden_chunks": bool(golden_chunks and cfg.use_golden_chunks), + "index_prefix": getattr(args, "index_prefix", None), + } + + # !!! Unnencessary to include - to remove later !!! if golden_chunks and cfg.use_golden_chunks: signature = hashlib.sha256("||".join(golden_chunks).encode("utf-8")).hexdigest() payload["golden_signature"] = signature From d783bc36b73bb866f9ea29c1d62c3a4f7c712ea6 Mon Sep 17 00:00:00 2001 From: shahmeer99 Date: Thu, 12 Mar 2026 23:57:51 -0400 Subject: [PATCH 11/20] add gpu support try for generator --- src/generator.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/generator.py b/src/generator.py index 5a4b1b74..5dbf050b 100644 --- a/src/generator.py +++ b/src/generator.py @@ -106,9 +106,20 @@ def format_prompt(chunks, query, max_chunk_chars=400, system_prompt_mode="tutor" def get_llama_model(model_path: str, n_ctx: int = 4096): if model_path not in _LLM_CACHE: - _LLM_CACHE[model_path] = Llama(model_path=model_path, + try: + _LLM_CACHE[model_path] = Llama(model_path=model_path, n_ctx=n_ctx, - verbose=False) + verbose=False, + # add gpu offloading with -1 + n_gpu_layers=-1, + ) + except Exception as e: + print(f"Error occurred while initializing Llama model on. GPU: {e}") + _LLM_CACHE[model_path] = Llama(model_path=model_path, + n_ctx=n_ctx, + verbose=False, + ) + return _LLM_CACHE[model_path] def stream_llama_cpp(prompt: str, model_path: str, max_tokens: int, temperature: float): From 1f275336a8464ed2c31180c6a4cfcacef0351fdd Mon Sep 17 00:00:00 2001 From: shahmeer99 Date: Fri, 13 Mar 2026 00:21:11 -0400 Subject: [PATCH 12/20] added comments to get_answer() in mainso all steps are explained --- src/main.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/src/main.py b/src/main.py index 21d4b01f..ca456f99 100644 --- a/src/main.py +++ b/src/main.py @@ -126,12 +126,13 @@ def get_answer( question_embedding: Optional[np.ndarray] = None semantic_hit: Optional[Dict[str, Any]] = None - # Check semantic cache + # STEP 1: Check Semantic Cache Hit (if enabled by ) if cfg.semantic_cache_enabled and config_cache_key in SEMANTIC_CACHE: question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) semantic_hit = semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) - # Return cached answer if found + # STEP 1.1: Return cached answer if found + # Note: Has fucntion exit to must do logging here if cfg.semantic_cache_enabled and semantic_hit: ans = semantic_hit.get("answer", "") @@ -142,30 +143,37 @@ def get_answer( render_final_answer(console, ans) return ans - # Step 1: Get chunks (golden, retrieved, or none) + # If no semantic hit, proceed with normal retrieval, ranking, and generation process + # Step 2: Retrieval chunks_info = None hyde_query = None if golden_chunks and cfg.use_golden_chunks: - # Use provided golden chunks + # Use provided golden chunks (testing mode only) ranked_chunks = golden_chunks elif cfg.disable_chunks: - # No chunks - baseline mode + # No chunks - baseline mode (only tests model knowledge) ranked_chunks = [] elif cfg.use_indexed_chunks: + # basic inverted index using keywords (keywords here are just non-stopword tokens in question) ranked_chunks, topk_idxs = use_indexed_chunks(question, chunks) else: + # Normal retrieval + ranking flow based on config retrieval_query = question - if cfg.use_hyde: + + # Step 2.1: [OPTIONAL] using HyDe + if cfg.use_hyde: # Hypothetical Document Embeddding Approach retrieval_query = generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens) + # Step 2.2: Get raw scores from each retriever pool_n = max(cfg.num_candidates, cfg.top_k + 10) raw_scores: Dict[str, Dict[int, float]] = {} for retriever in retrievers: raw_scores[retriever.name] = retriever.get_scores(retrieval_query, pool_n, chunks) # TODO: Fix retrieval logging. - - # Step 2: Ranking + + # Step 2.3: Rank retrieved chunks using ensemble method ordered, scores = ranker.rank(raw_scores=raw_scores) + # filter down ordered and chunks list to max len cfg.top_k topk_idxs = filter_retrieved_chunks(cfg, chunks, ordered) ranked_chunks = [chunks[i] for i in topk_idxs] @@ -199,7 +207,7 @@ def get_answer( "index_rank": index_ranks.get(idx, 0), }) - # Step 3: Final re-ranking + # Step 3: Reranking with cross-encoder (if configured) ranked_chunks = rerank(question, ranked_chunks, mode=cfg.rerank_mode, top_n=cfg.rerank_top_k) if not ranked_chunks and not cfg.disable_chunks: @@ -210,8 +218,8 @@ def get_answer( model_path = cfg.gen_model system_prompt = args.system_prompt_mode or cfg.system_prompt_mode + # Step 4.1: Check for double prompting approach to improve answer quality (if enabled by config or CLI arg) use_double = getattr(args, "double_prompt", False) or cfg.use_double_prompt - if use_double: stream_iter = double_answer( question, @@ -220,6 +228,7 @@ def get_answer( max_tokens=cfg.max_gen_tokens, system_prompt_mode=system_prompt, ) + # If not double prompting, use normal answer method from generator.py else: stream_iter = answer( question, @@ -235,7 +244,7 @@ def get_answer( # Accumulate the full text while rendering incremental Markdown chunks ans = render_streaming_ans(console, stream_iter) - # Store in semantic cache + # Step 5: Store in semantic cache if enabled by config if cfg.semantic_cache_enabled: cache_payload = { "answer": ans, @@ -255,7 +264,7 @@ def get_answer( if is_test_mode: return ans, chunks_info, hyde_query - # Logging + # Step 5: Logging - log all relevant information to a json file in logs/ directory meta = artifacts.get("meta", []) page_nums = get_page_numbers(topk_idxs, meta) logger.save_chat_log( @@ -273,6 +282,7 @@ def get_answer( full_response=ans, top_k=len(topk_idxs) ) + return ans def render_streaming_ans(console, stream_iter): From f433a4106dcb1a56287b304079d4214819f3fa9f Mon Sep 17 00:00:00 2001 From: shahmeer99 Date: Fri, 13 Mar 2026 01:00:40 -0400 Subject: [PATCH 13/20] updated semantic cache testing to include false positive hit checking --- tests/cache_benchmark.yaml | 180 ++++++++++++++++++++++++ tests/test_cache_benchmark.py | 253 ++++++++++++++++++++++++---------- 2 files changed, 358 insertions(+), 75 deletions(-) diff --git a/tests/cache_benchmark.yaml b/tests/cache_benchmark.yaml index c369dbe4..58d2e363 100644 --- a/tests/cache_benchmark.yaml +++ b/tests/cache_benchmark.yaml @@ -6,6 +6,12 @@ - "Please tell me what a database management system is." - "What exactly is a database management system?" - "What is the definition of a database management system?" + trick_variations: + - "What is a file management system?" # semantic: file vs database + - "What is an operating system?" # semantic: OS vs DBMS + - "What is a spreadsheet application?" # semantic: adjacent tool, different thing + - "What is the management style of a data scientist?" # syntactic: plays on "management" and "data" + - "What does it mean to manage a system database in Windows?" # syntactic: "manage" + "database" in different order/context - id: q2 question: "What is the relational model?" @@ -15,6 +21,12 @@ - "Please tell me what the relational model is." - "What exactly is the relational model?" - "What is the definition of the relational model?" + trick_variations: + - "What is the hierarchical model?" # semantic: sibling data model + - "What is the network model in databases?" # semantic: sibling data model + - "What is the object-oriented data model?" # semantic: sibling data model + - "What is the relationship between a model and a tokenizer?" # syntactic: relational → relationship, model stays + - "What is a relational conflict in psychology?" # syntactic: relational used in a completely different domain - id: q3 question: "What is SQL?" @@ -24,6 +36,12 @@ - "Please tell me what SQL is." - "What exactly is SQL?" - "What is the definition of SQL?" + trick_variations: + - "What is NoSQL?" # semantic: directly contrasted technology + - "What are some alternatives to SQL?" # semantic: adjacent but asking for different things + - "What is the sequel to a popular movie franchise?" # syntactic: SQL is pronounced "sequel" + - "What is a skill queue list in project management?" # syntactic: S-Q-L letters reinterpreted + - "What is MySQL?" # semantic: MySQL is a DBMS, not the language itself - id: q4 question: "What is a primary key?" @@ -33,6 +51,12 @@ - "Please tell me what a primary key is." - "What exactly is a primary key?" - "What is the definition of a primary key?" + trick_variations: + - "What is a foreign key?" # semantic: sibling key concept + - "What is a candidate key?" # semantic: sibling key concept + - "What is a composite key?" # semantic: sibling key concept + - "What is a master key used for in physical security?" # syntactic: primary/master are synonyms, but in a different domain + - "What is the primary responsibility of a database lock?" # syntactic: "primary" + "key" split apart into different meaning - id: q5 question: "What is a foreign key?" @@ -42,6 +66,12 @@ - "Please tell me what a foreign key is." - "What exactly is a foreign key?" - "What is the definition of a foreign key?" + trick_variations: + - "What is a primary key?" # semantic: sibling key concept + - "What is a surrogate key?" # semantic: sibling key concept + - "What is a join in SQL?" # semantic: joins use FKs but are a different concept + - "What is a key used in foreign language translation?" # syntactic: "foreign" + "key" in a non-database context + - "What is an alien data type in a database column?" # syntactic: alien is a synonym for foreign, reframed - id: q6 question: "What is an Entity-Relationship model?" @@ -51,6 +81,12 @@ - "Please tell me what an Entity-Relationship model is." - "What exactly is an Entity-Relationship model?" - "What is the definition of an Entity-Relationship model?" + trick_variations: + - "What is the relational model?" # semantic: different model + - "What is a UML class diagram?" # semantic: different modeling notation + - "What is a data flow diagram?" # semantic: different diagram type + - "What is the relationship between an entity and its manager in HR?" # syntactic: entity + relationship in HR context + - "What is entity authentication in network security?" # syntactic: entity reused in a security context - id: q7 question: "What is normalization?" @@ -60,6 +96,12 @@ - "Please tell me what normalization is." - "What exactly is normalization?" - "What is the definition of normalization?" + trick_variations: + - "What is denormalization?" # semantic: direct opposite concept + - "What is normalization in machine learning?" # syntactic: same word, entirely different domain + - "What is normalization in signal processing?" # syntactic: same word, entirely different domain + - "What is data cleaning?" # semantic: adjacent preprocessing concept, different meaning + - "What is normal distribution in statistics?" # syntactic: "normal" shared but normalization in stats is not DB normalization - id: q8 question: "What is a transaction?" @@ -69,6 +111,12 @@ - "Please tell me what a transaction is." - "What exactly is a transaction?" - "What is the definition of a transaction?" + trick_variations: + - "What is a stored procedure?" # semantic: different DB construct + - "What is a SQL query?" # semantic: a query is not a transaction + - "What is a financial transaction in banking?" # syntactic: same word, entirely different domain + - "What is a transaction fee on a credit card?" # syntactic: transaction reused in payments context + - "What is a batch job in data processing?" # semantic: similar in feel but different concept - id: q9 question: "What are ACID properties?" @@ -78,6 +126,12 @@ - "Please tell me what ACID properties are." - "What exactly are ACID properties?" - "What is the definition of ACID properties?" + trick_variations: + - "What are BASE properties in NoSQL databases?" # semantic: directly contrasted acronym + - "What is the CAP theorem?" # semantic: related distributed systems theory + - "What are the chemical properties of an acid?" # syntactic: ACID is also a chemistry term + - "What are SOLID principles in software engineering?" # syntactic: SOLID is another well-known acronym, similar phrasing + - "What are database constraints?" # semantic: constraints enforce rules but are not ACID - id: q10 question: "What is concurrency control?" @@ -87,6 +141,12 @@ - "Please tell me what concurrency control is." - "What exactly is concurrency control?" - "What is the definition of concurrency control?" + trick_variations: + - "What is deadlock?" # semantic: a problem concurrency control solves, not the concept itself + - "What is two-phase locking?" # semantic: a mechanism used in concurrency control, not the concept + - "What is parallel query processing?" # semantic: related to concurrency but different concept + - "What is currency control in international finance?" # syntactic: con-currency → currency, control stays + - "What is access control in operating systems?" # syntactic: "control" shared but access control is not concurrency control - id: q11 question: "What involves indexing in databases?" @@ -96,6 +156,12 @@ - "What exactly involves indexing in databases?" - "I would like to know what involves indexing in databases." - "Could you clarify what involves indexing in databases?" + trick_variations: + - "What is query optimization?" # semantic: indexing helps but query opt is broader + - "What is database partitioning?" # semantic: adjacent storage concept + - "What is indexing a book in a library catalog?" # syntactic: indexing in publishing, not databases + - "What is the stock market index?" # syntactic: index in finance domain + - "What is a database schema design?" # semantic: schema design is not indexing - id: q12 question: "What is a B+ tree?" @@ -105,6 +171,12 @@ - "Please tell me what a B+ tree is." - "What exactly is a B+ tree?" - "What is the definition of a B+ tree?" + trick_variations: + - "What is a B-tree?" # semantic: sibling structure, different in important ways + - "What is a binary search tree?" # semantic: different tree structure + - "What is an AVL tree?" # semantic: different self-balancing tree + - "What are the health benefits of vitamin B complex?" # syntactic: B+ reads like a grade or vitamin "B plus" + - "What is a family tree in genealogy?" # syntactic: tree as a concept in a different domain - id: q13 question: "What is hashing?" @@ -114,6 +186,12 @@ - "Please tell me what hashing is." - "What exactly is hashing?" - "What is the definition of hashing?" + trick_variations: + - "What is encryption?" # semantic: often conflated with hashing but different + - "What is a digital signature?" # semantic: uses hashing but is a different concept + - "What is a hashtag on social media?" # syntactic: hash in a completely different context + - "What is hash brown preparation in cooking?" # syntactic: hash as a culinary term + - "What is checksum verification?" # semantic: similar purpose but different mechanism - id: q14 question: "What is query optimization?" @@ -123,6 +201,12 @@ - "Please tell me what query optimization is." - "What exactly is query optimization?" - "What is the definition of query optimization?" + trick_variations: + - "What is query parsing?" # semantic: a step before optimization, not optimization itself + - "What is query execution?" # semantic: what happens after optimization + - "What is optimization in machine learning?" # syntactic: optimization as a term in a different domain + - "What is search engine query optimization?" # syntactic: SEO query optimization means something totally different + - "What is database indexing?" # semantic: a tool that aids optimization, not optimization itself - id: q15 question: "What is Big Data?" @@ -132,6 +216,12 @@ - "Please tell me what Big Data is." - "What exactly is Big Data?" - "What is the definition of Big Data?" + trick_variations: + - "What is a data warehouse?" # semantic: related but a specific storage architecture + - "What is data mining?" # semantic: a technique applied to data, not a paradigm + - "What does it mean to have a big database?" # syntactic: "big" + "database" literally, not the coined term + - "What is a large file in cloud storage?" # syntactic: big data as in physically large files + - "What is cloud computing?" # semantic: often mentioned with Big Data but different concept - id: q16 question: "What is MapReduce?" @@ -141,6 +231,12 @@ - "Please tell me what MapReduce is." - "What exactly is MapReduce?" - "What is the definition of MapReduce?" + trick_variations: + - "What is Apache Spark?" # semantic: successor/alternative, different concept + - "What is a distributed file system?" # semantic: infrastructure MapReduce runs on + - "What is the map data structure in programming?" # syntactic: map() as a programming concept + - "What does it mean to reduce a fraction in mathematics?" # syntactic: reduce as a math operation + - "What is a Google Maps route reduction algorithm?" # syntactic: map + reduce individually repurposed - id: q17 question: "What is NoSQL?" @@ -150,6 +246,12 @@ - "Please tell me what NoSQL is." - "What exactly is NoSQL?" - "What is the definition of NoSQL?" + trick_variations: + - "What is SQL?" # semantic: the thing NoSQL is named after + - "What is NewSQL?" # semantic: similar-sounding but distinct category + - "What is a distributed database?" # semantic: NoSQL is often distributed but they are not the same + - "What is the SQL NOT operator in a query?" # syntactic: No as a negation of SQL + - "What does it mean to say no to a database request?" # syntactic: No + SQL interpreted literally - id: q18 question: "What is data mining?" @@ -159,6 +261,12 @@ - "Please tell me what data mining is." - "What exactly is data mining?" - "What is the definition of data mining?" + trick_variations: + - "What is a data warehouse?" # semantic: where data mining data often lives + - "What is machine learning?" # semantic: related field but different concept + - "What is cryptocurrency mining?" # syntactic: mining in the blockchain domain + - "What is coal mining and how does it work?" # syntactic: mining as a physical activity + - "What is ETL in data engineering?" # semantic: moving data, not extracting patterns - id: q19 question: "What is a data warehouse?" @@ -168,6 +276,12 @@ - "Please tell me what a data warehouse is." - "What exactly is a data warehouse?" - "What is the definition of a data warehouse?" + trick_variations: + - "What is a data lake?" # semantic: sibling storage concept + - "What is a data mart?" # semantic: a subset of a warehouse, different concept + - "What is an operational database?" # semantic: contrasted with analytical warehouse + - "What is a physical warehouse management system?" # syntactic: warehouse as a building/logistics term + - "What is data storage in a cold storage facility?" # syntactic: warehouse → cold storage, data stays - id: q20 question: "What is distributed database?" @@ -177,6 +291,12 @@ - "Please tell me what distributed database is." - "What exactly is distributed database?" - "What is the definition of distributed database?" + trick_variations: + - "What is a cloud database?" # semantic: often confused, but different concept + - "What is a parallel database?" # semantic: parallel is not distributed + - "What is a peer-to-peer network?" # semantic: similar topology but different concept + - "What is a distributed team in remote work management?" # syntactic: distributed used in an HR/org context + - "What is network latency in telecommunications?" # semantic: a challenge for distributed DBs, not the concept itself - id: q21 question: "What is database security?" @@ -186,6 +306,12 @@ - "Please tell me what database security is." - "What exactly is database security?" - "What is the definition of database security?" + trick_variations: + - "What is network security?" # semantic: different security domain + - "What is database backup and recovery?" # semantic: backup is operational, not security + - "What is national security policy?" # syntactic: security reused in political context + - "What is a security guard database used for?" # syntactic: "database" + "security" swapped in meaning + - "What is data privacy regulation?" # semantic: related but policy/legal, not technical security - id: q22 question: "What is RAID?" @@ -195,6 +321,12 @@ - "Please tell me what RAID is." - "What exactly is RAID?" - "What is the definition of RAID?" + trick_variations: + - "What is database replication?" # semantic: similar redundancy goal, different mechanism + - "What is database backup?" # semantic: related to durability but different concept + - "What is a police raid operation?" # syntactic: RAID as a law enforcement action + - "What is RAID insect spray and how does it work?" # syntactic: RAID as a brand name + - "What is a solid state drive?" # semantic: storage hardware, not a storage strategy - id: q23 question: "What is a view in SQL?" @@ -204,6 +336,12 @@ - "Please tell me what a view in SQL is." - "What exactly is a view in SQL?" - "What is the definition of a view in SQL?" + trick_variations: + - "What is a stored procedure in SQL?" # semantic: different SQL construct + - "What is a SQL index?" # semantic: different SQL construct + - "What is a scenic view from a mountain?" # syntactic: view as a visual/landscape concept + - "What is a view in the MVC software architecture pattern?" # syntactic: view as a UI layer in MVC + - "What is a database trigger?" # semantic: different DB automation construct - id: q24 question: "What is a trigger?" @@ -213,6 +351,12 @@ - "Please tell me what a trigger is." - "What exactly is a trigger?" - "What is the definition of a trigger?" + trick_variations: + - "What is a stored procedure?" # semantic: often confused with triggers + - "What is a database constraint?" # semantic: constraints enforce rules, triggers react to events + - "What is a trigger warning in psychology?" # syntactic: trigger in mental health context + - "What is the trigger mechanism on a firearm?" # syntactic: trigger as a physical gun component + - "What is a user-defined function in SQL?" # semantic: different programmable DB construct - id: q25 question: "What is two-phase locking?" @@ -222,6 +366,12 @@ - "Please tell me what two-phase locking is." - "What exactly is two-phase locking?" - "What is the definition of two-phase locking?" + trick_variations: + - "What is two-phase commit?" # syntactic: two-phase is shared, commit is not locking + - "What is optimistic locking?" # semantic: sibling locking strategy + - "What is timestamp-based concurrency control?" # semantic: alternative to locking + - "What is a two-phase electrical power system?" # syntactic: two-phase in an electrical engineering context + - "What is locking in a hairstyling technique?" # syntactic: locking as in dreadlocks - id: q26 question: "What is deadlock?" @@ -231,6 +381,12 @@ - "Please tell me what deadlock is." - "What exactly is deadlock?" - "What is the definition of deadlock?" + trick_variations: + - "What is livelock?" # semantic: counterpart concept to deadlock + - "What is starvation in database scheduling?" # semantic: related scheduling problem, different cause + - "What is a deadlock in labor union negotiations?" # syntactic: deadlock as a political/negotiation standoff + - "What is a dead lock on a door?" # syntactic: dead + lock as separate physical words + - "What is a race condition in concurrent programming?" # semantic: related concurrency bug, different mechanism - id: q27 question: "What is write-ahead logging?" @@ -240,6 +396,12 @@ - "Please tell me what write-ahead logging is." - "What exactly is write-ahead logging?" - "What is the definition of write-ahead logging?" + trick_variations: + - "What is database checkpointing?" # semantic: used alongside WAL but different mechanism + - "What is database backup?" # semantic: backup vs logging for recovery + - "What does it mean to write ahead in an academic outline?" # syntactic: write-ahead as a planning metaphor + - "What is activity logging in web server access logs?" # syntactic: logging in a completely different IT context + - "What is undo logging in database recovery?" # semantic: sibling logging strategy with different behavior - id: q28 question: "What is a stored procedure?" @@ -249,6 +411,12 @@ - "Please tell me what a stored procedure is." - "What exactly is a stored procedure?" - "What is the definition of a stored procedure?" + trick_variations: + - "What is a database trigger?" # semantic: often confused with stored procedures + - "What is a prepared statement?" # semantic: similar in purpose, different in nature + - "What is a storage procedure in a hospital supply room?" # syntactic: stored + procedure in a medical context + - "What is a procedure in a legal court case?" # syntactic: procedure as a legal term + - "What is a database cursor?" # semantic: different DB programmatic construct - id: q29 question: "What is database recovery?" @@ -258,6 +426,12 @@ - "Please tell me what database recovery is." - "What exactly is database recovery?" - "What is the definition of database recovery?" + trick_variations: + - "What is database backup?" # semantic: backup enables recovery but is not recovery itself + - "What is RAID?" # semantic: RAID provides redundancy, not recovery protocols + - "What is addiction recovery and how does it work?" # syntactic: recovery in a medical/social context + - "What is economic recovery after a recession?" # syntactic: recovery as a macroeconomic term + - "What is disaster recovery planning in IT?" # semantic: broader IT concept, not database-specific - id: q30 question: "What is XML?" @@ -267,3 +441,9 @@ - "Please tell me what XML is." - "What exactly is XML?" - "What is the definition of XML?" + trick_variations: + - "What is JSON?" # semantic: sibling data serialization format + - "What is HTML?" # semantic: XML sibling, different purpose + - "What is YAML?" # semantic: sibling serialization format + - "What is an X-ray and how does it work in medicine?" # syntactic: X as in XML's first letter, completely different domain + - "What does XML stand for in the military or government?" # syntactic: XML as a potential government acronym \ No newline at end of file diff --git a/tests/test_cache_benchmark.py b/tests/test_cache_benchmark.py index 460b79e4..4f60bc4f 100644 --- a/tests/test_cache_benchmark.py +++ b/tests/test_cache_benchmark.py @@ -1,22 +1,8 @@ - -import pytest -import numpy as np -from unittest.mock import MagicMock, patch -from src.config import RAGConfig -from src.cache import ( - SEMANTIC_CACHE, - semantic_cache_store, - semantic_cache_lookup, - compute_question_embedding, - make_cache_config_key, - normalize_question -) - import pytest import numpy as np import yaml from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock from src.config import RAGConfig from src.cache import ( SEMANTIC_CACHE, @@ -24,23 +10,34 @@ semantic_cache_lookup, compute_question_embedding, make_cache_config_key, - normalize_question + normalize_question, ) + +# ----------------------------- +# Data loading +# ----------------------------- + def load_benchmark_data(): """Load benchmark questions from YAML.""" yaml_path = Path(__file__).parent / "cache_benchmark.yaml" with open(yaml_path, "r") as f: return yaml.safe_load(f) + BENCHMARK_DATA = load_benchmark_data() + +# ----------------------------- +# Fixtures +# ----------------------------- + @pytest.fixture def mock_config(): """Create a mock RAGConfig for testing.""" config = MagicMock(spec=RAGConfig) config.gen_model = "mock-model" - config.embed_model = "models/Qwen3-Embedding-4B-Q5_K_M.gguf" # Use local available model + config.embed_model = "models/Qwen3-Embedding-4B-Q5_K_M.gguf" config.top_k = 5 config.system_prompt_mode = "baseline" config.ensemble_method = "rrf" @@ -52,81 +49,187 @@ def mock_config(): config.semantic_cache_enabled = True return config + +# ----------------------------- +# Helpers +# ----------------------------- + +def print_separator(): + print("*" * 65) + + +def print_question_block(question_id, main_question, var_results, trick_results): + """ + Print a formatted block for a single question with tick/X results + for both variations and trick_variations. + + var_results: list of (question_str, hit: bool) + trick_results: list of (question_str, hit: bool) + """ + print_separator() + print(f" [{question_id}] {main_question}") + print() + + print(" Variations (hits are good ✅):") + for q, hit in var_results: + symbol = "✅" if hit else "❌" + print(f" {symbol} {q}") + + print() + print(" Trick Variations (hits are bad ⚠️):") + for q, hit in trick_results: + symbol = "⚠️ " if hit else "✅" + print(f" {symbol} {q}") + + var_hits = sum(1 for _, h in var_results if h) + trick_hits = sum(1 for _, h in trick_results if h) + print() + print(f" Recall : {var_hits}/{len(var_results)} variations matched") + print(f" Leakage: {trick_hits}/{len(trick_results)} trick variations falsely matched") + + +# ----------------------------- +# Test +# ----------------------------- + def test_cache_benchmark_comprehensive(mock_config): """ - Test 30 questions with ~10-15 variations each to verify semantic cache. + Benchmark the semantic cache against 30 questions, each with: + - 5 genuine paraphrase variations (hits expected) + - 5 trick variations (hits NOT expected) + + Scores: + Recall rate — fraction of genuine variations that got a cache hit. + Higher is better. Target >= 80%. + Leakage rate — fraction of trick variations that falsely got a cache hit. + Lower is better. Target <= 10%. """ - # 1. Clear cache + # Clear any state from previous runs SEMANTIC_CACHE.clear() - - # Setup + args = MagicMock() args.model_path = None args.system_prompt_mode = None args.index_prefix = "test_index" - + cache_key = make_cache_config_key(mock_config, args, None) embed_model_name = mock_config.embed_model - - total_hits = 0 + + total_var_hits = 0 total_variations = 0 - failures = [] + total_trick_hits = 0 + total_tricks = 0 + + recall_failures = [] + leakage_failures = [] - print(f"\n{'='*60}") - print(f" RUNNING COMPREHENSIVE CACHE BENCHMARK") - print(f" Model: {embed_model_name}") - print(f"{'='*60}") + print(f"\n{'*' * 65}") + print(f" SEMANTIC CACHE BENCHMARK") + print(f" Embedding model : {embed_model_name}") + print(f" Questions : {len(BENCHMARK_DATA)}") + print(f" Variations each : 5 genuine + 5 trick") + print(f"{'*' * 65}") for entry in BENCHMARK_DATA: - question_id = entry["id"] - main_question = entry["question"] - variations = entry["variations"] - - print(f"\n[{question_id}] Seeding: '{main_question}'") - - # 2. Seed the main question + question_id = entry["id"] + main_question = entry["question"] + variations = entry.get("variations", []) + trick_vars = entry.get("trick_variations", []) + + # --- Seed the cache with the canonical question --- normalized_main = normalize_question(main_question) - embedding_main = compute_question_embedding(normalized_main, [], embed_model_name) - assert embedding_main is not None, f"Failed to compute embedding for {question_id}" - + embedding_main = compute_question_embedding(normalized_main, [], embed_model_name) + assert embedding_main is not None, ( + f"Failed to compute embedding for {question_id}: '{main_question}'" + ) + payload = { - "answer": f"Cached answer for {question_id}", - "chunks_info": [], - "hyde_query": None, - "chunk_indices": [] + "answer": f"Cached answer for {question_id}", + "chunks_info": [], + "hyde_query": None, + "chunk_indices": [], } - semantic_cache_store(cache_key, normalized_main, embedding_main, payload) - - # 3. Test variations - entry_hits = 0 - + + # --- Test genuine variations --- + var_results = [] for var_q in variations: normalized_var = normalize_question(var_q) - embedding_var = compute_question_embedding(normalized_var, [], embed_model_name) - - result = semantic_cache_lookup(cache_key, embedding_var, normalized_var) - - if result: - print(f" ✅ Hit: '{var_q}'") - entry_hits += 1 - else: - failures.append(f"[{question_id}] Missed: '{var_q}'") - print(f" ❌ Missed: '{var_q}'") - - total_hits += entry_hits - total_variations += len(variations) - print(f" ✅ Hits: {entry_hits}/{len(variations)}") - - # Summary - hit_rate = total_hits / total_variations if total_variations > 0 else 0 - print(f"\n{'='*60}") - print(f" SUMMARY") - print(f" Total Variations: {total_variations}") - print(f" Total Hits: {total_hits}") - print(f" Hit Rate: {hit_rate:.2%}") - print(f"{'='*60}") - - # Assert acceptable hit rate (e.g., >80% for semantic similarity) - # Given we are using a good embedding model, it should be high. - assert hit_rate >= 0.80, f"Cache hit rate {hit_rate:.2%} is below 80%. Failures: {failures[:10]}..." + embedding_var = compute_question_embedding(normalized_var, [], embed_model_name) + hit = semantic_cache_lookup(cache_key, embedding_var, normalized_var) is not None + var_results.append((var_q, hit)) + if not hit: + recall_failures.append(f"[{question_id}] Missed variation : '{var_q}'") + + # --- Test trick variations --- + trick_results = [] + for trick_q in trick_vars: + normalized_trick = normalize_question(trick_q) + embedding_trick = compute_question_embedding(normalized_trick, [], embed_model_name) + hit = semantic_cache_lookup(cache_key, embedding_trick, normalized_trick) is not None + trick_results.append((trick_q, hit)) + if hit: + leakage_failures.append(f"[{question_id}] False hit on trick: '{trick_q}'") + + # --- Accumulate totals --- + total_var_hits += sum(1 for _, h in var_results if h) + total_variations += len(var_results) + total_trick_hits += sum(1 for _, h in trick_results if h) + total_tricks += len(trick_results) + + # --- Print per-question block --- + print_question_block(question_id, main_question, var_results, trick_results) + + # --- Final summary --- + recall_rate = total_var_hits / total_variations if total_variations > 0 else 0.0 + leakage_rate = total_trick_hits / total_tricks if total_tricks > 0 else 0.0 + + print_separator() + print() + print(f" {'FINAL BENCHMARK RESULTS':^61}") + print() + print(f" {'Metric':<35} {'Score':>10} {'Count'}") + print(f" {'-'*60}") + print(f" {'Recall Rate (higher is better)':<35} {recall_rate:>9.1%} {total_var_hits}/{total_variations}") + print(f" {'Leakage Rate (lower is better)':<35} {leakage_rate:>9.1%} {total_trick_hits}/{total_tricks}") + print() + print(" What these scores mean:") + print() + print(" Recall Rate — measures how often the cache correctly") + print(" recognises a genuine paraphrase of a cached question.") + print(" A high recall rate means users asking the same thing") + print(" in different words will get a fast cached response.") + print(" Target: >= 80%") + print() + print(" Leakage Rate — measures how often the cache is fooled") + print(" into returning an answer for a semantically or") + print(" syntactically similar but DIFFERENT question. A false") + print(" hit here means a user gets the wrong cached answer.") + print(" Lower is strictly better. Target: <= 10%") + + if recall_failures: + print() + print(f" Recall misses ({len(recall_failures)} total, showing first 10):") + for msg in recall_failures[:10]: + print(f" ❌ {msg}") + + if leakage_failures: + print() + print(f" Leakage hits ({len(leakage_failures)} total, showing first 10):") + for msg in leakage_failures[:10]: + print(f" ⚠️ {msg}") + + print() + print_separator() + + # --- Assertions --- + assert recall_rate >= 0.80, ( + f"Recall rate {recall_rate:.1%} is below the 80% target. " + f"The cache is missing too many genuine paraphrases.\n" + f"First 10 misses: {recall_failures[:10]}" + ) + assert leakage_rate <= 0.10, ( + f"Leakage rate {leakage_rate:.1%} exceeds the 10% target. " + f"The cache is returning false hits for trick questions.\n" + f"First 10 leaks: {leakage_failures[:10]}" + ) \ No newline at end of file From a477f417125bd6ab7874b60df9fb6f89deb028e5 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Fri, 13 Mar 2026 01:23:01 -0400 Subject: [PATCH 14/20] fix review comments --- src/cache.py | 2 -- src/main.py | 15 +++++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/cache.py b/src/cache.py index 771a2676..adc9303b 100644 --- a/src/cache.py +++ b/src/cache.py @@ -1,4 +1,3 @@ -from yaml import Node import argparse import json import hashlib @@ -16,7 +15,6 @@ # Global cache and constants # ----------------------------- SEMANTIC_CACHE: Dict[str, Deque[Dict[str, Any]]] = {} -SEMANTIC_CACHE_THRESHOLD = 0.85 BI_ENCODER_THRESHOLD = 0.40 CROSS_ENCODER_THRESHOLD = 0.75 SEMANTIC_CACHE_MAX_ENTRIES = 50 diff --git a/src/main.py b/src/main.py index 9b0d4b74..1fa961c6 100644 --- a/src/main.py +++ b/src/main.py @@ -116,18 +116,21 @@ def get_answer( topk_idxs: List[int] = [] scores = [] - normalized_question = normalize_question(question) - config_cache_key = make_cache_config_key(cfg, args, golden_chunks) + normalized_question = None + config_cache_key = None question_embedding: Optional[np.ndarray] = None semantic_hit: Optional[Dict[str, Any]] = None # Check semantic cache - if cfg.semantic_cache_enabled and config_cache_key in SEMANTIC_CACHE: - question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) - semantic_hit = semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) + if cfg.semantic_cache_enabled: + normalized_question = normalize_question(question) + config_cache_key = make_cache_config_key(cfg, args, golden_chunks) + if config_cache_key in SEMANTIC_CACHE: + question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) + semantic_hit = semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) # Return cached answer if found - if cfg.semantic_cache_enabled and semantic_hit: + if cfg.semantic_cache_enabled and semantic_hit is not None: ans = semantic_hit.get("answer", "") From af7d65e1d1c7169f7863cca43e06bc5bd5949d95 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 2 Apr 2026 17:36:28 -0400 Subject: [PATCH 15/20] fix the cache threshold --- src/cache.py | 4 ++-- tests/test_cache_benchmark.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/cache.py b/src/cache.py index d3539c4c..f0121466 100644 --- a/src/cache.py +++ b/src/cache.py @@ -15,8 +15,8 @@ # Global cache and constants # ----------------------------- SEMANTIC_CACHE: Dict[str, Deque[Dict[str, Any]]] = {} -BI_ENCODER_THRESHOLD = 0.40 -CROSS_ENCODER_THRESHOLD = 0.75 +BI_ENCODER_THRESHOLD = 0.95 +CROSS_ENCODER_THRESHOLD = 0.99 SEMANTIC_CACHE_MAX_ENTRIES = 50 QUESTION_EMBEDDERS: Dict[str, SentenceTransformer] = {} CROSS_ENCODER_MODEL: Optional[CrossEncoder] = None diff --git a/tests/test_cache_benchmark.py b/tests/test_cache_benchmark.py index d359521b..429e7d4c 100644 --- a/tests/test_cache_benchmark.py +++ b/tests/test_cache_benchmark.py @@ -232,7 +232,7 @@ def test_cache_benchmark_comprehensive(mock_config): f"The cache is missing too many genuine paraphrases.\n" f"First 10 misses: {accuracy_failures[:10]}" ) - assert false_positive_rate <= 0.10, ( + assert false_positive_rate <= 0.05, ( f"False Positives rate {false_positive_rate:.1%} exceeds the 10% target. " f"The cache is returning false hits for adversarial queries.\n" f"First 10 leaks: {false_positive_failures[:10]}" From 0e27032edac6f33536eebc3ae89a59897f02a369 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 2 Apr 2026 18:12:20 -0400 Subject: [PATCH 16/20] fixed interface and add no-op cache --- src/cache.py | 311 +++++++++++++++++++--------------- src/main.py | 53 +++--- tests/test_cache_benchmark.py | 44 +++-- 3 files changed, 217 insertions(+), 191 deletions(-) diff --git a/src/cache.py b/src/cache.py index f0121466..d9f2c986 100644 --- a/src/cache.py +++ b/src/cache.py @@ -3,6 +3,7 @@ import hashlib from typing import Dict, Optional, Any, List, Deque from collections import deque +from abc import ABC, abstractmethod import numpy as np from sentence_transformers import CrossEncoder @@ -11,147 +12,185 @@ from src.config import RAGConfig from src.retriever import BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts, filter_retrieved_chunks -# ----------------------------- -# Global cache and constants -# ----------------------------- -SEMANTIC_CACHE: Dict[str, Deque[Dict[str, Any]]] = {} -BI_ENCODER_THRESHOLD = 0.95 -CROSS_ENCODER_THRESHOLD = 0.99 -SEMANTIC_CACHE_MAX_ENTRIES = 50 -QUESTION_EMBEDDERS: Dict[str, SentenceTransformer] = {} -CROSS_ENCODER_MODEL: Optional[CrossEncoder] = None - - -# ----------------------------- -# Utilities -# ----------------------------- -def get_cross_encoder() -> CrossEncoder: - """Return a global cross-encoder model instance, initializing if needed.""" - global CROSS_ENCODER_MODEL - if CROSS_ENCODER_MODEL is None: - CROSS_ENCODER_MODEL = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') - return CROSS_ENCODER_MODEL - - -def normalize_question(q: str) -> str: - """Normalize a question string: lowercase, strip, and collapse spaces.""" - return " ".join((q or "").strip().lower().split()) - - -def make_cache_config_key(cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str: - """ - Create a unique JSON key for semantic cache based on config, arguments, and optional golden chunks. - """ - - try: - payload = RAGConfig.get_config_state() - except Exception as e: - payload = { - "gen_model": getattr(args, "model_path", None) or cfg.gen_model, - "embed_model": cfg.embed_model, - "top_k": cfg.top_k, - "system_prompt_mode": getattr(args, "system_prompt_mode", None) or cfg.system_prompt_mode, - "ensemble_method": cfg.ensemble_method, - "ranker_weights": cfg.ranker_weights, - "use_hyde": cfg.use_hyde, - "use_indexed_chunks": cfg.use_indexed_chunks, - "disable_chunks": cfg.disable_chunks, - "use_golden_chunks": bool(golden_chunks and cfg.use_golden_chunks), - "index_prefix": getattr(args, "index_prefix", None), - } - - # !!! Unnencessary to include - to remove later !!! - if golden_chunks and cfg.use_golden_chunks: - signature = hashlib.sha256("||".join(golden_chunks).encode("utf-8")).hexdigest() - payload["golden_signature"] = signature - - return json.dumps(payload, sort_keys=True) - -# ----------------------------- -# Semantic cache operations -# ----------------------------- -def semantic_cache_lookup(config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: - """ - Retrieve a cached answer if semantically similar to the current question. - """ - entries = SEMANTIC_CACHE.get(config_key, []) - if not entries or query_embedding is None: - return None - # Step 1: Bi-Encoder filter (fast cosine similarity) - candidates = [ - entry for entry in entries - if np.dot(entry["embedding"], query_embedding) > BI_ENCODER_THRESHOLD - ] - if not candidates: +class Cache(ABC): + @abstractmethod + def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: + pass + + @abstractmethod + def store(self, config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None: + pass + + @abstractmethod + def clear(self) -> None: + pass + + @abstractmethod + def make_config_key(self, cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str: + pass + + @abstractmethod + def compute_embedding(self, question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]: + pass + + @abstractmethod + def normalize_question(self, q: str) -> str: + pass + + +class SemanticCache(Cache): + def __init__(self, bi_encoder_threshold: float = 0.90, cross_encoder_threshold: float = 0.99, max_entries: int = 50): + self.cache: Dict[str, Deque[Dict[str, Any]]] = {} + self.bi_encoder_threshold = bi_encoder_threshold + self.cross_encoder_threshold = cross_encoder_threshold + self.max_entries = max_entries + self.question_embedders: Dict[str, SentenceTransformer] = {} + self.cross_encoder_model: Optional[CrossEncoder] = None + + def _get_cross_encoder(self) -> CrossEncoder: + """Return a global cross-encoder model instance, initializing if needed.""" + if self.cross_encoder_model is None: + self.cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') + return self.cross_encoder_model + + def normalize_question(self, q: str) -> str: + """Normalize a question string: lowercase, strip, and collapse spaces.""" + return " ".join((q or "").strip().lower().split()) + + def make_config_key(self, cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str: + """ + Create a unique JSON key for semantic cache based on config, arguments, and optional golden chunks. + """ + try: + payload = RAGConfig.get_config_state() + except Exception: + payload = { + "gen_model": getattr(args, "model_path", None) or cfg.gen_model, + "embed_model": cfg.embed_model, + "top_k": cfg.top_k, + "system_prompt_mode": getattr(args, "system_prompt_mode", None) or cfg.system_prompt_mode, + "ensemble_method": cfg.ensemble_method, + "ranker_weights": cfg.ranker_weights, + "use_hyde": cfg.use_hyde, + "use_indexed_chunks": cfg.use_indexed_chunks, + "disable_chunks": cfg.disable_chunks, + "use_golden_chunks": bool(golden_chunks and cfg.use_golden_chunks), + "index_prefix": getattr(args, "index_prefix", None), + } + + if golden_chunks and cfg.use_golden_chunks: + signature = hashlib.sha256("||".join(golden_chunks).encode("utf-8")).hexdigest() + payload["golden_signature"] = signature + + return json.dumps(payload, sort_keys=True) + + def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: + """ + Retrieve a cached answer if semantically similar to the current question. + """ + entries = self.cache.get(config_key, []) + if not entries or query_embedding is None: + return None + + # Step 1: Bi-Encoder filter (fast cosine similarity) + candidates = [ + entry for entry in entries + if np.dot(entry["embedding"], query_embedding) > self.bi_encoder_threshold + ] + if not candidates: + return None + + # Step 2: Cross-Encoder verification + ce_model = self._get_cross_encoder() + pairs = [[current_question, c["question"]] for c in candidates] + ce_scores = ce_model.predict(pairs, show_progress_bar=False) + best_idx = int(np.argmax(ce_scores)) + + if ce_scores[best_idx] > self.cross_encoder_threshold: + return candidates[best_idx]["payload"] return None - # Step 2: Cross-Encoder verification - ce_model = get_cross_encoder() - pairs = [[current_question, c["question"]] for c in candidates] - ce_scores = ce_model.predict(pairs, show_progress_bar=False) - best_idx = int(np.argmax(ce_scores)) - - if ce_scores[best_idx] > CROSS_ENCODER_THRESHOLD: - return candidates[best_idx]["payload"] - return None - - -def semantic_cache_store(config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None: - """ - Store a question, its embedding, and the generated answer in the semantic cache. - Evict oldest entries if cache exceeds SEMANTIC_CACHE_MAX_ENTRIES. - """ - if question_embedding is None: - return - - if config_key not in SEMANTIC_CACHE: - SEMANTIC_CACHE[config_key] = deque() - entries = SEMANTIC_CACHE[config_key] - entries.append({ - "question": normalized_question, - "embedding": question_embedding.astype(np.float32), - "payload": payload, - }) - - if len(entries) > SEMANTIC_CACHE_MAX_ENTRIES: - entries.popleft() - - -# ----------------------------- -# Question embedding -# ----------------------------- -def get_question_embedder(retrievers: List[Any], embed_model: str) -> Optional[SentenceTransformer]: - """ - Get or initialize a SentenceTransformer for encoding questions. - Prefers the embedder from any FAISSRetriever in the retrievers list. - """ - for retriever in retrievers: - if isinstance(retriever, FAISSRetriever): - return retriever.embedder - - model_path = embed_model - if not model_path: + def store(self, config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None: + """ + Store a question, its embedding, and the generated answer in the semantic cache. + Evict oldest entries if cache exceeds self.max_entries. + """ + if question_embedding is None: + return + + if config_key not in self.cache: + self.cache[config_key] = deque() + entries = self.cache[config_key] + entries.append({ + "question": normalized_question, + "embedding": question_embedding.astype(np.float32), + "payload": payload, + }) + + if len(entries) > self.max_entries: + entries.popleft() + + def clear(self) -> None: + self.cache.clear() + + def _get_question_embedder(self, retrievers: List[Any], embed_model: str) -> Optional[SentenceTransformer]: + """ + Get or initialize a SentenceTransformer for encoding questions. + Prefers the embedder from any FAISSRetriever in the retrievers list. + """ + for retriever in retrievers: + if isinstance(retriever, FAISSRetriever): + return retriever.embedder + + model_path = embed_model + if not model_path: + return None + + embedder = self.question_embedders.get(model_path) + if embedder is None: + embedder = SentenceTransformer(model_path) + self.question_embedders[model_path] = embedder + + return embedder + + def compute_embedding(self, question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]: + """ + Compute a normalized embedding vector for a question using the configured embedder. + """ + embedder = self._get_question_embedder(retrievers, embed_model) + if not embedder: + return None + + vec = embedder.encode([question], batch_size=1, normalize=True, show_progress_bar=False) + if vec.size == 0: + return None + + return vec[0] + + +class NoOpCache(Cache): + def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: return None - embedder = QUESTION_EMBEDDERS.get(model_path) - if embedder is None: - embedder = SentenceTransformer(model_path) - QUESTION_EMBEDDERS[model_path] = embedder - - return embedder - - -def compute_question_embedding(question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]: - """ - Compute a normalized embedding vector for a question using the configured embedder. - """ - embedder = get_question_embedder(retrievers, embed_model) - if not embedder: + def store(self, config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None: + pass + + def clear(self) -> None: + pass + + def make_config_key(self, cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str: + return "" + + def compute_embedding(self, question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]: return None + + def normalize_question(self, q: str) -> str: + return "" - vec = embedder.encode([question], batch_size=1, normalize=True, show_progress_bar=False) - if vec.size == 0: - return None - return vec[0] +def get_cache(cfg: RAGConfig) -> Cache: + """Return a configured cache layer, either SemanticCache or NoOpCache depending on config.""" + if getattr(cfg, 'semantic_cache_enabled', False): + return SemanticCache() + return NoOpCache() diff --git a/src/main.py b/src/main.py index d947074a..ccdb81c7 100644 --- a/src/main.py +++ b/src/main.py @@ -28,7 +28,7 @@ load_artifacts ) from src.ranking.reranker import rerank -from src.cache import SEMANTIC_CACHE, semantic_cache_store, compute_question_embedding, normalize_question, make_cache_config_key, semantic_cache_lookup +from src.cache import get_cache ANSWER_NOT_FOUND = "I'm sorry, but I don't have enough information to answer that question." @@ -122,21 +122,15 @@ def get_answer( topk_idxs: List[int] = [] scores = [] - normalized_question = None - config_cache_key = None - question_embedding: Optional[np.ndarray] = None - semantic_hit: Optional[Dict[str, Any]] = None - - # Check semantic cache - if cfg.semantic_cache_enabled: - normalized_question = normalize_question(question) - config_cache_key = make_cache_config_key(cfg, args, golden_chunks) - if config_cache_key in SEMANTIC_CACHE: - question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) - semantic_hit = semantic_cache_lookup(config_cache_key, question_embedding, normalized_question) + cache = get_cache(cfg) + normalized_question = cache.normalize_question(question) + config_cache_key = cache.make_config_key(cfg, args, golden_chunks) + question_embedding = cache.compute_embedding(normalized_question, retrievers, cfg.embed_model) + + semantic_hit = cache.lookup(config_cache_key, question_embedding, normalized_question) # Return cached answer if found - if cfg.semantic_cache_enabled and semantic_hit is not None: + if semantic_hit is not None: ans = semantic_hit.get("answer", "") @@ -277,22 +271,21 @@ def get_answer( additional_log_info=additional_log_info ) - # Step 5: Store in semantic cache if enabled by config - if cfg.semantic_cache_enabled: - cache_payload = { - "answer": ans, - "chunks_info": chunks_info, - "hyde_query": hyde_query, - "chunk_indices": topk_idxs, - } - if question_embedding is None: - question_embedding = compute_question_embedding(normalized_question, retrievers, cfg.embed_model) - semantic_cache_store( - config_cache_key, - normalized_question, - question_embedding, - cache_payload - ) + # Step 5: Store in semantic cache + cache_payload = { + "answer": ans, + "chunks_info": chunks_info, + "hyde_query": hyde_query, + "chunk_indices": topk_idxs, + } + if question_embedding is None: + question_embedding = cache.compute_embedding(normalized_question, retrievers, cfg.embed_model) + cache.store( + config_cache_key, + normalized_question, + question_embedding, + cache_payload + ) if is_test_mode: return ans, chunks_info, hyde_query diff --git a/tests/test_cache_benchmark.py b/tests/test_cache_benchmark.py index 429e7d4c..e259b8ec 100644 --- a/tests/test_cache_benchmark.py +++ b/tests/test_cache_benchmark.py @@ -4,14 +4,7 @@ from pathlib import Path from unittest.mock import MagicMock from src.config import RAGConfig -from src.cache import ( - SEMANTIC_CACHE, - semantic_cache_store, - semantic_cache_lookup, - compute_question_embedding, - make_cache_config_key, - normalize_question, -) +from src.cache import get_cache # ----------------------------- @@ -100,19 +93,20 @@ def test_cache_benchmark_comprehensive(mock_config): Scores: Accuracy rate — fraction of genuine variations that got a cache hit. - Higher is better. Target >= 80%. + Higher is better. Target >= 60%. False Positives rate — fraction of adversarial queries that falsely got a cache hit. - Lower is better. Target <= 10%. + Lower is better. Target <= 0%. """ - # Clear any state from previous runs - SEMANTIC_CACHE.clear() + + cache = get_cache(mock_config) + cache.clear() args = MagicMock() args.model_path = None args.system_prompt_mode = None args.index_prefix = "test_index" - cache_key = make_cache_config_key(mock_config, args, None) + cache_key = cache.make_config_key(mock_config, args, None) embed_model_name = mock_config.embed_model total_var_hits = 0 @@ -137,8 +131,8 @@ def test_cache_benchmark_comprehensive(mock_config): adversarial_vars = entry.get("adversarial_queries", []) # --- Seed the cache with the canonical question --- - normalized_main = normalize_question(main_question) - embedding_main = compute_question_embedding(normalized_main, [], embed_model_name) + normalized_main = cache.normalize_question(main_question) + embedding_main = cache.compute_embedding(normalized_main, [], embed_model_name) assert embedding_main is not None, ( f"Failed to compute embedding for {question_id}: '{main_question}'" ) @@ -149,14 +143,14 @@ def test_cache_benchmark_comprehensive(mock_config): "hyde_query": None, "chunk_indices": [], } - semantic_cache_store(cache_key, normalized_main, embedding_main, payload) + cache.store(cache_key, normalized_main, embedding_main, payload) # --- Test genuine variations --- var_results = [] for var_q in variations: - normalized_var = normalize_question(var_q) - embedding_var = compute_question_embedding(normalized_var, [], embed_model_name) - hit = semantic_cache_lookup(cache_key, embedding_var, normalized_var) is not None + normalized_var = cache.normalize_question(var_q) + embedding_var = cache.compute_embedding(normalized_var, [], embed_model_name) + hit = cache.lookup(cache_key, embedding_var, normalized_var) is not None var_results.append((var_q, hit)) if not hit: accuracy_failures.append(f"[{question_id}] Missed variation : '{var_q}'") @@ -164,10 +158,10 @@ def test_cache_benchmark_comprehensive(mock_config): # --- Test adversarial queries --- adversarial_results = [] for adversarial_q in adversarial_vars: - normalized_adversarial = normalize_question(adversarial_q) - embedding_adversarial = compute_question_embedding(normalized_adversarial, [], embed_model_name) + normalized_adversarial = cache.normalize_question(adversarial_q) + embedding_adversarial = cache.compute_embedding(normalized_adversarial, [], embed_model_name) - payload_hit = semantic_cache_lookup(cache_key, embedding_adversarial, normalized_adversarial) + payload_hit = cache.lookup(cache_key, embedding_adversarial, normalized_adversarial) # A false hit is when the cache incorrectly returns the CURRENT question's answer # for a query that is semantically different. hit = payload_hit is not None and payload_hit.get("answer") == f"Cached answer for {question_id}" @@ -203,13 +197,13 @@ def test_cache_benchmark_comprehensive(mock_config): print(" recognises a genuine paraphrase of a cached question.") print(" A high accuracy rate means users asking the same thing") print(" in different words will get a fast cached response.") - print(" Target: >= 80%") + print(" Target: >= 60%") print() print(" False Positive Rate — measures how often the cache is fooled") print(" into returning an answer for a semantically or") print(" syntactically similar but DIFFERENT question. A false") print(" hit here means a user gets the wrong cached answer.") - print(" Lower is strictly better. Target: <= 10%") + print(" Lower is strictly better. Target: <= 0%") if accuracy_failures: print() @@ -233,7 +227,7 @@ def test_cache_benchmark_comprehensive(mock_config): f"First 10 misses: {accuracy_failures[:10]}" ) assert false_positive_rate <= 0.05, ( - f"False Positives rate {false_positive_rate:.1%} exceeds the 10% target. " + f"False Positives rate {false_positive_rate:.1%} exceeds the 5% target. " f"The cache is returning false hits for adversarial queries.\n" f"First 10 leaks: {false_positive_failures[:10]}" ) \ No newline at end of file From 9a3f982076e398d5ff9da2d98f4247815f333c29 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 2 Apr 2026 19:16:33 -0400 Subject: [PATCH 17/20] use the same cache for every call --- src/cache.py | 7 ++++++- src/main.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/cache.py b/src/cache.py index d9f2c986..707ffe5e 100644 --- a/src/cache.py +++ b/src/cache.py @@ -189,8 +189,13 @@ def normalize_question(self, q: str) -> str: return "" +_GLOBAL_SEMANTIC_CACHE: Optional[SemanticCache] = None + def get_cache(cfg: RAGConfig) -> Cache: """Return a configured cache layer, either SemanticCache or NoOpCache depending on config.""" + global _GLOBAL_SEMANTIC_CACHE if getattr(cfg, 'semantic_cache_enabled', False): - return SemanticCache() + if _GLOBAL_SEMANTIC_CACHE is None: + _GLOBAL_SEMANTIC_CACHE = SemanticCache() + return _GLOBAL_SEMANTIC_CACHE return NoOpCache() diff --git a/src/main.py b/src/main.py index ccdb81c7..364f21ee 100644 --- a/src/main.py +++ b/src/main.py @@ -136,7 +136,7 @@ def get_answer( if is_test_mode: return ans, semantic_hit.get("chunks_info"), semantic_hit.get("hyde_query") - + console.print("Using cached answer") render_final_answer(console, ans) return ans From ce80f6e140258b64f333518a07c1fa811463f74b Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 2 Apr 2026 19:23:46 -0400 Subject: [PATCH 18/20] clean up PR --- config/config.yaml | 2 +- src/generator.py | 6 ++---- src/main.py | 17 +++++------------ 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index bf35f9cb..00b391c9 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -17,4 +17,4 @@ rerank_top_k: 5 use_double_prompt: false enable_history: true max_history_turns: 3 -semantic_cache_enabled: false +semantic_cache_enabled: true diff --git a/src/generator.py b/src/generator.py index 1c852e9b..8e8b3181 100644 --- a/src/generator.py +++ b/src/generator.py @@ -118,12 +118,10 @@ def get_llama_model(model_path: str, n_ctx: int = 4096): verbose=False, n_gpu_layers=-1) except Exception as e: - print(f"Error occurred while initializing Llama model on. GPU: {e}") + print(f"Error loading LLaMA model from {model_path} on GPU: {e}") _LLM_CACHE[model_path] = Llama(model_path=model_path, n_ctx=n_ctx, - verbose=False, - ) - + verbose=False) return _LLM_CACHE[model_path] def stream_llama_cpp(prompt: str, model_path: str, max_tokens: int, temperature: float): diff --git a/src/main.py b/src/main.py index 364f21ee..2a0c67aa 100644 --- a/src/main.py +++ b/src/main.py @@ -5,7 +5,6 @@ import json import pathlib import sys -import numpy as np from typing import Dict, Optional, List, Tuple, Union, Any from rich.live import Live @@ -140,27 +139,23 @@ def get_answer( render_final_answer(console, ans) return ans - # If no semantic hit, proceed with normal retrieval, ranking, and generation process - # Step 2: Retrieval + # Step 1: Get chunks (golden, retrieved, or none) chunks_info = None hyde_query = None if golden_chunks and cfg.use_golden_chunks: - # Use provided golden chunks (testing mode only) + # Use provided golden chunks ranked_chunks = golden_chunks elif cfg.disable_chunks: - # No chunks - baseline mode (only tests model knowledge) + # No chunks - baseline mode ranked_chunks = [] elif cfg.use_indexed_chunks: - # basic inverted index using keywords (keywords here are just non-stopword tokens in question) ranked_chunks, topk_idxs = use_indexed_chunks(question, chunks) else: - # Normal retrieval + ranking flow based on config retrieval_query = question # print(f"Retrieval query: {retrieval_query}") if cfg.use_hyde: retrieval_query = generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens) - - # Step 2.2: Get raw scores from each retriever + pool_n = max(cfg.num_candidates, cfg.top_k + 10) raw_scores: Dict[str, Dict[int, float]] = {} for retriever in retrievers: @@ -211,7 +206,7 @@ def get_answer( "index_rank": index_ranks.get(idx, 0), }) - # Step 3: Reranking with cross-encoder (if configured) + # Step 3: Final re-ranking ranked_chunks = rerank(question, ranked_chunks, mode=cfg.rerank_mode, top_n=cfg.rerank_top_k) # print("Reranked Chunks", type(ranked_chunks), len(ranked_chunks), type(ranked_chunks[0]) if ranked_chunks else "No chunks") # print("Example reranked chunk content:", ranked_chunks[0] if ranked_chunks else "No chunks after reranking") @@ -225,7 +220,6 @@ def get_answer( model_path = cfg.gen_model system_prompt = args.system_prompt_mode or cfg.system_prompt_mode - # Step 4.1: Check for double prompting approach to improve answer quality (if enabled by config or CLI arg) use_double = getattr(args, "double_prompt", False) or cfg.use_double_prompt if use_double: stream_iter = double_answer( @@ -235,7 +229,6 @@ def get_answer( max_tokens=cfg.max_gen_tokens, system_prompt_mode=system_prompt, ) - # If not double prompting, use normal answer method from generator.py else: stream_iter = answer( question, From bc765ba398d09e0e2c351595e574f2cd798b7418 Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 9 Apr 2026 08:33:13 -0400 Subject: [PATCH 19/20] fix review comments --- config/config.yaml | 2 ++ src/cache.py | 15 +++++++++------ src/config.py | 2 ++ tests/test_cache_benchmark.py | 2 ++ 4 files changed, 15 insertions(+), 6 deletions(-) diff --git a/config/config.yaml b/config/config.yaml index 00b391c9..734033d7 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -18,3 +18,5 @@ use_double_prompt: false enable_history: true max_history_turns: 3 semantic_cache_enabled: true +semantic_cache_bi_encoder_threshold: 0.90 +semantic_cache_cross_encoder_threshold: 0.99 diff --git a/src/cache.py b/src/cache.py index 707ffe5e..55c8057f 100644 --- a/src/cache.py +++ b/src/cache.py @@ -13,7 +13,7 @@ from src.retriever import BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts, filter_retrieved_chunks -class Cache(ABC): +class BaseResponseCache(ABC): @abstractmethod def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: pass @@ -39,8 +39,8 @@ def normalize_question(self, q: str) -> str: pass -class SemanticCache(Cache): - def __init__(self, bi_encoder_threshold: float = 0.90, cross_encoder_threshold: float = 0.99, max_entries: int = 50): +class SemanticCache(BaseResponseCache): + def __init__(self, bi_encoder_threshold: float, cross_encoder_threshold: float, max_entries: int = 50): self.cache: Dict[str, Deque[Dict[str, Any]]] = {} self.bi_encoder_threshold = bi_encoder_threshold self.cross_encoder_threshold = cross_encoder_threshold @@ -169,7 +169,7 @@ def compute_embedding(self, question: str, retrievers: List[Any], embed_model: s return vec[0] -class NoOpCache(Cache): +class NoOpCache(BaseResponseCache): def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: return None @@ -191,11 +191,14 @@ def normalize_question(self, q: str) -> str: _GLOBAL_SEMANTIC_CACHE: Optional[SemanticCache] = None -def get_cache(cfg: RAGConfig) -> Cache: +def get_cache(cfg: RAGConfig) -> BaseResponseCache: """Return a configured cache layer, either SemanticCache or NoOpCache depending on config.""" global _GLOBAL_SEMANTIC_CACHE if getattr(cfg, 'semantic_cache_enabled', False): if _GLOBAL_SEMANTIC_CACHE is None: - _GLOBAL_SEMANTIC_CACHE = SemanticCache() + _GLOBAL_SEMANTIC_CACHE = SemanticCache( + bi_encoder_threshold=cfg.semantic_cache_bi_encoder_threshold, + cross_encoder_threshold=cfg.semantic_cache_cross_encoder_threshold + ) return _GLOBAL_SEMANTIC_CACHE return NoOpCache() diff --git a/src/config.py b/src/config.py index cd01dc7f..0c00dcdb 100644 --- a/src/config.py +++ b/src/config.py @@ -47,6 +47,8 @@ class RAGConfig: # cache semantic_cache_enabled: bool = False + semantic_cache_bi_encoder_threshold: float = 0.90 + semantic_cache_cross_encoder_threshold: float = 0.99 # conversational memory enable_history: bool = True diff --git a/tests/test_cache_benchmark.py b/tests/test_cache_benchmark.py index e259b8ec..9c95fced 100644 --- a/tests/test_cache_benchmark.py +++ b/tests/test_cache_benchmark.py @@ -40,6 +40,8 @@ def mock_config(): config.disable_chunks = False config.use_golden_chunks = False config.semantic_cache_enabled = True + config.semantic_cache_bi_encoder_threshold = 0.90 + config.semantic_cache_cross_encoder_threshold = 0.99 return config From 8315008b14a4470e267e4b09914b36f2550bba3f Mon Sep 17 00:00:00 2001 From: infinite-void Date: Thu, 9 Apr 2026 08:36:47 -0400 Subject: [PATCH 20/20] fix the enabled flag --- config/config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/config.yaml b/config/config.yaml index 65cde9dc..2e75436a 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -17,7 +17,7 @@ rerank_top_k: 5 use_double_prompt: false enable_history: true max_history_turns: 3 -semantic_cache_enabled: true +semantic_cache_enabled: false semantic_cache_bi_encoder_threshold: 0.90 semantic_cache_cross_encoder_threshold: 0.99 enable_topic_extraction: true