Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
2d68a87
add semantic caching with Bi-Direction Encoder and Cross Encoder
sureshkumarsrinath Feb 6, 2026
9cdf2a7
initial cleanup
sureshkumarsrinath Feb 10, 2026
d07a908
print the cache answer without straming
sureshkumarsrinath Feb 11, 2026
7e5af19
code clean up
sureshkumarsrinath Feb 11, 2026
305b10d
benchmark checkpoint
sureshkumarsrinath Feb 13, 2026
d8ee028
fix benchmark variations
sureshkumarsrinath Mar 6, 2026
f18adf9
merge conflicts
sureshkumarsrinath Mar 6, 2026
e87e664
add flag for semantic caching
sureshkumarsrinath Mar 6, 2026
76e02f5
Fix unit tests for semantic caching
sureshkumarsrinath Mar 6, 2026
d539b0a
resolve review comments
sureshkumarsrinath Mar 11, 2026
8614e7d
use RAGConfig methid get_config_state() to get payload for hashing key
Mar 13, 2026
d783bc3
add gpu support try for generator
Mar 13, 2026
d28820c
Merge branch 'main' into semantic-caching
Mar 13, 2026
1f27533
added comments to get_answer() in mainso all steps are explained
Mar 13, 2026
f433a41
updated semantic cache testing to include false positive hit checking
Mar 13, 2026
a477f41
fix review comments
sureshkumarsrinath Mar 13, 2026
c8f6fbb
merge main into feature branch for latest updates
sureshkumarsrinath Mar 13, 2026
7faef64
resolve conflicts
sureshkumarsrinath Apr 1, 2026
221e6f3
resolve merge conflicts
sureshkumarsrinath Apr 1, 2026
af7d65e
fix the cache threshold
sureshkumarsrinath Apr 2, 2026
0e27032
fixed interface and add no-op cache
sureshkumarsrinath Apr 2, 2026
9a3f982
use the same cache for every call
sureshkumarsrinath Apr 2, 2026
ce80f6e
clean up PR
sureshkumarsrinath Apr 2, 2026
7a4003a
Merge branch 'main' into semantic-caching
sureshkumarsrinath Apr 3, 2026
bc765ba
fix review comments
sureshkumarsrinath Apr 9, 2026
1ca44f2
resolve conflicts:
sureshkumarsrinath Apr 9, 2026
8315008
fix the enabled flag
sureshkumarsrinath Apr 9, 2026
3743200
update with main
sureshkumarsrinath Apr 16, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ rerank_top_k: 5
use_double_prompt: false
enable_history: true
max_history_turns: 3
enable_topic_extraction: false
semantic_cache_enabled: false
semantic_cache_bi_encoder_threshold: 0.90
semantic_cache_cross_encoder_threshold: 0.99
enable_topic_extraction: false
204 changes: 204 additions & 0 deletions src/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
import argparse
import json
import hashlib
from typing import Dict, Optional, Any, List, Deque
from collections import deque
from abc import ABC, abstractmethod

import numpy as np
from sentence_transformers import CrossEncoder

from src.embedder import SentenceTransformer
from src.config import RAGConfig
from src.retriever import BM25Retriever, FAISSRetriever, IndexKeywordRetriever, load_artifacts, filter_retrieved_chunks


class BaseResponseCache(ABC):
@abstractmethod
def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]:
pass

