Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@ rerank_top_k: 5
use_double_prompt: false
enable_history: true
max_history_turns: 3
enable_topic_extraction: false
enable_topic_extraction: false
enable_l2_cache: true
l2_cache_max_entries: 256
l2_cache_ttl_seconds: 600
176 changes: 152 additions & 24 deletions src/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

from src.config import RAGConfig
from src.generator import answer
from src.l1_cache import L1RetrievalCache
from src.l2_cache import L2AnswerCache
from src.feedback_store import (
init_feedback_db,
save_answer,
Expand All @@ -49,6 +51,8 @@
_config: Optional[RAGConfig] = None
_logger = None
_topic_extractor: Optional[TopicExtractor] = None
_l1_cache: Optional[L1RetrievalCache] = None
_l2_cache: Optional[L2AnswerCache] = None


class SourceItem(BaseModel):
Expand Down Expand Up @@ -149,6 +153,38 @@ def _create_log(chunks , sources , topk_idxs, ordered_ranked_scores, page_nums,
def _retrieve_and_rank(query: str, top_k: Optional[int] = None):
chunks = _artifacts["chunks"]
effective_top_k = top_k if top_k is not None else _config.top_k

# L1 cache lookup (additive path; falls back to existing retrieval flow on miss)
cache_entry = None
if getattr(_config, "enable_l1_cache", False) and _l1_cache is not None:
retriever_names = [r.name for r in _retrievers]
retrieval_params = {
"top_k": effective_top_k,
"num_candidates": _config.num_candidates,
"ensemble_method": _config.ensemble_method,
"rrf_k": _config.rrf_k,
"ranker_weights": tuple(sorted(_config.ranker_weights.items())),
"retrievers": tuple(sorted(retriever_names)),
"query": query,
}
cache_entry = _l1_cache.get(
query=query,
embed_model=_config.embed_model,
embedding_context_window=getattr(_config, "embedding_model_context_window", 4096),
params=retrieval_params,
)

if cache_entry is not None:
ordered_ids = cache_entry.top_chunk_ids
ordered_scores = cache_entry.top_chunk_scores
if top_k is not None:
ordered_ids = ordered_ids[:top_k]
ordered_scores = ordered_scores[:top_k]
else:
ordered_ids = ordered_ids[:_config.top_k]
ordered_scores = ordered_scores[:_config.top_k]
return ordered_ids, ordered_scores

pool_n = max(_config.num_candidates, effective_top_k + 10)
raw_scores: Dict[str, Dict[int, float]] = {}

Expand All @@ -157,6 +193,27 @@ def _retrieve_and_rank(query: str, top_k: Optional[int] = None):

ordered_ids, ordered_scores = _ranker.rank(raw_scores=raw_scores)

# L1 cache write-back after successful retrieval + ranking.
if getattr(_config, "enable_l1_cache", False) and _l1_cache is not None:
retriever_names = [r.name for r in _retrievers]
retrieval_params = {
"top_k": effective_top_k,
"num_candidates": _config.num_candidates,
"ensemble_method": _config.ensemble_method,
"rrf_k": _config.rrf_k,
"ranker_weights": tuple(sorted(_config.ranker_weights.items())),
"retrievers": tuple(sorted(retriever_names)),
"query": query,
}
_l1_cache.set(
query=query,
embed_model=_config.embed_model,
embedding_context_window=getattr(_config, "embedding_model_context_window", 4096),
top_chunk_ids=ordered_ids[:effective_top_k],
top_chunk_scores=ordered_scores[:effective_top_k],
params=retrieval_params,
)

if top_k is not None:
ordered_ids = ordered_ids[:top_k]
ordered_scores = ordered_scores[:top_k]
Expand All @@ -166,10 +223,26 @@ def _retrieve_and_rank(query: str, top_k: Optional[int] = None):

return ordered_ids, ordered_scores


def _l2_generation_params(
prompt_type: str,
max_chunks: int,
temperature: float,
enable_chunks: bool,
) -> Dict[str, object]:
return {
"gen_model": _config.gen_model,
"system_prompt_mode": prompt_type,
"max_gen_tokens": _config.max_gen_tokens,
"max_chunks": int(max_chunks),
"temperature": float(temperature),
"enable_chunks": bool(enable_chunks),
}

@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize artifacts on startup."""
global _artifacts, _retrievers, _ranker, _config, _logger, _topic_extractor
global _artifacts, _retrievers, _ranker, _config, _logger, _topic_extractor, _l1_cache, _l2_cache

config_path = _resolve_config_path()
if not config_path.exists():
Expand Down Expand Up @@ -208,6 +281,23 @@ async def lifespan(app: FastAPI):
rrf_k=int(_config.rrf_k),
)

# Initialize L1 cache for retrieval/ranking outputs.
if getattr(_config, "enable_l1_cache", False):
_l1_cache = L1RetrievalCache(
max_entries=getattr(_config, "l1_cache_max_entries", 256),
ttl_seconds=getattr(_config, "l1_cache_ttl_seconds", 600),
)
else:
_l1_cache = None

if getattr(_config, "enable_l2_cache", False):
_l2_cache = L2AnswerCache(
max_entries=getattr(_config, "l2_cache_max_entries", 256),
ttl_seconds=getattr(_config, "l2_cache_ttl_seconds", 600),
)
else:
_l2_cache = None

init_feedback_db()
if _config.enable_topic_extraction:
_topic_extractor = TopicExtractor(
Expand Down Expand Up @@ -377,9 +467,22 @@ async def chat_stream(request: ChatRequest):

chunks = _artifacts["chunks"]
sources = _artifacts["sources"]

l2_params = _l2_generation_params(
prompt_type=prompt_type,
max_chunks=max_chunks,
temperature=temperature,
enable_chunks=enable_chunks,
)
l2_entry = None
if getattr(_config, "enable_l2_cache", False) and _l2_cache is not None:
l2_entry = _l2_cache.get(request.query, params=l2_params)

if disable_chunks:
if l2_entry is not None:
ranked_chunks, topk_idxs, ordered_ranked_scores = [], [], {}
elif disable_chunks:
ranked_chunks, topk_idxs = [], []
ordered_ranked_scores = {}
else:
topk_idxs, ordered_ranked_scores = _retrieve_and_rank(request.query, top_k=max_chunks)
topk_idxs = [int(i) for i in topk_idxs]
Expand Down Expand Up @@ -410,12 +513,20 @@ async def event_generator():
yield f"data: {json.dumps({'type': 'sources', 'content': [s.dict() for s in sources_used]})}\n\n"
yield f"data: {json.dumps({'type': 'chunks_by_page', 'content': chunks_by_page})}\n\n"

# Stream generation token by token
for delta in answer(request.query, ranked_chunks, _config.gen_model,
_config.max_gen_tokens, system_prompt_mode=prompt_type, temperature=temperature):
if delta:
full_response_accumulator.append(delta) # Capture for log
yield f"data: {json.dumps({'type': 'token', 'content': delta})}\n\n"
if l2_entry is not None:
full_response_accumulator = [l2_entry.answer_text]
yield f"data: {json.dumps({'type': 'token', 'content': l2_entry.answer_text})}\n\n"
else:
# Stream generation token by token
for delta in answer(request.query, ranked_chunks, _config.gen_model,
_config.max_gen_tokens, system_prompt_mode=prompt_type, temperature=temperature):
if delta:
full_response_accumulator.append(delta) # Capture for log
yield f"data: {json.dumps({'type': 'token', 'content': delta})}\n\n"

final_answer = "".join(full_response_accumulator)
if getattr(_config, "enable_l2_cache", False) and _l2_cache is not None and final_answer.strip():
_l2_cache.set(request.query, final_answer, params=l2_params)

if _logger:
success_log = _create_log(chunks , sources , topk_idxs, ordered_ranked_scores, page_nums, full_response_accumulator, request,
Expand Down Expand Up @@ -498,10 +609,24 @@ async def chat(request: ChatRequest):

chunks = _artifacts["chunks"]
sources = _artifacts["sources"]
l2_params = _l2_generation_params(
prompt_type=prompt_type,
max_chunks=max_chunks,
temperature=temperature,
enable_chunks=enable_chunks,
)

try:
# 2. L2 Cache fast path for exact query + generation params
l2_entry = None
if getattr(_config, "enable_l2_cache", False) and _l2_cache is not None:
l2_entry = _l2_cache.get(request.query, params=l2_params)

# 2. Retrieval & Ranking (SAFE against mocked None return)
if disable_chunks:
if l2_entry is not None:
ranked_chunks, topk_idxs, ordered_ranked_scores = [], [], {}
answer_text = l2_entry.answer_text
elif disable_chunks:
ranked_chunks, topk_idxs, ordered_ranked_scores = [], [], {}
else:
retrieval_result = _retrieve_and_rank(
Expand All @@ -527,22 +652,25 @@ async def chat(request: ChatRequest):
raise HTTPException(status_code=500, detail="Model path not configured.")

# 3. Full Generation
try:
answer_text = "".join(
answer(
request.query,
ranked_chunks,
_config.gen_model,
_config.max_gen_tokens,
system_prompt_mode=prompt_type,
temperature=temperature,
if l2_entry is None:
try:
answer_text = "".join(
answer(
request.query,
ranked_chunks,
_config.gen_model,
_config.max_gen_tokens,
system_prompt_mode=prompt_type,
temperature=temperature,
)
)
if getattr(_config, "enable_l2_cache", False) and _l2_cache is not None and answer_text.strip():
_l2_cache.set(request.query, answer_text, params=l2_params)
except Exception as gen_error:
print(f"Generation failed: {str(gen_error)}")
answer_text = (
"I'm sorry, but I couldn't generate a response due to an internal error."
)
)
except Exception as gen_error:
print(f"Generation failed: {str(gen_error)}")
answer_text = (
"I'm sorry, but I couldn't generate a response due to an internal error."
)

# 4. Post-processing (Metadata & Pages)
page_nums = get_page_numbers(topk_idxs, _artifacts["meta"]) or {}
Expand Down
10 changes: 10 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ class RAGConfig:
)
rerank_mode: str = ""
rerank_top_k: int = 5
enable_l1_cache: bool = True
l1_cache_max_entries: int = 256
l1_cache_ttl_seconds: int = 600
enable_l2_cache: bool = True
l2_cache_max_entries: int = 256
l2_cache_ttl_seconds: int = 600

# generation
max_gen_tokens: int = 400
Expand Down Expand Up @@ -69,6 +75,10 @@ def __post_init__(self):
"""Validation logic runs automatically after initialization."""
assert self.top_k > 0, "top_k must be > 0"
assert self.num_candidates >= self.top_k, "num_candidates must be >= top_k"
assert self.l1_cache_max_entries > 0, "l1_cache_max_entries must be > 0"
assert self.l1_cache_ttl_seconds > 0, "l1_cache_ttl_seconds must be > 0"
assert self.l2_cache_max_entries > 0, "l2_cache_max_entries must be > 0"
assert self.l2_cache_ttl_seconds > 0, "l2_cache_ttl_seconds must be > 0"
assert self.ensemble_method.lower() in {"linear", "weighted", "rrf"}
assert self.embedding_model_context_window > 0, "embedding_model_context_window must be > 0"
if self.ensemble_method.lower() in {"linear", "weighted"}:
Expand Down
29 changes: 20 additions & 9 deletions src/embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,26 @@ def __init__(self, model_path: str, n_ctx: int = 4096, n_threads: int = None):
self.model_path = model_path
self.n_ctx = n_ctx

self.model = Llama(
model_path=model_path,
n_ctx=n_ctx,
n_threads=n_threads,
embedding=True,
verbose=False,
use_mmap=True,
n_gpu_layers=-1,
)
try:
self.model = Llama(
model_path=model_path,
n_ctx=n_ctx,
n_threads=n_threads,
embedding=True,
verbose=False,
use_mmap=True,
n_gpu_layers=-1,
)
except Exception as e:
print(f"Error loading embedding model from {model_path} on GPU: {e}")
self.model = Llama(
model_path=model_path,
n_ctx=n_ctx,
n_threads=n_threads,
embedding=True,
verbose=False,
use_mmap=True,
)
self._embedding_dimension = None

# Warm up — also caches embedding dimension
Expand Down
Loading