From 02da156aa573803a4e5e56220ff8f9cebfa51ee8 Mon Sep 17 00:00:00 2001 From: ksuleimanova Date: Tue, 14 Apr 2026 20:49:13 -0400 Subject: [PATCH 1/3] L1 cache --- src/config.py | 5 ++ src/l1_cache.py | 160 ++++++++++++++++++++++++++++++++++++++++++++++++ src/main.py | 86 +++++++++++++++++++++----- 3 files changed, 234 insertions(+), 17 deletions(-) create mode 100644 src/l1_cache.py diff --git a/src/config.py b/src/config.py index 35811fc6..89e3b5ab 100644 --- a/src/config.py +++ b/src/config.py @@ -29,6 +29,9 @@ class RAGConfig: ) rerank_mode: str = "" rerank_top_k: int = 5 + enable_l1_cache: bool = True + l1_cache_max_entries: int = 256 + l1_cache_ttl_seconds: int = 600 # generation max_gen_tokens: int = 400 @@ -69,6 +72,8 @@ def __post_init__(self): """Validation logic runs automatically after initialization.""" assert self.top_k > 0, "top_k must be > 0" assert self.num_candidates >= self.top_k, "num_candidates must be >= top_k" + assert self.l1_cache_max_entries > 0, "l1_cache_max_entries must be > 0" + assert self.l1_cache_ttl_seconds > 0, "l1_cache_ttl_seconds must be > 0" assert self.ensemble_method.lower() in {"linear", "weighted", "rrf"} assert self.embedding_model_context_window > 0, "embedding_model_context_window must be > 0" if self.ensemble_method.lower() in {"linear", "weighted"}: diff --git a/src/l1_cache.py b/src/l1_cache.py new file mode 100644 index 00000000..a7baf860 --- /dev/null +++ b/src/l1_cache.py @@ -0,0 +1,160 @@ +""" +In-memory L1 cache for retrieval + ranking outputs. + +Key: query embedding hash (from normalized query text) +Value: normalized query text, top chunk ids, top chunk scores, and LRU/TTL metadata +""" + +from __future__ import annotations + +import hashlib +import re +import time +from dataclasses import dataclass +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from src.embedder import CachedEmbedder + + +_EMBEDDER_CACHE: Dict[Tuple[str, int], CachedEmbedder] = {} + + +def _get_embedder(model_path: str, context_window: int) -> CachedEmbedder: + key = (model_path, int(context_window)) + if key not in _EMBEDDER_CACHE: + _EMBEDDER_CACHE[key] = CachedEmbedder(model_path, n_ctx=int(context_window)) + return _EMBEDDER_CACHE[key] + + +def normalize_query_text(query: str) -> str: + """Normalize query text for stable embedding hash keys.""" + return re.sub(r"\s+", " ", query.strip().lower()) + + +@dataclass +class L1CacheEntry: + normalized_query_text: str + top_chunk_ids: List[int] + top_chunk_scores: List[float] + params_signature: str + created_at: float + expires_at: float + last_access_at: float + access_count: int + + +class L1RetrievalCache: + def __init__(self, max_entries: int = 256, ttl_seconds: int = 600): + self.max_entries = max(1, int(max_entries)) + self.ttl_seconds = max(1, int(ttl_seconds)) + self._entries: "OrderedDict[str, L1CacheEntry]" = OrderedDict() + + def _now(self) -> float: + return time.time() + + def _is_expired(self, entry: L1CacheEntry, now: float) -> bool: + return entry.expires_at <= now + + def _evict_expired(self, now: float) -> None: + expired = [k for k, v in self._entries.items() if self._is_expired(v, now)] + for key in expired: + self._entries.pop(key, None) + + def _evict_lru_if_needed(self) -> None: + while len(self._entries) > self.max_entries: + self._entries.popitem(last=False) + + def _params_signature(params: Dict[str, Any]) -> str: + ordered = sorted(params.items(), key=lambda kv: kv[0]) + return hashlib.sha256(repr(ordered).encode("utf-8")).hexdigest() + + def _make_query_embedding_hash( + self, + normalized_query_text: str, + embed_model: str, + embedding_context_window: int, + ) -> str: + embedder = _get_embedder(embed_model, embedding_context_window) + vec = embedder.encode([normalized_query_text], normalize=True).astype(np.float32) + return hashlib.sha256(vec.tobytes()).hexdigest() + + def get( + self, + query: str, + embed_model: str, + embedding_context_window: int, + params: Dict[str, Any], + ) -> Optional[L1CacheEntry]: + now = self._now() + self._evict_expired(now) + + normalized_query_text = normalize_query_text(query) + key = self._make_query_embedding_hash( + normalized_query_text, + embed_model, + embedding_context_window, + ) + + entry = self._entries.get(key) + if entry is None: + return None + + expected_sig = self._params_signature(params) + if entry.params_signature != expected_sig: + self._entries.pop(key, None) + return None + + if self._is_expired(entry, now): + self._entries.pop(key, None) + return None + + entry.last_access_at = now + entry.access_count += 1 + self._entries.move_to_end(key) + return entry + + def set( + self, + query: str, + embed_model: str, + embedding_context_window: int, + top_chunk_ids: List[int], + top_chunk_scores: List[float], + params: Dict[str, Any], + ) -> L1CacheEntry: + now = self._now() + self._evict_expired(now) + + normalized_query_text = normalize_query_text(query) + key = self._make_query_embedding_hash( + normalized_query_text, + embed_model, + embedding_context_window, + ) + + entry = L1CacheEntry( + normalized_query_text=normalized_query_text, + top_chunk_ids=[int(i) for i in top_chunk_ids], + top_chunk_scores=[float(s) for s in top_chunk_scores], + params_signature=self._params_signature(params), + created_at=now, + expires_at=now + self.ttl_seconds, + last_access_at=now, + access_count=1, + ) + + if key in self._entries: + self._entries.pop(key, None) + self._entries[key] = entry + self._evict_lru_if_needed() + return entry + + def stats(self) -> Dict[str, int]: + return { + "entries": len(self._entries), + "max_entries": self.max_entries, + "ttl_seconds": self.ttl_seconds, + } diff --git a/src/main.py b/src/main.py index 64a7da8c..73d0783d 100644 --- a/src/main.py +++ b/src/main.py @@ -12,6 +12,7 @@ from rich.markdown import Markdown from src.config import RAGConfig +from src.l1_cache import L1RetrievalCache from src.generator import answer, double_answer, dedupe_generated_text from src.index_builder import build_index from src.instrumentation.logging import get_logger @@ -122,6 +123,7 @@ def get_answer( # Step 1: Get chunks (golden, retrieved, or none) chunks_info = None hyde_query = None + cache_meta: Dict[str, Any] = {} if golden_chunks and cfg.use_golden_chunks: # Use provided golden chunks ranked_chunks = golden_chunks @@ -136,21 +138,60 @@ def get_answer( if cfg.use_hyde: retrieval_query = generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens) - pool_n = max(cfg.num_candidates, cfg.top_k + 10) - raw_scores: Dict[str, Dict[int, float]] = {} - for retriever in retrievers: - # print(f"Getting scores from retriever: {retriever.name}...") - raw_scores[retriever.name] = retriever.get_scores(retrieval_query, pool_n, chunks) - # TODO: Fix retrieval logging. - - # print("Raw scores from retrievers:") - # for retriever_name, score_dict in raw_scores.items(): - # print(f" {retriever_name}: {list(score_dict.values())}") - # Step 2: Ranking - ordered, scores = ranker.rank(raw_scores=raw_scores) - # print(f"Ordered candidate indices after ranking: {ordered[:cfg.top_k]}") - # print(f"Corresponding scores: {scores[:cfg.top_k]}") - topk_idxs = filter_retrieved_chunks(cfg, chunks, ordered) + l1_cache: Optional[L1RetrievalCache] = artifacts.get("l1_cache") + retriever_names = [r.name for r in retrievers] + retrieval_params = { + "top_k": cfg.top_k, + "num_candidates": cfg.num_candidates, + "ensemble_method": cfg.ensemble_method, + "rrf_k": cfg.rrf_k, + "ranker_weights": tuple(sorted(cfg.ranker_weights.items())), + "retrievers": tuple(sorted(retriever_names)), + "use_hyde": cfg.use_hyde, + "retrieval_query": retrieval_query, + } + + cache_entry = None + if cfg.enable_l1_cache and l1_cache is not None: + cache_entry = l1_cache.get( + query=retrieval_query, + embed_model=cfg.embed_model, + embedding_context_window=cfg.embedding_model_context_window, + params=retrieval_params, + ) + + if cache_entry is not None: + topk_idxs = cache_entry.top_chunk_ids + scores = cache_entry.top_chunk_scores + cache_meta = { + "l1_cache": { + "hit": True, + "normalized_query_text": cache_entry.normalized_query_text, + "last_access_at": cache_entry.last_access_at, + "expires_at": cache_entry.expires_at, + "access_count": cache_entry.access_count, + } + } + else: + pool_n = max(cfg.num_candidates, cfg.top_k + 10) + raw_scores: Dict[str, Dict[int, float]] = {} + for retriever in retrievers: + raw_scores[retriever.name] = retriever.get_scores(retrieval_query, pool_n, chunks) + + ordered, scores = ranker.rank(raw_scores=raw_scores) + topk_idxs = filter_retrieved_chunks(cfg, chunks, ordered) + + if cfg.enable_l1_cache and l1_cache is not None: + l1_cache.set( + query=retrieval_query, + embed_model=cfg.embed_model, + embedding_context_window=cfg.embedding_model_context_window, + top_chunk_ids=topk_idxs, + top_chunk_scores=scores[:len(topk_idxs)], + params=retrieval_params, + ) + cache_meta = {"l1_cache": {"hit": False}} + ranked_chunks = [chunks[i] for i in topk_idxs] # print(f"Top-{cfg.top_k} chunk indices after filtering: {topk_idxs}") # print("Len Ranked chunks:", len(ranked_chunks)) @@ -247,7 +288,7 @@ def get_answer( page_map=page_nums, full_response=ans, top_k=len(topk_idxs), - additional_log_info=additional_log_info + additional_log_info={**(additional_log_info or {}), **cache_meta} ) return ans @@ -292,8 +333,19 @@ def run_chat_session(args: argparse.Namespace, cfg: RAGConfig): retrievers.append(IndexKeywordRetriever(cfg.extracted_index_path, cfg.page_to_chunk_map_path)) ranker = EnsembleRanker(ensemble_method=cfg.ensemble_method, weights=cfg.ranker_weights, rrf_k=int(cfg.rrf_k)) + l1_cache = L1RetrievalCache( + max_entries=cfg.l1_cache_max_entries, + ttl_seconds=cfg.l1_cache_ttl_seconds, + ) print("Loaded retrievers and initialized ranker.") - artifacts = {"chunks": chunks, "sources": sources, "retrievers": retrievers, "ranker": ranker, "meta": meta} + artifacts = { + "chunks": chunks, + "sources": sources, + "retrievers": retrievers, + "ranker": ranker, + "meta": meta, + "l1_cache": l1_cache, + } except Exception as e: print(f"ERROR: {e}. Run 'index' mode first.") sys.exit(1) From 9ece70b94f4040f766fca0bf273998cf54375432 Mon Sep 17 00:00:00 2001 From: ksuleimanova Date: Tue, 14 Apr 2026 21:27:33 -0400 Subject: [PATCH 2/3] Changes to api server --- src/api_server.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/src/api_server.py b/src/api_server.py index 5e9cf243..33ce2316 100644 --- a/src/api_server.py +++ b/src/api_server.py @@ -26,6 +26,7 @@ from src.config import RAGConfig from src.generator import answer +from src.l1_cache import L1RetrievalCache from src.feedback_store import ( init_feedback_db, save_answer, @@ -49,6 +50,7 @@ _config: Optional[RAGConfig] = None _logger = None _topic_extractor: Optional[TopicExtractor] = None +_l1_cache: Optional[L1RetrievalCache] = None class SourceItem(BaseModel): @@ -149,6 +151,38 @@ def _create_log(chunks , sources , topk_idxs, ordered_ranked_scores, page_nums, def _retrieve_and_rank(query: str, top_k: Optional[int] = None): chunks = _artifacts["chunks"] effective_top_k = top_k if top_k is not None else _config.top_k + + # L1 cache lookup (additive path; falls back to existing retrieval flow on miss) + cache_entry = None + if getattr(_config, "enable_l1_cache", False) and _l1_cache is not None: + retriever_names = [r.name for r in _retrievers] + retrieval_params = { + "top_k": effective_top_k, + "num_candidates": _config.num_candidates, + "ensemble_method": _config.ensemble_method, + "rrf_k": _config.rrf_k, + "ranker_weights": tuple(sorted(_config.ranker_weights.items())), + "retrievers": tuple(sorted(retriever_names)), + "query": query, + } + cache_entry = _l1_cache.get( + query=query, + embed_model=_config.embed_model, + embedding_context_window=getattr(_config, "embedding_model_context_window", 4096), + params=retrieval_params, + ) + + if cache_entry is not None: + ordered_ids = cache_entry.top_chunk_ids + ordered_scores = cache_entry.top_chunk_scores + if top_k is not None: + ordered_ids = ordered_ids[:top_k] + ordered_scores = ordered_scores[:top_k] + else: + ordered_ids = ordered_ids[:_config.top_k] + ordered_scores = ordered_scores[:_config.top_k] + return ordered_ids, ordered_scores + pool_n = max(_config.num_candidates, effective_top_k + 10) raw_scores: Dict[str, Dict[int, float]] = {} @@ -157,6 +191,27 @@ def _retrieve_and_rank(query: str, top_k: Optional[int] = None): ordered_ids, ordered_scores = _ranker.rank(raw_scores=raw_scores) + # L1 cache write-back after successful retrieval + ranking. + if getattr(_config, "enable_l1_cache", False) and _l1_cache is not None: + retriever_names = [r.name for r in _retrievers] + retrieval_params = { + "top_k": effective_top_k, + "num_candidates": _config.num_candidates, + "ensemble_method": _config.ensemble_method, + "rrf_k": _config.rrf_k, + "ranker_weights": tuple(sorted(_config.ranker_weights.items())), + "retrievers": tuple(sorted(retriever_names)), + "query": query, + } + _l1_cache.set( + query=query, + embed_model=_config.embed_model, + embedding_context_window=getattr(_config, "embedding_model_context_window", 4096), + top_chunk_ids=ordered_ids[:effective_top_k], + top_chunk_scores=ordered_scores[:effective_top_k], + params=retrieval_params, + ) + if top_k is not None: ordered_ids = ordered_ids[:top_k] ordered_scores = ordered_scores[:top_k] @@ -169,7 +224,7 @@ def _retrieve_and_rank(query: str, top_k: Optional[int] = None): @asynccontextmanager async def lifespan(app: FastAPI): """Initialize artifacts on startup.""" - global _artifacts, _retrievers, _ranker, _config, _logger, _topic_extractor + global _artifacts, _retrievers, _ranker, _config, _logger, _topic_extractor, _l1_cache config_path = _resolve_config_path() if not config_path.exists(): @@ -208,6 +263,15 @@ async def lifespan(app: FastAPI): rrf_k=int(_config.rrf_k), ) + # Initialize L1 cache for retrieval/ranking outputs. + if getattr(_config, "enable_l1_cache", False): + _l1_cache = L1RetrievalCache( + max_entries=getattr(_config, "l1_cache_max_entries", 256), + ttl_seconds=getattr(_config, "l1_cache_ttl_seconds", 600), + ) + else: + _l1_cache = None + init_feedback_db() if _config.enable_topic_extraction: _topic_extractor = TopicExtractor( From 8c8314b5752ca32c2717be1fd8575dc7fe691938 Mon Sep 17 00:00:00 2001 From: ksuleimanova Date: Sat, 25 Apr 2026 18:08:16 -0400 Subject: [PATCH 3/3] L2 cache --- config/config.yaml | 5 +- src/api_server.py | 112 +++++++++++++++++++++++++++++++++++---------- src/config.py | 5 ++ src/embedder.py | 29 ++++++++---- src/l2_cache.py | 105 ++++++++++++++++++++++++++++++++++++++++++ src/main.py | 59 ++++++++++++++++++++++-- 6 files changed, 278 insertions(+), 37 deletions(-) create mode 100644 src/l2_cache.py diff --git a/config/config.yaml b/config/config.yaml index b646d2b8..c4963a4f 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -18,4 +18,7 @@ rerank_top_k: 5 use_double_prompt: false enable_history: true max_history_turns: 3 -enable_topic_extraction: false \ No newline at end of file +enable_topic_extraction: false +enable_l2_cache: true +l2_cache_max_entries: 256 +l2_cache_ttl_seconds: 600 \ No newline at end of file diff --git a/src/api_server.py b/src/api_server.py index 33ce2316..7551a6db 100644 --- a/src/api_server.py +++ b/src/api_server.py @@ -27,6 +27,7 @@ from src.config import RAGConfig from src.generator import answer from src.l1_cache import L1RetrievalCache +from src.l2_cache import L2AnswerCache from src.feedback_store import ( init_feedback_db, save_answer, @@ -51,6 +52,7 @@ _logger = None _topic_extractor: Optional[TopicExtractor] = None _l1_cache: Optional[L1RetrievalCache] = None +_l2_cache: Optional[L2AnswerCache] = None class SourceItem(BaseModel): @@ -221,10 +223,26 @@ def _retrieve_and_rank(query: str, top_k: Optional[int] = None): return ordered_ids, ordered_scores + +def _l2_generation_params( + prompt_type: str, + max_chunks: int, + temperature: float, + enable_chunks: bool, +) -> Dict[str, object]: + return { + "gen_model": _config.gen_model, + "system_prompt_mode": prompt_type, + "max_gen_tokens": _config.max_gen_tokens, + "max_chunks": int(max_chunks), + "temperature": float(temperature), + "enable_chunks": bool(enable_chunks), + } + @asynccontextmanager async def lifespan(app: FastAPI): """Initialize artifacts on startup.""" - global _artifacts, _retrievers, _ranker, _config, _logger, _topic_extractor, _l1_cache + global _artifacts, _retrievers, _ranker, _config, _logger, _topic_extractor, _l1_cache, _l2_cache config_path = _resolve_config_path() if not config_path.exists(): @@ -272,6 +290,14 @@ async def lifespan(app: FastAPI): else: _l1_cache = None + if getattr(_config, "enable_l2_cache", False): + _l2_cache = L2AnswerCache( + max_entries=getattr(_config, "l2_cache_max_entries", 256), + ttl_seconds=getattr(_config, "l2_cache_ttl_seconds", 600), + ) + else: + _l2_cache = None + init_feedback_db() if _config.enable_topic_extraction: _topic_extractor = TopicExtractor( @@ -441,9 +467,22 @@ async def chat_stream(request: ChatRequest): chunks = _artifacts["chunks"] sources = _artifacts["sources"] + + l2_params = _l2_generation_params( + prompt_type=prompt_type, + max_chunks=max_chunks, + temperature=temperature, + enable_chunks=enable_chunks, + ) + l2_entry = None + if getattr(_config, "enable_l2_cache", False) and _l2_cache is not None: + l2_entry = _l2_cache.get(request.query, params=l2_params) - if disable_chunks: + if l2_entry is not None: + ranked_chunks, topk_idxs, ordered_ranked_scores = [], [], {} + elif disable_chunks: ranked_chunks, topk_idxs = [], [] + ordered_ranked_scores = {} else: topk_idxs, ordered_ranked_scores = _retrieve_and_rank(request.query, top_k=max_chunks) topk_idxs = [int(i) for i in topk_idxs] @@ -474,12 +513,20 @@ async def event_generator(): yield f"data: {json.dumps({'type': 'sources', 'content': [s.dict() for s in sources_used]})}\n\n" yield f"data: {json.dumps({'type': 'chunks_by_page', 'content': chunks_by_page})}\n\n" - # Stream generation token by token - for delta in answer(request.query, ranked_chunks, _config.gen_model, - _config.max_gen_tokens, system_prompt_mode=prompt_type, temperature=temperature): - if delta: - full_response_accumulator.append(delta) # Capture for log - yield f"data: {json.dumps({'type': 'token', 'content': delta})}\n\n" + if l2_entry is not None: + full_response_accumulator = [l2_entry.answer_text] + yield f"data: {json.dumps({'type': 'token', 'content': l2_entry.answer_text})}\n\n" + else: + # Stream generation token by token + for delta in answer(request.query, ranked_chunks, _config.gen_model, + _config.max_gen_tokens, system_prompt_mode=prompt_type, temperature=temperature): + if delta: + full_response_accumulator.append(delta) # Capture for log + yield f"data: {json.dumps({'type': 'token', 'content': delta})}\n\n" + + final_answer = "".join(full_response_accumulator) + if getattr(_config, "enable_l2_cache", False) and _l2_cache is not None and final_answer.strip(): + _l2_cache.set(request.query, final_answer, params=l2_params) if _logger: success_log = _create_log(chunks , sources , topk_idxs, ordered_ranked_scores, page_nums, full_response_accumulator, request, @@ -562,10 +609,24 @@ async def chat(request: ChatRequest): chunks = _artifacts["chunks"] sources = _artifacts["sources"] + l2_params = _l2_generation_params( + prompt_type=prompt_type, + max_chunks=max_chunks, + temperature=temperature, + enable_chunks=enable_chunks, + ) try: + # 2. L2 Cache fast path for exact query + generation params + l2_entry = None + if getattr(_config, "enable_l2_cache", False) and _l2_cache is not None: + l2_entry = _l2_cache.get(request.query, params=l2_params) + # 2. Retrieval & Ranking (SAFE against mocked None return) - if disable_chunks: + if l2_entry is not None: + ranked_chunks, topk_idxs, ordered_ranked_scores = [], [], {} + answer_text = l2_entry.answer_text + elif disable_chunks: ranked_chunks, topk_idxs, ordered_ranked_scores = [], [], {} else: retrieval_result = _retrieve_and_rank( @@ -591,22 +652,25 @@ async def chat(request: ChatRequest): raise HTTPException(status_code=500, detail="Model path not configured.") # 3. Full Generation - try: - answer_text = "".join( - answer( - request.query, - ranked_chunks, - _config.gen_model, - _config.max_gen_tokens, - system_prompt_mode=prompt_type, - temperature=temperature, + if l2_entry is None: + try: + answer_text = "".join( + answer( + request.query, + ranked_chunks, + _config.gen_model, + _config.max_gen_tokens, + system_prompt_mode=prompt_type, + temperature=temperature, + ) + ) + if getattr(_config, "enable_l2_cache", False) and _l2_cache is not None and answer_text.strip(): + _l2_cache.set(request.query, answer_text, params=l2_params) + except Exception as gen_error: + print(f"Generation failed: {str(gen_error)}") + answer_text = ( + "I'm sorry, but I couldn't generate a response due to an internal error." ) - ) - except Exception as gen_error: - print(f"Generation failed: {str(gen_error)}") - answer_text = ( - "I'm sorry, but I couldn't generate a response due to an internal error." - ) # 4. Post-processing (Metadata & Pages) page_nums = get_page_numbers(topk_idxs, _artifacts["meta"]) or {} diff --git a/src/config.py b/src/config.py index 89e3b5ab..621b6f0b 100644 --- a/src/config.py +++ b/src/config.py @@ -32,6 +32,9 @@ class RAGConfig: enable_l1_cache: bool = True l1_cache_max_entries: int = 256 l1_cache_ttl_seconds: int = 600 + enable_l2_cache: bool = True + l2_cache_max_entries: int = 256 + l2_cache_ttl_seconds: int = 600 # generation max_gen_tokens: int = 400 @@ -74,6 +77,8 @@ def __post_init__(self): assert self.num_candidates >= self.top_k, "num_candidates must be >= top_k" assert self.l1_cache_max_entries > 0, "l1_cache_max_entries must be > 0" assert self.l1_cache_ttl_seconds > 0, "l1_cache_ttl_seconds must be > 0" + assert self.l2_cache_max_entries > 0, "l2_cache_max_entries must be > 0" + assert self.l2_cache_ttl_seconds > 0, "l2_cache_ttl_seconds must be > 0" assert self.ensemble_method.lower() in {"linear", "weighted", "rrf"} assert self.embedding_model_context_window > 0, "embedding_model_context_window must be > 0" if self.ensemble_method.lower() in {"linear", "weighted"}: diff --git a/src/embedder.py b/src/embedder.py index 1b95f532..2fc0bec4 100644 --- a/src/embedder.py +++ b/src/embedder.py @@ -60,15 +60,26 @@ def __init__(self, model_path: str, n_ctx: int = 4096, n_threads: int = None): self.model_path = model_path self.n_ctx = n_ctx - self.model = Llama( - model_path=model_path, - n_ctx=n_ctx, - n_threads=n_threads, - embedding=True, - verbose=False, - use_mmap=True, - n_gpu_layers=-1, - ) + try: + self.model = Llama( + model_path=model_path, + n_ctx=n_ctx, + n_threads=n_threads, + embedding=True, + verbose=False, + use_mmap=True, + n_gpu_layers=-1, + ) + except Exception as e: + print(f"Error loading embedding model from {model_path} on GPU: {e}") + self.model = Llama( + model_path=model_path, + n_ctx=n_ctx, + n_threads=n_threads, + embedding=True, + verbose=False, + use_mmap=True, + ) self._embedding_dimension = None # Warm up — also caches embedding dimension diff --git a/src/l2_cache.py b/src/l2_cache.py new file mode 100644 index 00000000..cccfaba4 --- /dev/null +++ b/src/l2_cache.py @@ -0,0 +1,105 @@ +""" +In-memory L2 cache for final generated answers. + +Key: normalized query text +Value: answer text and LRU/TTL metadata +""" + +from __future__ import annotations + +import hashlib +import time +from collections import OrderedDict +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from src.l1_cache import normalize_query_text + + +@dataclass +class L2CacheEntry: + normalized_query_text: str + answer_text: str + params_signature: str + created_at: float + expires_at: float + last_access_at: float + access_count: int + + +class L2AnswerCache: + def __init__(self, max_entries: int = 256, ttl_seconds: int = 600): + self.max_entries = max(1, int(max_entries)) + self.ttl_seconds = max(1, int(ttl_seconds)) + self._entries: "OrderedDict[str, L2CacheEntry]" = OrderedDict() + + def _now(self) -> float: + return time.time() + + def _is_expired(self, entry: L2CacheEntry, now: float) -> bool: + return entry.expires_at <= now + + def _evict_expired(self, now: float) -> None: + expired = [k for k, v in self._entries.items() if self._is_expired(v, now)] + for key in expired: + self._entries.pop(key, None) + + def _evict_lru_if_needed(self) -> None: + while len(self._entries) > self.max_entries: + self._entries.popitem(last=False) + + @staticmethod + def _params_signature(params: Dict[str, Any]) -> str: + ordered = sorted(params.items(), key=lambda kv: kv[0]) + return hashlib.sha256(repr(ordered).encode("utf-8")).hexdigest() + + def get(self, query: str, params: Dict[str, Any]) -> Optional[L2CacheEntry]: + now = self._now() + self._evict_expired(now) + + key = normalize_query_text(query) + entry = self._entries.get(key) + if entry is None: + return None + + expected_sig = self._params_signature(params) + if entry.params_signature != expected_sig: + self._entries.pop(key, None) + return None + + if self._is_expired(entry, now): + self._entries.pop(key, None) + return None + + entry.last_access_at = now + entry.access_count += 1 + self._entries.move_to_end(key) + return entry + + def set(self, query: str, answer_text: str, params: Dict[str, Any]) -> L2CacheEntry: + now = self._now() + self._evict_expired(now) + + key = normalize_query_text(query) + entry = L2CacheEntry( + normalized_query_text=key, + answer_text=str(answer_text), + params_signature=self._params_signature(params), + created_at=now, + expires_at=now + self.ttl_seconds, + last_access_at=now, + access_count=1, + ) + + if key in self._entries: + self._entries.pop(key, None) + self._entries[key] = entry + self._evict_lru_if_needed() + return entry + + def stats(self) -> Dict[str, int]: + return { + "entries": len(self._entries), + "max_entries": self.max_entries, + "ttl_seconds": self.ttl_seconds, + } diff --git a/src/main.py b/src/main.py index 73d0783d..90749920 100644 --- a/src/main.py +++ b/src/main.py @@ -13,6 +13,7 @@ from src.config import RAGConfig from src.l1_cache import L1RetrievalCache +from src.l2_cache import L2AnswerCache from src.generator import answer, double_answer, dedupe_generated_text from src.index_builder import build_index from src.instrumentation.logging import get_logger @@ -124,6 +125,50 @@ def get_answer( chunks_info = None hyde_query = None cache_meta: Dict[str, Any] = {} + l2_cache: Optional[L2AnswerCache] = artifacts.get("l2_cache") + system_prompt = args.system_prompt_mode or cfg.system_prompt_mode + use_double = getattr(args, "double_prompt", False) or cfg.use_double_prompt + generation_params = { + "gen_model": cfg.gen_model, + "system_prompt_mode": system_prompt, + "max_gen_tokens": cfg.max_gen_tokens, + "use_double_prompt": use_double, + } + + if not is_test_mode and cfg.enable_l2_cache and l2_cache is not None: + l2_entry = l2_cache.get(question, params=generation_params) + if l2_entry is not None: + cache_meta = { + "l2_cache": { + "hit": True, + "normalized_query_text": l2_entry.normalized_query_text, + "last_access_at": l2_entry.last_access_at, + "expires_at": l2_entry.expires_at, + "access_count": l2_entry.access_count, + } + } + if console: + console.print("\n[bold cyan]=== START OF ANSWER ===[/bold cyan]\n") + console.print(Markdown(l2_entry.answer_text)) + console.print("\n[bold cyan]=== END OF ANSWER ===[/bold cyan]\n") + + logger.save_chat_log( + query=question, + config_state=cfg.get_config_state(), + ordered_scores=[], + chat_request_params={ + "system_prompt": system_prompt, + "max_tokens": cfg.max_gen_tokens, + }, + top_idxs=[], + chunks=[], + sources=[], + page_map={}, + full_response=l2_entry.answer_text, + top_k=0, + additional_log_info={**(additional_log_info or {}), **cache_meta}, + ) + return l2_entry.answer_text if golden_chunks and cfg.use_golden_chunks: # Use provided golden chunks ranked_chunks = golden_chunks @@ -239,9 +284,6 @@ def get_answer( # Step 4: Generation model_path = cfg.gen_model - system_prompt = args.system_prompt_mode or cfg.system_prompt_mode - - use_double = getattr(args, "double_prompt", False) or cfg.use_double_prompt if use_double: stream_iter = double_answer( @@ -271,6 +313,10 @@ def get_answer( # Accumulate the full text while rendering incremental Markdown chunks ans = render_streaming_ans(console, stream_iter) + if cfg.enable_l2_cache and l2_cache is not None and ans.strip(): + l2_cache.set(question, ans, params=generation_params) + cache_meta["l2_cache"] = {"hit": False} + # Logging meta = artifacts.get("meta", []) page_nums = get_page_numbers(topk_idxs, meta) @@ -337,6 +383,12 @@ def run_chat_session(args: argparse.Namespace, cfg: RAGConfig): max_entries=cfg.l1_cache_max_entries, ttl_seconds=cfg.l1_cache_ttl_seconds, ) + l2_cache = None + if cfg.enable_l2_cache: + l2_cache = L2AnswerCache( + max_entries=cfg.l2_cache_max_entries, + ttl_seconds=cfg.l2_cache_ttl_seconds, + ) print("Loaded retrievers and initialized ranker.") artifacts = { "chunks": chunks, @@ -345,6 +397,7 @@ def run_chat_session(args: argparse.Namespace, cfg: RAGConfig): "ranker": ranker, "meta": meta, "l1_cache": l1_cache, + "l2_cache": l2_cache, } except Exception as e: print(f"ERROR: {e}. Run 'index' mode first.")