@abstractmethod
def store(self, config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None:
pass

@abstractmethod
def clear(self) -> None:
pass

@abstractmethod
def make_config_key(self, cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str:
pass

@abstractmethod
def compute_embedding(self, question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]:
pass

@abstractmethod
def normalize_question(self, q: str) -> str:
pass


class SemanticCache(BaseResponseCache):
def __init__(self, bi_encoder_threshold: float, cross_encoder_threshold: float, max_entries: int = 50):
self.cache: Dict[str, Deque[Dict[str, Any]]] = {}
self.bi_encoder_threshold = bi_encoder_threshold
self.cross_encoder_threshold = cross_encoder_threshold
self.max_entries = max_entries
self.question_embedders: Dict[str, SentenceTransformer] = {}
self.cross_encoder_model: Optional[CrossEncoder] = None

def _get_cross_encoder(self) -> CrossEncoder:
"""Return a global cross-encoder model instance, initializing if needed."""
if self.cross_encoder_model is None:
self.cross_encoder_model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
return self.cross_encoder_model

def normalize_question(self, q: str) -> str:
"""Normalize a question string: lowercase, strip, and collapse spaces."""
return " ".join((q or "").strip().lower().split())

def make_config_key(self, cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str:
"""
Create a unique JSON key for semantic cache based on config, arguments, and optional golden chunks.
"""
try:
payload = RAGConfig.get_config_state()
except Exception:
payload = {
"gen_model": getattr(args, "model_path", None) or cfg.gen_model,
"embed_model": cfg.embed_model,
"top_k": cfg.top_k,
"system_prompt_mode": getattr(args, "system_prompt_mode", None) or cfg.system_prompt_mode,
"ensemble_method": cfg.ensemble_method,
"ranker_weights": cfg.ranker_weights,
"use_hyde": cfg.use_hyde,
"use_indexed_chunks": cfg.use_indexed_chunks,
"disable_chunks": cfg.disable_chunks,
"use_golden_chunks": bool(golden_chunks and cfg.use_golden_chunks),
"index_prefix": getattr(args, "index_prefix", None),
}

if golden_chunks and cfg.use_golden_chunks:
signature = hashlib.sha256("||".join(golden_chunks).encode("utf-8")).hexdigest()
payload["golden_signature"] = signature

return json.dumps(payload, sort_keys=True)

def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]:
"""
Retrieve a cached answer if semantically similar to the current question.
"""
entries = self.cache.get(config_key, [])
if not entries or query_embedding is None:
return None

# Step 1: Bi-Encoder filter (fast cosine similarity)
candidates = [
entry for entry in entries
if np.dot(entry["embedding"], query_embedding) > self.bi_encoder_threshold
]
if not candidates:
return None

# Step 2: Cross-Encoder verification
ce_model = self._get_cross_encoder()
pairs = [[current_question, c["question"]] for c in candidates]
ce_scores = ce_model.predict(pairs, show_progress_bar=False)
best_idx = int(np.argmax(ce_scores))

if ce_scores[best_idx] > self.cross_encoder_threshold:
return candidates[best_idx]["payload"]
return None

def store(self, config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None:
"""
Store a question, its embedding, and the generated answer in the semantic cache.
Evict oldest entries if cache exceeds self.max_entries.
"""
if question_embedding is None:
return

if config_key not in self.cache:
self.cache[config_key] = deque()
entries = self.cache[config_key]
entries.append({
"question": normalized_question,
"embedding": question_embedding.astype(np.float32),
"payload": payload,
})

if len(entries) > self.max_entries:
entries.popleft()

def clear(self) -> None:
self.cache.clear()

def _get_question_embedder(self, retrievers: List[Any], embed_model: str) -> Optional[SentenceTransformer]:
"""
Get or initialize a SentenceTransformer for encoding questions.
Prefers the embedder from any FAISSRetriever in the retrievers list.
"""
for retriever in retrievers:
if isinstance(retriever, FAISSRetriever):
return retriever.embedder

model_path = embed_model
if not model_path:
return None

embedder = self.question_embedders.get(model_path)
if embedder is None:
embedder = SentenceTransformer(model_path)
self.question_embedders[model_path] = embedder

return embedder

def compute_embedding(self, question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]:
"""
Compute a normalized embedding vector for a question using the configured embedder.
"""
embedder = self._get_question_embedder(retrievers, embed_model)
if not embedder:
return None

vec = embedder.encode([question], batch_size=1, normalize=True, show_progress_bar=False)
if vec.size == 0:
return None

return vec[0]


class NoOpCache(BaseResponseCache):
def lookup(self, config_key: str, query_embedding: np.ndarray, current_question: str) -> Optional[Dict[str, Any]]:
return None

def store(self, config_key: str, normalized_question: str, question_embedding: Optional[np.ndarray], payload: Dict[str, Any]) -> None:
pass

def clear(self) -> None:
pass

def make_config_key(self, cfg: RAGConfig, args: argparse.Namespace, golden_chunks: Optional[List[str]]) -> str:
return ""

def compute_embedding(self, question: str, retrievers: List[Any], embed_model: str) -> Optional[np.ndarray]:
return None

def normalize_question(self, q: str) -> str:
return ""


_GLOBAL_SEMANTIC_CACHE: Optional[SemanticCache] = None

def get_cache(cfg: RAGConfig) -> BaseResponseCache:
"""Return a configured cache layer, either SemanticCache or NoOpCache depending on config."""
global _GLOBAL_SEMANTIC_CACHE
if getattr(cfg, 'semantic_cache_enabled', False):
if _GLOBAL_SEMANTIC_CACHE is None:
_GLOBAL_SEMANTIC_CACHE = SemanticCache(
bi_encoder_threshold=cfg.semantic_cache_bi_encoder_threshold,
cross_encoder_threshold=cfg.semantic_cache_cross_encoder_threshold
)
return _GLOBAL_SEMANTIC_CACHE
return NoOpCache()
5 changes: 5 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ class RAGConfig:
hyde_max_tokens: int = 300
use_double_prompt: bool = False

# cache
semantic_cache_enabled: bool = False
semantic_cache_bi_encoder_threshold: float = 0.90
semantic_cache_cross_encoder_threshold: float = 0.99

# conversational memory
enable_history: bool = True
max_history_turns: int = 3
Expand Down
64 changes: 55 additions & 9 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
load_artifacts
)
from src.ranking.reranker import rerank
from src.cache import get_cache

ANSWER_NOT_FOUND = "I'm sorry, but I don't have enough information to answer that question."

Expand Down Expand Up @@ -114,11 +115,30 @@ def get_answer(
sources = artifacts["sources"]
retrievers = artifacts["retrievers"]
ranker = artifacts["ranker"]

# Ensure these locals exist for all control flows to avoid UnboundLocalError
ranked_chunks: List[str] = []
topk_idxs: List[int] = []
scores = []

cache = get_cache(cfg)
normalized_question = cache.normalize_question(question)
config_cache_key = cache.make_config_key(cfg, args, golden_chunks)
question_embedding = cache.compute_embedding(normalized_question, retrievers, cfg.embed_model)

semantic_hit = cache.lookup(config_cache_key, question_embedding, normalized_question)

# Return cached answer if found
if semantic_hit is not None:

ans = semantic_hit.get("answer", "")

if is_test_mode:
return ans, semantic_hit.get("chunks_info"), semantic_hit.get("hyde_query")
console.print("Using cached answer")
render_final_answer(console, ans)
return ans

# Step 1: Get chunks (golden, retrieved, or none)
chunks_info = None
hyde_query = None
Expand All @@ -135,7 +155,7 @@ def get_answer(
# print(f"Retrieval query: {retrieval_query}")
if cfg.use_hyde:
retrieval_query = generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens)

pool_n = max(cfg.num_candidates, cfg.top_k + 10)
raw_scores: Dict[str, Dict[int, float]] = {}
for retriever in retrievers:
Expand Down Expand Up @@ -201,7 +221,6 @@ def get_answer(
system_prompt = args.system_prompt_mode or cfg.system_prompt_mode

use_double = getattr(args, "double_prompt", False) or cfg.use_double_prompt

if use_double:
stream_iter = double_answer(
question,
Expand All @@ -220,12 +239,7 @@ def get_answer(
)

if is_test_mode:
# We do not render MD in the test mode
ans = ""
for delta in stream_iter:
ans += delta
ans = dedupe_generated_text(ans)
return ans, chunks_info, hyde_query
ans = dedupe_generated_text("".join(stream_iter))
else:
# Accumulate the full text while rendering incremental Markdown chunks
ans = render_streaming_ans(console, stream_iter)
Expand All @@ -249,7 +263,27 @@ def get_answer(
top_k=len(topk_idxs),
additional_log_info=additional_log_info
)
return ans

# Step 5: Store in semantic cache
cache_payload = {
"answer": ans,
"chunks_info": chunks_info,
"hyde_query": hyde_query,
"chunk_indices": topk_idxs,
}
if question_embedding is None:
question_embedding = cache.compute_embedding(normalized_question, retrievers, cfg.embed_model)
cache.store(
config_cache_key,
normalized_question,
question_embedding,
cache_payload
)

if is_test_mode:
return ans, chunks_info, hyde_query

return ans

def render_streaming_ans(console, stream_iter):
ans = ""
Expand All @@ -266,6 +300,18 @@ def render_streaming_ans(console, stream_iter):
console.print("\n[bold cyan]=== END OF ANSWER ===[/bold cyan]\n")
return ans

# Fully generated answer without streaming (Usage: cache hits)
def render_final_answer(console, ans):
if not console:
raise ValueError("Console must be non null for rendering.")
console.print(
"\n[bold cyan]==================== START OF ANSWER ===================[/bold cyan]\n"
)
console.print(Markdown(ans))
console.print(
"\n[bold cyan]===================== END OF ANSWER ====================[/bold cyan]\n"
)

def get_keywords(question: str) -> list:
"""
Simple keyword extraction from the question.
Expand Down
Loading
Loading