diff --git a/config/config.yaml b/config/config.yaml index 748180d1..389f8e31 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -19,4 +19,7 @@ rerank_top_k: 5 use_double_prompt: false enable_history: true max_history_turns: 3 -enable_topic_extraction: false \ No newline at end of file +semantic_cache_enabled: false +semantic_cache_bi_encoder_threshold: 0.90 +semantic_cache_cross_encoder_threshold: 0.99 +enable_topic_extraction: false diff --git a/src/cache.py b/src/cache.py new file mode 100644 index 00000000..55c8057f --- /dev/null +++ b/src/cache.py @@ -0,0 +1,204 @@ +import argparse +import json +import hashlib +from typing import Dict, Optional, Any, List, Deque +from collections import deque +from abc import ABC, abstractmethod + +import numpy as np +from sentence_transformers import CrossEncoder + +from src.embedder import SentenceTransformer +from src.config import RAGConfig +from src.retriever import BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts, filter_retrieved_chunks + + +class BaseResponseCache(ABC): + @abstractmethod + def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: + pass + + @abstractmethod + def store(self, config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None: + pass + + @abstractmethod + def clear(self) -> None: + pass + + @abstractmethod + def make_config_key(self, cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str: + pass + + @abstractmethod + def compute_embedding(self, question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]: + pass + + @abstractmethod + def normalize_question(self, q: str) -> str: + pass + + +class SemanticCache(BaseResponseCache): + def __init__(self, bi_encoder_threshold: float, cross_encoder_threshold: float, max_entries: int = 50): + self.cache: Dict[str, Deque[Dict[str, Any]]] = {} + self.bi_encoder_threshold = bi_encoder_threshold + self.cross_encoder_threshold = cross_encoder_threshold + self.max_entries = max_entries + self.question_embedders: Dict[str, SentenceTransformer] = {} + self.cross_encoder_model: Optional[CrossEncoder] = None + + def _get_cross_encoder(self) -> CrossEncoder: + """Return a global cross-encoder model instance, initializing if needed.""" + if self.cross_encoder_model is None: + self.cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') + return self.cross_encoder_model + + def normalize_question(self, q: str) -> str: + """Normalize a question string: lowercase, strip, and collapse spaces.""" + return " ".join((q or "").strip().lower().split()) + + def make_config_key(self, cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str: + """ + Create a unique JSON key for semantic cache based on config, arguments, and optional golden chunks. + """ + try: + payload = RAGConfig.get_config_state() + except Exception: + payload = { + "gen_model": getattr(args, "model_path", None) or cfg.gen_model, + "embed_model": cfg.embed_model, + "top_k": cfg.top_k, + "system_prompt_mode": getattr(args, "system_prompt_mode", None) or cfg.system_prompt_mode, + "ensemble_method": cfg.ensemble_method, + "ranker_weights": cfg.ranker_weights, + "use_hyde": cfg.use_hyde, + "use_indexed_chunks": cfg.use_indexed_chunks, + "disable_chunks": cfg.disable_chunks, + "use_golden_chunks": bool(golden_chunks and cfg.use_golden_chunks), + "index_prefix": getattr(args, "index_prefix", None), + } + + if golden_chunks and cfg.use_golden_chunks: + signature = hashlib.sha256("||".join(golden_chunks).encode("utf-8")).hexdigest() + payload["golden_signature"] = signature + + return json.dumps(payload, sort_keys=True) + + def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: + """ + Retrieve a cached answer if semantically similar to the current question. + """ + entries = self.cache.get(config_key, []) + if not entries or query_embedding is None: + return None + + # Step 1: Bi-Encoder filter (fast cosine similarity) + candidates = [ + entry for entry in entries + if np.dot(entry["embedding"], query_embedding) > self.bi_encoder_threshold + ] + if not candidates: + return None + + # Step 2: Cross-Encoder verification + ce_model = self._get_cross_encoder() + pairs = [[current_question, c["question"]] for c in candidates] + ce_scores = ce_model.predict(pairs, show_progress_bar=False) + best_idx = int(np.argmax(ce_scores)) + + if ce_scores[best_idx] > self.cross_encoder_threshold: + return candidates[best_idx]["payload"] + return None + + def store(self, config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None: + """ + Store a question, its embedding, and the generated answer in the semantic cache. + Evict oldest entries if cache exceeds self.max_entries. + """ + if question_embedding is None: + return + + if config_key not in self.cache: + self.cache[config_key] = deque() + entries = self.cache[config_key] + entries.append({ + "question": normalized_question, + "embedding": question_embedding.astype(np.float32), + "payload": payload, + }) + + if len(entries) > self.max_entries: + entries.popleft() + + def clear(self) -> None: + self.cache.clear() + + def _get_question_embedder(self, retrievers: List[Any], embed_model: str) -> Optional[SentenceTransformer]: + """ + Get or initialize a SentenceTransformer for encoding questions. + Prefers the embedder from any FAISSRetriever in the retrievers list. + """ + for retriever in retrievers: + if isinstance(retriever, FAISSRetriever): + return retriever.embedder + + model_path = embed_model + if not model_path: + return None + + embedder = self.question_embedders.get(model_path) + if embedder is None: + embedder = SentenceTransformer(model_path) + self.question_embedders[model_path] = embedder + + return embedder + + def compute_embedding(self, question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]: + """ + Compute a normalized embedding vector for a question using the configured embedder. + """ + embedder = self._get_question_embedder(retrievers, embed_model) + if not embedder: + return None + + vec = embedder.encode([question], batch_size=1, normalize=True, show_progress_bar=False) + if vec.size == 0: + return None + + return vec[0] + + +class NoOpCache(BaseResponseCache): + def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]: + return None + + def store(self, config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None: + pass + + def clear(self) -> None: + pass + + def make_config_key(self, cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str: + return "" + + def compute_embedding(self, question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]: + return None + + def normalize_question(self, q: str) -> str: + return "" + + +_GLOBAL_SEMANTIC_CACHE: Optional[SemanticCache] = None + +def get_cache(cfg: RAGConfig) -> BaseResponseCache: + """Return a configured cache layer, either SemanticCache or NoOpCache depending on config.""" + global _GLOBAL_SEMANTIC_CACHE + if getattr(cfg, 'semantic_cache_enabled', False): + if _GLOBAL_SEMANTIC_CACHE is None: + _GLOBAL_SEMANTIC_CACHE = SemanticCache( + bi_encoder_threshold=cfg.semantic_cache_bi_encoder_threshold, + cross_encoder_threshold=cfg.semantic_cache_cross_encoder_threshold + ) + return _GLOBAL_SEMANTIC_CACHE + return NoOpCache() diff --git a/src/config.py b/src/config.py index 642821a4..2ebd49ac 100644 --- a/src/config.py +++ b/src/config.py @@ -46,6 +46,11 @@ class RAGConfig: hyde_max_tokens: int = 300 use_double_prompt: bool = False + # cache + semantic_cache_enabled: bool = False + semantic_cache_bi_encoder_threshold: float = 0.90 + semantic_cache_cross_encoder_threshold: float = 0.99 + # conversational memory enable_history: bool = True max_history_turns: int = 3 diff --git a/src/main.py b/src/main.py index 64a7da8c..3b165163 100644 --- a/src/main.py +++ b/src/main.py @@ -27,6 +27,7 @@ load_artifacts ) from src.ranking.reranker import rerank +from src.cache import get_cache ANSWER_NOT_FOUND = "I'm sorry, but I don't have enough information to answer that question." @@ -114,11 +115,30 @@ def get_answer( sources = artifacts["sources"] retrievers = artifacts["retrievers"] ranker = artifacts["ranker"] + # Ensure these locals exist for all control flows to avoid UnboundLocalError ranked_chunks: List[str] = [] topk_idxs: List[int] = [] scores = [] + + cache = get_cache(cfg) + normalized_question = cache.normalize_question(question) + config_cache_key = cache.make_config_key(cfg, args, golden_chunks) + question_embedding = cache.compute_embedding(normalized_question, retrievers, cfg.embed_model) + semantic_hit = cache.lookup(config_cache_key, question_embedding, normalized_question) + + # Return cached answer if found + if semantic_hit is not None: + + ans = semantic_hit.get("answer", "") + + if is_test_mode: + return ans, semantic_hit.get("chunks_info"), semantic_hit.get("hyde_query") + console.print("Using cached answer") + render_final_answer(console, ans) + return ans + # Step 1: Get chunks (golden, retrieved, or none) chunks_info = None hyde_query = None @@ -135,7 +155,7 @@ def get_answer( # print(f"Retrieval query: {retrieval_query}") if cfg.use_hyde: retrieval_query = generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens) - + pool_n = max(cfg.num_candidates, cfg.top_k + 10) raw_scores: Dict[str, Dict[int, float]] = {} for retriever in retrievers: @@ -201,7 +221,6 @@ def get_answer( system_prompt = args.system_prompt_mode or cfg.system_prompt_mode use_double = getattr(args, "double_prompt", False) or cfg.use_double_prompt - if use_double: stream_iter = double_answer( question, @@ -220,12 +239,7 @@ def get_answer( ) if is_test_mode: - # We do not render MD in the test mode - ans = "" - for delta in stream_iter: - ans += delta - ans = dedupe_generated_text(ans) - return ans, chunks_info, hyde_query + ans = dedupe_generated_text("".join(stream_iter)) else: # Accumulate the full text while rendering incremental Markdown chunks ans = render_streaming_ans(console, stream_iter) @@ -249,7 +263,27 @@ def get_answer( top_k=len(topk_idxs), additional_log_info=additional_log_info ) - return ans + + # Step 5: Store in semantic cache + cache_payload = { + "answer": ans, + "chunks_info": chunks_info, + "hyde_query": hyde_query, + "chunk_indices": topk_idxs, + } + if question_embedding is None: + question_embedding = cache.compute_embedding(normalized_question, retrievers, cfg.embed_model) + cache.store( + config_cache_key, + normalized_question, + question_embedding, + cache_payload + ) + + if is_test_mode: + return ans, chunks_info, hyde_query + + return ans def render_streaming_ans(console, stream_iter): ans = "" @@ -266,6 +300,18 @@ def render_streaming_ans(console, stream_iter): console.print("\n[bold cyan]=== END OF ANSWER ===[/bold cyan]\n") return ans +# Fully generated answer without streaming (Usage: cache hits) +def render_final_answer(console, ans): + if not console: + raise ValueError("Console must be non null for rendering.") + console.print( + "\n[bold cyan]==================== START OF ANSWER ===================[/bold cyan]\n" + ) + console.print(Markdown(ans)) + console.print( + "\n[bold cyan]===================== END OF ANSWER ====================[/bold cyan]\n" + ) + def get_keywords(question: str) -> list: """ Simple keyword extraction from the question. diff --git a/tests/cache_benchmark.yaml b/tests/cache_benchmark.yaml new file mode 100644 index 00000000..511f2994 --- /dev/null +++ b/tests/cache_benchmark.yaml @@ -0,0 +1,449 @@ +- id: q1 + question: "What is a database management system?" + variations: + - "Can you explain what a database management system is?" + - "How would you define a database management system?" + - "Please tell me what a database management system is." + - "What exactly is a database management system?" + - "What is the definition of a database management system?" + adversarial_queries: + - "What is a file management system?" # semantic: file vs database + - "What is an operating system?" # semantic: OS vs DBMS + - "What is a spreadsheet application?" # semantic: adjacent tool, different thing + - "What is the management style of a data scientist?" # syntactic: plays on "management" and "data" + - "What does it mean to manage a system database in Windows?" # syntactic: "manage" + "database" in different order/context + +- id: q2 + question: "What is the relational model?" + variations: + - "Can you explain what the relational model is?" + - "How would you define the relational model?" + - "Please tell me what the relational model is." + - "What exactly is the relational model?" + - "What is the definition of the relational model?" + adversarial_queries: + - "What is the hierarchical model?" # semantic: sibling data model + - "What is the network model in databases?" # semantic: sibling data model + - "What is the object-oriented data model?" # semantic: sibling data model + - "What is the relationship between a model and a tokenizer?" # syntactic: relational → relationship, model stays + - "What is a relational conflict in psychology?" # syntactic: relational used in a completely different domain + +- id: q3 + question: "What is SQL?" + variations: + - "Can you explain what SQL is?" + - "How would you define SQL?" + - "Please tell me what SQL is." + - "What exactly is SQL?" + - "What is the definition of SQL?" + adversarial_queries: + - "What is NoSQL?" # semantic: directly contrasted technology + - "What are some alternatives to SQL?" # semantic: adjacent but asking for different things + - "What is the sequel to a popular movie franchise?" # syntactic: SQL is pronounced "sequel" + - "What is a skill queue list in project management?" # syntactic: S-Q-L letters reinterpreted + - "What is MySQL?" # semantic: MySQL is a DBMS, not the language itself + +- id: q4 + question: "What is a primary key?" + variations: + - "Can you explain what a primary key is?" + - "How would you define a primary key?" + - "Please tell me what a primary key is." + - "What exactly is a primary key?" + - "What is the definition of a primary key?" + adversarial_queries: + - "What is a foreign key?" # semantic: sibling key concept + - "What is a candidate key?" # semantic: sibling key concept + - "What is a composite key?" # semantic: sibling key concept + - "What is a master key used for in physical security?" # syntactic: primary/master are synonyms, but in a different domain + - "What is the primary responsibility of a database lock?" # syntactic: "primary" + "key" split apart into different meaning + +- id: q5 + question: "What is a foreign key?" + variations: + - "Can you explain what a foreign key is?" + - "How would you define a foreign key?" + - "Please tell me what a foreign key is." + - "What exactly is a foreign key?" + - "What is the definition of a foreign key?" + adversarial_queries: + - "What is a primary key?" # semantic: sibling key concept + - "What is a surrogate key?" # semantic: sibling key concept + - "What is a join in SQL?" # semantic: joins use FKs but are a different concept + - "What is a key used in foreign language translation?" # syntactic: "foreign" + "key" in a non-database context + - "What is an alien data type in a database column?" # syntactic: alien is a synonym for foreign, reframed + +- id: q6 + question: "What is an Entity-Relationship model?" + variations: + - "Can you explain what an Entity-Relationship model is?" + - "How would you define an Entity-Relationship model?" + - "Please tell me what an Entity-Relationship model is." + - "What exactly is an Entity-Relationship model?" + - "What is the definition of an Entity-Relationship model?" + adversarial_queries: + - "What is the relational model?" # semantic: different model + - "What is a UML class diagram?" # semantic: different modeling notation + - "What is a data flow diagram?" # semantic: different diagram type + - "What is the relationship between an entity and its manager in HR?" # syntactic: entity + relationship in HR context + - "What is entity authentication in network security?" # syntactic: entity reused in a security context + +- id: q7 + question: "What is normalization?" + variations: + - "Can you explain what normalization is?" + - "How would you define normalization?" + - "Please tell me what normalization is." + - "What exactly is normalization?" + - "What is the definition of normalization?" + adversarial_queries: + - "What is denormalization?" # semantic: direct opposite concept + - "What is normalization in machine learning?" # syntactic: same word, entirely different domain + - "What is normalization in signal processing?" # syntactic: same word, entirely different domain + - "What is data cleaning?" # semantic: adjacent preprocessing concept, different meaning + - "What is normal distribution in statistics?" # syntactic: "normal" shared but normalization in stats is not DB normalization + +- id: q8 + question: "What is a transaction?" + variations: + - "Can you explain what a transaction is?" + - "How would you define a transaction?" + - "Please tell me what a transaction is." + - "What exactly is a transaction?" + - "What is the definition of a transaction?" + adversarial_queries: + - "What is a stored procedure?" # semantic: different DB construct + - "What is a SQL query?" # semantic: a query is not a transaction + - "What is a financial transaction in banking?" # syntactic: same word, entirely different domain + - "What is a transaction fee on a credit card?" # syntactic: transaction reused in payments context + - "What is a batch job in data processing?" # semantic: similar in feel but different concept + +- id: q9 + question: "What are ACID properties?" + variations: + - "Can you explain what ACID properties are?" + - "How would you define ACID properties?" + - "Please tell me what ACID properties are." + - "What exactly are ACID properties?" + - "What is the definition of ACID properties?" + adversarial_queries: + - "What are BASE properties in NoSQL databases?" # semantic: directly contrasted acronym + - "What is the CAP theorem?" # semantic: related distributed systems theory + - "What are the chemical properties of an acid?" # syntactic: ACID is also a chemistry term + - "What are SOLID principles in software engineering?" # syntactic: SOLID is another well-known acronym, similar phrasing + - "What are database constraints?" # semantic: constraints enforce rules but are not ACID + +- id: q10 + question: "What is concurrency control?" + variations: + - "Can you explain what concurrency control is?" + - "How would you define concurrency control?" + - "Please tell me what concurrency control is." + - "What exactly is concurrency control?" + - "What is the definition of concurrency control?" + adversarial_queries: + - "What is deadlock?" # semantic: a problem concurrency control solves, not the concept itself + - "What is two-phase locking?" # semantic: a mechanism used in concurrency control, not the concept + - "What is parallel query processing?" # semantic: related to concurrency but different concept + - "What is currency control in international finance?" # syntactic: con-currency → currency, control stays + - "What is access control in operating systems?" # syntactic: "control" shared but access control is not concurrency control + +- id: q11 + question: "What involves indexing in databases?" + variations: + - "Can you explain what involves indexing in databases?" + - "Please describe what involves indexing in databases." + - "What exactly involves indexing in databases?" + - "I would like to know what involves indexing in databases." + - "Could you clarify what involves indexing in databases?" + adversarial_queries: + - "What is query optimization?" # semantic: indexing helps but query opt is broader + - "What is database partitioning?" # semantic: adjacent storage concept + - "What is indexing a book in a library catalog?" # syntactic: indexing in publishing, not databases + - "What is the stock market index?" # syntactic: index in finance domain + - "What is a database schema design?" # semantic: schema design is not indexing + +- id: q12 + question: "What is a B+ tree?" + variations: + - "Can you explain what a B+ tree is?" + - "How would you define a B+ tree?" + - "Please tell me what a B+ tree is." + - "What exactly is a B+ tree?" + - "What is the definition of a B+ tree?" + adversarial_queries: + - "What is a B-tree?" # semantic: sibling structure, different in important ways + - "What is a binary search tree?" # semantic: different tree structure + - "What is an AVL tree?" # semantic: different self-balancing tree + - "What are the health benefits of vitamin B complex?" # syntactic: B+ reads like a grade or vitamin "B plus" + - "What is a family tree in genealogy?" # syntactic: tree as a concept in a different domain + +- id: q13 + question: "What is hashing?" + variations: + - "Can you explain what hashing is?" + - "How would you define hashing?" + - "Please tell me what hashing is." + - "What exactly is hashing?" + - "What is the definition of hashing?" + adversarial_queries: + - "What is encryption?" # semantic: often conflated with hashing but different + - "What is a digital signature?" # semantic: uses hashing but is a different concept + - "What is a hashtag on social media?" # syntactic: hash in a completely different context + - "What is hash brown preparation in cooking?" # syntactic: hash as a culinary term + - "What is checksum verification?" # semantic: similar purpose but different mechanism + +- id: q14 + question: "What is query optimization?" + variations: + - "Can you explain what query optimization is?" + - "How would you define query optimization?" + - "Please tell me what query optimization is." + - "What exactly is query optimization?" + - "What is the definition of query optimization?" + adversarial_queries: + - "What is query parsing?" # semantic: a step before optimization, not optimization itself + - "What is query execution?" # semantic: what happens after optimization + - "What is optimization in machine learning?" # syntactic: optimization as a term in a different domain + - "What is search engine query optimization?" # syntactic: SEO query optimization means something totally different + - "What is database indexing?" # semantic: a tool that aids optimization, not optimization itself + +- id: q15 + question: "What is Big Data?" + variations: + - "Can you explain what Big Data is?" + - "How would you define Big Data?" + - "Please tell me what Big Data is." + - "What exactly is Big Data?" + - "What is the definition of Big Data?" + adversarial_queries: + - "What is a data warehouse?" # semantic: related but a specific storage architecture + - "What is data mining?" # semantic: a technique applied to data, not a paradigm + - "What does it mean to have a big database?" # syntactic: "big" + "database" literally, not the coined term + - "What is a large file in cloud storage?" # syntactic: big data as in physically large files + - "What is cloud computing?" # semantic: often mentioned with Big Data but different concept + +- id: q16 + question: "What is MapReduce?" + variations: + - "Can you explain what MapReduce is?" + - "How would you define MapReduce?" + - "Please tell me what MapReduce is." + - "What exactly is MapReduce?" + - "What is the definition of MapReduce?" + adversarial_queries: + - "What is Apache Spark?" # semantic: successor/alternative, different concept + - "What is a distributed file system?" # semantic: infrastructure MapReduce runs on + - "What is the map data structure in programming?" # syntactic: map() as a programming concept + - "What does it mean to reduce a fraction in mathematics?" # syntactic: reduce as a math operation + - "What is a Google Maps route reduction algorithm?" # syntactic: map + reduce individually repurposed + +- id: q17 + question: "What is NoSQL?" + variations: + - "Can you explain what NoSQL is?" + - "How would you define NoSQL?" + - "Please tell me what NoSQL is." + - "What exactly is NoSQL?" + - "What is the definition of NoSQL?" + adversarial_queries: + - "What is SQL?" # semantic: the thing NoSQL is named after + - "What is NewSQL?" # semantic: similar-sounding but distinct category + - "What is a distributed database?" # semantic: NoSQL is often distributed but they are not the same + - "What is the SQL NOT operator in a query?" # syntactic: No as a negation of SQL + - "What does it mean to say no to a database request?" # syntactic: No + SQL interpreted literally + +- id: q18 + question: "What is data mining?" + variations: + - "Can you explain what data mining is?" + - "How would you define data mining?" + - "Please tell me what data mining is." + - "What exactly is data mining?" + - "What is the definition of data mining?" + adversarial_queries: + - "What is a data warehouse?" # semantic: where data mining data often lives + - "What is machine learning?" # semantic: related field but different concept + - "What is cryptocurrency mining?" # syntactic: mining in the blockchain domain + - "What is coal mining and how does it work?" # syntactic: mining as a physical activity + - "What is ETL in data engineering?" # semantic: moving data, not extracting patterns + +- id: q19 + question: "What is a data warehouse?" + variations: + - "Can you explain what a data warehouse is?" + - "How would you define a data warehouse?" + - "Please tell me what a data warehouse is." + - "What exactly is a data warehouse?" + - "What is the definition of a data warehouse?" + adversarial_queries: + - "What is a data lake?" # semantic: sibling storage concept + - "What is a data mart?" # semantic: a subset of a warehouse, different concept + - "What is an operational database?" # semantic: contrasted with analytical warehouse + - "What is a physical warehouse management system?" # syntactic: warehouse as a building/logistics term + - "What is data storage in a cold storage facility?" # syntactic: warehouse → cold storage, data stays + +- id: q20 + question: "What is distributed database?" + variations: + - "Can you explain what distributed database is?" + - "How would you define distributed database?" + - "Please tell me what distributed database is." + - "What exactly is distributed database?" + - "What is the definition of distributed database?" + adversarial_queries: + - "What is a cloud database?" # semantic: often confused, but different concept + - "What is a parallel database?" # semantic: parallel is not distributed + - "What is a peer-to-peer network?" # semantic: similar topology but different concept + - "What is a distributed team in remote work management?" # syntactic: distributed used in an HR/org context + - "What is network latency in telecommunications?" # semantic: a challenge for distributed DBs, not the concept itself + +- id: q21 + question: "What is database security?" + variations: + - "Can you explain what database security is?" + - "How would you define database security?" + - "Please tell me what database security is." + - "What exactly is database security?" + - "What is the definition of database security?" + adversarial_queries: + - "What is network security?" # semantic: different security domain + - "What is database backup and recovery?" # semantic: backup is operational, not security + - "What is national security policy?" # syntactic: security reused in political context + - "What is a security guard database used for?" # syntactic: "database" + "security" swapped in meaning + - "What is data privacy regulation?" # semantic: related but policy/legal, not technical security + +- id: q22 + question: "What is RAID?" + variations: + - "Can you explain what RAID is?" + - "How would you define RAID?" + - "Please tell me what RAID is." + - "What exactly is RAID?" + - "What is the definition of RAID?" + adversarial_queries: + - "What is database replication?" # semantic: similar redundancy goal, different mechanism + - "What is database backup?" # semantic: related to durability but different concept + - "What is a police raid operation?" # syntactic: RAID as a law enforcement action + - "What is RAID insect spray and how does it work?" # syntactic: RAID as a brand name + - "What is a solid state drive?" # semantic: storage hardware, not a storage strategy + +- id: q23 + question: "What is a view in SQL?" + variations: + - "Can you explain what a view in SQL is?" + - "How would you define a view in SQL?" + - "Please tell me what a view in SQL is." + - "What exactly is a view in SQL?" + - "What is the definition of a view in SQL?" + adversarial_queries: + - "What is a stored procedure in SQL?" # semantic: different SQL construct + - "What is a SQL index?" # semantic: different SQL construct + - "What is a scenic view from a mountain?" # syntactic: view as a visual/landscape concept + - "What is a view in the MVC software architecture pattern?" # syntactic: view as a UI layer in MVC + - "What is a database trigger?" # semantic: different DB automation construct + +- id: q24 + question: "What is a trigger?" + variations: + - "Can you explain what a trigger is?" + - "How would you define a trigger?" + - "Please tell me what a trigger is." + - "What exactly is a trigger?" + - "What is the definition of a trigger?" + adversarial_queries: + - "What is a stored procedure?" # semantic: often confused with triggers + - "What is a database constraint?" # semantic: constraints enforce rules, triggers react to events + - "What is a trigger warning in psychology?" # syntactic: trigger in mental health context + - "What is the trigger mechanism on a firearm?" # syntactic: trigger as a physical gun component + - "What is a user-defined function in SQL?" # semantic: different programmable DB construct + +- id: q25 + question: "What is two-phase locking?" + variations: + - "Can you explain what two-phase locking is?" + - "How would you define two-phase locking?" + - "Please tell me what two-phase locking is." + - "What exactly is two-phase locking?" + - "What is the definition of two-phase locking?" + adversarial_queries: + - "What is two-phase commit?" # syntactic: two-phase is shared, commit is not locking + - "What is optimistic locking?" # semantic: sibling locking strategy + - "What is timestamp-based concurrency control?" # semantic: alternative to locking + - "What is a two-phase electrical power system?" # syntactic: two-phase in an electrical engineering context + - "What is locking in a hairstyling technique?" # syntactic: locking as in dreadlocks + +- id: q26 + question: "What is deadlock?" + variations: + - "Can you explain what deadlock is?" + - "How would you define deadlock?" + - "Please tell me what deadlock is." + - "What exactly is deadlock?" + - "What is the definition of deadlock?" + adversarial_queries: + - "What is livelock?" # semantic: counterpart concept to deadlock + - "What is starvation in database scheduling?" # semantic: related scheduling problem, different cause + - "What is a deadlock in labor union negotiations?" # syntactic: deadlock as a political/negotiation standoff + - "What is a dead lock on a door?" # syntactic: dead + lock as separate physical words + - "What is a race condition in concurrent programming?" # semantic: related concurrency bug, different mechanism + +- id: q27 + question: "What is write-ahead logging?" + variations: + - "Can you explain what write-ahead logging is?" + - "How would you define write-ahead logging?" + - "Please tell me what write-ahead logging is." + - "What exactly is write-ahead logging?" + - "What is the definition of write-ahead logging?" + adversarial_queries: + - "What is database checkpointing?" # semantic: used alongside WAL but different mechanism + - "What is database backup?" # semantic: backup vs logging for recovery + - "What does it mean to write ahead in an academic outline?" # syntactic: write-ahead as a planning metaphor + - "What is activity logging in web server access logs?" # syntactic: logging in a completely different IT context + - "What is undo logging in database recovery?" # semantic: sibling logging strategy with different behavior + +- id: q28 + question: "What is a stored procedure?" + variations: + - "Can you explain what a stored procedure is?" + - "How would you define a stored procedure?" + - "Please tell me what a stored procedure is." + - "What exactly is a stored procedure?" + - "What is the definition of a stored procedure?" + adversarial_queries: + - "What is a database trigger?" # semantic: often confused with stored procedures + - "What is a prepared statement?" # semantic: similar in purpose, different in nature + - "What is a storage procedure in a hospital supply room?" # syntactic: stored + procedure in a medical context + - "What is a procedure in a legal court case?" # syntactic: procedure as a legal term + - "What is a database cursor?" # semantic: different DB programmatic construct + +- id: q29 + question: "What is database recovery?" + variations: + - "Can you explain what database recovery is?" + - "How would you define database recovery?" + - "Please tell me what database recovery is." + - "What exactly is database recovery?" + - "What is the definition of database recovery?" + adversarial_queries: + - "What is database backup?" # semantic: backup enables recovery but is not recovery itself + - "What is RAID?" # semantic: RAID provides redundancy, not recovery protocols + - "What is addiction recovery and how does it work?" # syntactic: recovery in a medical/social context + - "What is economic recovery after a recession?" # syntactic: recovery as a macroeconomic term + - "What is disaster recovery planning in IT?" # semantic: broader IT concept, not database-specific + +- id: q30 + question: "What is XML?" + variations: + - "Can you explain what XML is?" + - "How would you define XML?" + - "Please tell me what XML is." + - "What exactly is XML?" + - "What is the definition of XML?" + adversarial_queries: + - "What is JSON?" # semantic: sibling data serialization format + - "What is HTML?" # semantic: XML sibling, different purpose + - "What is YAML?" # semantic: sibling serialization format + - "What is an X-ray and how does it work in medicine?" # syntactic: X as in XML's first letter, completely different domain + - "What does XML stand for in the military or government?" # syntactic: XML as a potential government acronym \ No newline at end of file diff --git a/tests/test_cache_benchmark.py b/tests/test_cache_benchmark.py new file mode 100644 index 00000000..9c95fced --- /dev/null +++ b/tests/test_cache_benchmark.py @@ -0,0 +1,235 @@ +import pytest +import numpy as np +import yaml +from pathlib import Path +from unittest.mock import MagicMock +from src.config import RAGConfig +from src.cache import get_cache + + +# ----------------------------- +# Data loading +# ----------------------------- + +def load_benchmark_data(): + """Load benchmark questions from YAML.""" + yaml_path = Path(__file__).parent / "cache_benchmark.yaml" + with open(yaml_path, "r") as f: + return yaml.safe_load(f) + + +BENCHMARK_DATA = load_benchmark_data() + + +# ----------------------------- +# Fixtures +# ----------------------------- + +@pytest.fixture +def mock_config(): + """Create a mock RAGConfig for testing.""" + config = MagicMock(spec=RAGConfig) + config.gen_model = "mock-model" + config.embed_model = "models/Qwen3-Embedding-4B-Q5_K_M.gguf" + config.top_k = 5 + config.system_prompt_mode = "baseline" + config.ensemble_method = "rrf" + config.ranker_weights = {"faiss": 0.5, "bm25": 0.5} + config.use_hyde = False + config.use_indexed_chunks = False + config.disable_chunks = False + config.use_golden_chunks = False + config.semantic_cache_enabled = True + config.semantic_cache_bi_encoder_threshold = 0.90 + config.semantic_cache_cross_encoder_threshold = 0.99 + return config + + +# ----------------------------- +# Helpers +# ----------------------------- + +def print_separator(): + print("*" * 65) + + +def print_question_block(question_id, main_question, var_results, adversarial_results): + """ + Print a formatted block for a single question with tick/X results + for both variations and adversarial_queries. + + var_results: list of (question_str, hit: bool) + adversarial_results: list of (question_str, hit: bool) + """ + print_separator() + print(f" [{question_id}] {main_question}") + print() + + print(" Variations (hits are good ✅):") + for q, hit in var_results: + symbol = "✅" if hit else "❌" + print(f" {symbol} {q}") + + print() + print(" Adversarial Queries (hits are bad ⚠️):") + for q, hit in adversarial_results: + symbol = "⚠️ " if hit else "✅" + print(f" {symbol} {q}") + + var_hits = sum(1 for _, h in var_results if h) + adversarial_hits = sum(1 for _, h in adversarial_results if h) + print() + print(f" Accuracy : {var_hits}/{len(var_results)} variations matched") + print(f" False Positives: {adversarial_hits}/{len(adversarial_results)} adversarial queries falsely matched") + + +# ----------------------------- +# Test +# ----------------------------- + +def test_cache_benchmark_comprehensive(mock_config): + """ + Benchmark the semantic cache against 30 questions, each with: + - 5 genuine paraphrase variations (hits expected) + - 5 adversarial queries (hits NOT expected) + + Scores: + Accuracy rate — fraction of genuine variations that got a cache hit. + Higher is better. Target >= 60%. + False Positives rate — fraction of adversarial queries that falsely got a cache hit. + Lower is better. Target <= 0%. + """ + + cache = get_cache(mock_config) + cache.clear() + + args = MagicMock() + args.model_path = None + args.system_prompt_mode = None + args.index_prefix = "test_index" + + cache_key = cache.make_config_key(mock_config, args, None) + embed_model_name = mock_config.embed_model + + total_var_hits = 0 + total_variations = 0 + total_adversarial_hits = 0 + total_adversarial_queries = 0 + + accuracy_failures = [] + false_positive_failures = [] + + print(f"\n{'*' * 65}") + print(f" SEMANTIC CACHE BENCHMARK") + print(f" Embedding model : {embed_model_name}") + print(f" Questions : {len(BENCHMARK_DATA)}") + print(f" Variations each : 5 genuine + 5 adversarial") + print(f"{'*' * 65}") + + for entry in BENCHMARK_DATA: + question_id = entry["id"] + main_question = entry["question"] + variations = entry.get("variations", []) + adversarial_vars = entry.get("adversarial_queries", []) + + # --- Seed the cache with the canonical question --- + normalized_main = cache.normalize_question(main_question) + embedding_main = cache.compute_embedding(normalized_main, [], embed_model_name) + assert embedding_main is not None, ( + f"Failed to compute embedding for {question_id}: '{main_question}'" + ) + + payload = { + "answer": f"Cached answer for {question_id}", + "chunks_info": [], + "hyde_query": None, + "chunk_indices": [], + } + cache.store(cache_key, normalized_main, embedding_main, payload) + + # --- Test genuine variations --- + var_results = [] + for var_q in variations: + normalized_var = cache.normalize_question(var_q) + embedding_var = cache.compute_embedding(normalized_var, [], embed_model_name) + hit = cache.lookup(cache_key, embedding_var, normalized_var) is not None + var_results.append((var_q, hit)) + if not hit: + accuracy_failures.append(f"[{question_id}] Missed variation : '{var_q}'") + + # --- Test adversarial queries --- + adversarial_results = [] + for adversarial_q in adversarial_vars: + normalized_adversarial = cache.normalize_question(adversarial_q) + embedding_adversarial = cache.compute_embedding(normalized_adversarial, [], embed_model_name) + + payload_hit = cache.lookup(cache_key, embedding_adversarial, normalized_adversarial) + # A false hit is when the cache incorrectly returns the CURRENT question's answer + # for a query that is semantically different. + hit = payload_hit is not None and payload_hit.get("answer") == f"Cached answer for {question_id}" + adversarial_results.append((adversarial_q, hit)) + if hit: + false_positive_failures.append(f"[{question_id}] False hit on adversarial: '{adversarial_q}'") + + # --- Accumulate totals --- + total_var_hits += sum(1 for _, h in var_results if h) + total_variations += len(var_results) + total_adversarial_hits += sum(1 for _, h in adversarial_results if h) + total_adversarial_queries += len(adversarial_results) + + # --- Print per-question block --- + print_question_block(question_id, main_question, var_results, adversarial_results) + + # --- Final summary --- + accuracy_rate = total_var_hits / total_variations if total_variations > 0 else 0.0 + false_positive_rate = total_adversarial_hits / total_adversarial_queries if total_adversarial_queries > 0 else 0.0 + + print_separator() + print() + print(f" {'FINAL BENCHMARK RESULTS':^61}") + print() + print(f" {'Metric':<35} {'Score':>10} {'Count'}") + print(f" {'-'*60}") + print(f" {'Accuracy Rate (higher is better)':<35} {accuracy_rate:>9.1%} {total_var_hits}/{total_variations}") + print(f" {'False Positive Rate (lower is better)':<35} {false_positive_rate:>9.1%} {total_adversarial_hits}/{total_adversarial_queries}") + print() + print(" What these scores mean:") + print() + print(" Accuracy Rate — measures how often the cache correctly") + print(" recognises a genuine paraphrase of a cached question.") + print(" A high accuracy rate means users asking the same thing") + print(" in different words will get a fast cached response.") + print(" Target: >= 60%") + print() + print(" False Positive Rate — measures how often the cache is fooled") + print(" into returning an answer for a semantically or") + print(" syntactically similar but DIFFERENT question. A false") + print(" hit here means a user gets the wrong cached answer.") + print(" Lower is strictly better. Target: <= 0%") + + if accuracy_failures: + print() + print(f" Accuracy misses ({len(accuracy_failures)} total, showing first 10):") + for msg in accuracy_failures[:10]: + print(f" ❌ {msg}") + + if false_positive_failures: + print() + print(f" False Positive hits ({len(false_positive_failures)} total, showing first 10):") + for msg in false_positive_failures[:10]: + print(f" ⚠️ {msg}") + + print() + print_separator() + + # --- Assertions --- + assert accuracy_rate >= 0.80, ( + f"Accuracy rate {accuracy_rate:.1%} is below the 80% target. " + f"The cache is missing too many genuine paraphrases.\n" + f"First 10 misses: {accuracy_failures[:10]}" + ) + assert false_positive_rate <= 0.05, ( + f"False Positives rate {false_positive_rate:.1%} exceeds the 5% target. " + f"The cache is returning false hits for adversarial queries.\n" + f"First 10 leaks: {false_positive_failures[:10]}" + ) \ No newline at end of file