diff --git a/config/config.yaml b/config/config.yaml index 286da7ca..f2d843e6 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -1,14 +1,15 @@ embed_model: "models/Qwen3-Embedding-4B-Q5_K_M.gguf" +embedding_model_context_window: 4096 top_k: 10 num_candidates: 50 ensemble_method: "rrf" -ranker_weights: {"faiss":1,"bm25":0,"index_keywords":0} +ranker_weights: {"faiss":0.75,"bm25":0.25,"index_keywords":0} rrf_k: 60 max_gen_tokens: 400 -chunk_mode : "recursive_sections" +chunk_mode: "recursive_sections" gen_model: "models/qwen2.5-3b-instruct-q8_0.gguf" -chunk_size: 2000 -chunk_overlap: 200 +chunk_size_in_chars: 2000 +chunk_overlap: 300 use_hyde: false hyde_max_tokens: 300 use_indexed_chunks: false @@ -17,4 +18,4 @@ rerank_top_k: 5 use_double_prompt: false enable_history: true max_history_turns: 3 -enable_topic_extraction: true +enable_topic_extraction: false \ No newline at end of file diff --git a/src/api_server.py b/src/api_server.py index 5e9cf243..a05685d0 100644 --- a/src/api_server.py +++ b/src/api_server.py @@ -10,11 +10,10 @@ from uuid import uuid4 from copy import deepcopy from contextlib import asynccontextmanager -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Any, Tuple import traceback import os -# Add project root to Python path to allow imports when run directly _project_root = pathlib.Path(__file__).resolve().parent.parent if str(_project_root) not in sys.path: sys.path.insert(0, str(_project_root)) @@ -35,14 +34,19 @@ ) from src.instrumentation.logging import get_logger from src.ranking.ranker import EnsembleRanker -from src.retriever import filter_retrieved_chunks, BM25Retriever, FAISSRetriever, IndexKeywordRetriever, get_page_numbers, load_artifacts +from src.retriever import ( + filter_retrieved_chunks, + BM25Retriever, + FAISSRetriever, + IndexKeywordRetriever, + get_page_numbers, + load_artifacts, +) +from src.ranking.reranker import rerank from src.user_feedback_model import TopicExtractor, estimate_difficulty -# Constants INDEX_PREFIX = "textbook_index" - -# Global state populated during app lifespan _artifacts: Optional[Dict[str, List[str]]] = None _retrievers: Optional[List] = None _ranker: Optional[EnsembleRanker] = None @@ -54,18 +58,18 @@ class SourceItem(BaseModel): page: int text: str - + class Config: - frozen = True # Makes the model hashable so it can be used in sets + frozen = True class ChatRequest(BaseModel): query: str enable_chunks: Optional[bool] = None - prompt_type: Optional[str] = None # Maps to system_prompt_mode - max_chunks: Optional[int] = None # Maps to top_k for retrieval + prompt_type: Optional[str] = None + max_chunks: Optional[int] = None temperature: Optional[float] = None - top_k: Optional[int] = None # Alternative name for max_chunks, takes precedence if both provided + top_k: Optional[int] = None session_id: Optional[str] = None @@ -92,7 +96,6 @@ class ChatResponse(BaseModel): def _resolve_config_path() -> pathlib.Path: - """Return the absolute path to the API config.""" return pathlib.Path(__file__).resolve().parent.parent / "config" / "config.yaml" @@ -100,106 +103,189 @@ def _ensure_initialized(): if not all([_config, _artifacts, _retrievers, _ranker]): raise HTTPException( status_code=503, - detail="Artifacts not loaded. Please run indexing first." + detail="Artifacts not loaded. Please run indexing first.", ) -def _create_log(chunks , sources , topk_idxs, ordered_ranked_scores, page_nums, full_response_accumulator, request, - enable_chunks, prompt_type, max_chunks, temperature): + +def _retrieve_and_rank( + query: str, + top_k: Optional[int] = None, +) -> Tuple[List[int], List[float], Dict[str, Dict[int, float]]]: + """ + Run all retrievers and fuse scores. + + Returns: + ordered_ids: Chunk indices sorted by fused score, truncated to top_k. + ordered_scores: Corresponding fused scores. + raw_scores: Per-retriever score dicts โ used for diagnostics. + """ + chunks = _artifacts["chunks"] + effective_top_k = top_k if top_k is not None else _config.top_k + pool_n = max(_config.num_candidates, effective_top_k + 10) + + raw_scores: Dict[str, Dict[int, float]] = {} + for retriever in _retrievers: + raw_scores[retriever.name] = retriever.get_scores(query, pool_n, chunks) + + ordered_ids, ordered_scores = _ranker.rank(raw_scores=raw_scores) + + limit = effective_top_k + ordered_ids = ordered_ids[:limit] + ordered_scores = ordered_scores[:limit] + + return ordered_ids, ordered_scores, raw_scores + + +def _build_chunk_diagnostics( + topk_idxs: List[int], + ordered_ids: List[int], + ordered_scores: List[float], + raw_scores: Dict[str, Dict[int, float]], +) -> Dict[int, Dict[str, Any]]: + """ + Build per-chunk diagnostic dict (pre-reranking). + post_reranking_rank and cross_encoder_score are filled in after reranking. + """ + faiss_scores = raw_scores.get("faiss", {}) + bm25_scores = raw_scores.get("bm25", {}) + + faiss_ranked = sorted(faiss_scores, key=lambda i: faiss_scores[i], reverse=True) + bm25_ranked = sorted(bm25_scores, key=lambda i: bm25_scores[i], reverse=True) + + faiss_ranks = {idx: rank + 1 for rank, idx in enumerate(faiss_ranked)} + bm25_ranks = {idx: rank + 1 for rank, idx in enumerate(bm25_ranked)} + + # post_fusion_rank: rank within the full fused ordering (not just top-k) + post_fusion_ranks = {idx: rank + 1 for rank, idx in enumerate(ordered_ids)} + + diagnostics: Dict[int, Dict[str, Any]] = {} + for idx in topk_idxs: + diagnostics[idx] = { + "faiss_score": faiss_scores.get(idx, None), + "faiss_rank": faiss_ranks.get(idx, None), + "bm25_score": bm25_scores.get(idx, None), + "bm25_rank": bm25_ranks.get(idx, None), + "post_fusion_rank": post_fusion_ranks.get(idx, None), + "post_reranking_rank": None, # filled after reranking + "cross_encoder_score": None, # filled after reranking + } + return diagnostics + + +def _apply_reranking( + query: str, + topk_idxs: List[int], + chunk_diagnostics: Dict[int, Dict[str, Any]], +) -> Tuple[List[int], List[Any]]: + """ + Run cross-encoder reranking and fill diagnostics in-place. + + Returns: + reranked_idxs: Chunk indices in reranked order. + reranked_chunks: List of (text, ce_score) tuples for the generator. + """ + chunks = _artifacts["chunks"] + indexed_chunks = [(idx, chunks[idx]) for idx in topk_idxs] + + reranked = rerank( + query, + indexed_chunks, + mode=_config.rerank_mode, + top_n=_config.rerank_top_k, + ) + # reranked: List[Tuple[int, str, float]] + + reranked_idxs = [idx for idx, _, _ in reranked] + reranked_chunks = [(text, ce_score) for _, text, ce_score in reranked] + + for rerank_pos, (idx, _, ce_score) in enumerate(reranked, start=1): + if idx in chunk_diagnostics: + chunk_diagnostics[idx]["post_reranking_rank"] = rerank_pos + chunk_diagnostics[idx]["cross_encoder_score"] = ce_score + + return reranked_idxs, reranked_chunks + + +def _create_log( + chunks: List[str], + sources: List[str], + topk_idxs: List[int], + ordered_ranked_scores: List[float], + page_nums: Dict, + full_response_accumulator: List[str], + request: ChatRequest, + enable_chunks: bool, + prompt_type: str, + max_chunks: int, + temperature: float, + chunk_diagnostics: Optional[Dict[int, Dict[str, Any]]] = None, +) -> bool: try: - # Capture the actual strings used for the log file - log_chunks = [chunks[i] for i in topk_idxs[:max_chunks]] - log_sources = [sources[i] for i in topk_idxs[:max_chunks]] - - # Just Logging + # Align everything to the actual number of reranked chunks, + # NOT max_chunks โ reranking may have reduced the count to rerank_top_k. + n = len(topk_idxs) + log_chunks = [chunks[i] for i in topk_idxs] + log_sources = [sources[i] for i in topk_idxs] + + # Trim fusion scores to match reranked count. + # ordered_ranked_scores comes from fusion (pre-reranking) so may be longer. + log_scores = ordered_ranked_scores[:n] + _logger.save_chat_log( query=request.query, config_state=_config.get_config_state(), - ordered_scores=ordered_ranked_scores, + ordered_scores=log_scores, chat_request_params={ - "enable_chunks": { - "provided": request.enable_chunks, - "used": enable_chunks - }, - "prompt_type": { - "provided": request.prompt_type, - "used": prompt_type - }, - "max_chunks": { - "provided": request.max_chunks, - "used": max_chunks - }, - "temperature": { - "provided": request.temperature, - "used": temperature - } + "enable_chunks": {"provided": request.enable_chunks, "used": enable_chunks}, + "prompt_type": {"provided": request.prompt_type, "used": prompt_type}, + "max_chunks": {"provided": request.max_chunks, "used": max_chunks}, + "temperature": {"provided": request.temperature, "used": temperature}, }, - top_idxs=topk_idxs[:max_chunks], + top_idxs=topk_idxs, chunks=log_chunks, sources=log_sources, page_map=page_nums, full_response="".join(full_response_accumulator), - top_k=max_chunks + top_k=n, + chunk_diagnostics=chunk_diagnostics, ) - return True except Exception as log_exc: + print(f"Logging error: {log_exc}") + traceback.print_exc() return False -def _retrieve_and_rank(query: str, top_k: Optional[int] = None): - chunks = _artifacts["chunks"] - effective_top_k = top_k if top_k is not None else _config.top_k - pool_n = max(_config.num_candidates, effective_top_k + 10) - raw_scores: Dict[str, Dict[int, float]] = {} - - for retriever in _retrievers: - raw_scores[retriever.name] = retriever.get_scores(query, pool_n, chunks) - - ordered_ids, ordered_scores = _ranker.rank(raw_scores=raw_scores) - - if top_k is not None: - ordered_ids = ordered_ids[:top_k] - ordered_scores = ordered_scores[:top_k] - else: - ordered_ids = ordered_ids[:_config.top_k] - ordered_scores = ordered_scores[:_config.top_k] - - return ordered_ids, ordered_scores - @asynccontextmanager async def lifespan(app: FastAPI): - """Initialize artifacts on startup.""" global _artifacts, _retrievers, _ranker, _config, _logger, _topic_extractor config_path = _resolve_config_path() if not config_path.exists(): raise FileNotFoundError(f"No config file found at {config_path}") - _config = RAGConfig.from_yaml(config_path) - _logger = get_logger() + _config = RAGConfig.from_yaml(config_path) + _logger = get_logger() try: artifacts_dir = _config.get_artifacts_directory() faiss_index, bm25_index, chunks, sources, metadata = load_artifacts( artifacts_dir=artifacts_dir, - index_prefix=INDEX_PREFIX + index_prefix=INDEX_PREFIX, ) - _artifacts = { - "chunks": chunks, - "sources": sources, - "meta": metadata, - } + _artifacts = {"chunks": chunks, "sources": sources, "meta": metadata} _retrievers = [ FAISSRetriever(faiss_index, _config.embed_model), BM25Retriever(bm25_index), ] - - # Add index keyword retriever if weight > 0 if _config.ranker_weights.get("index_keywords", 0) > 0: _retrievers.append( - IndexKeywordRetriever(_config.extracted_index_path, _config.page_to_chunk_map_path) + IndexKeywordRetriever( + _config.extracted_index_path, + _config.page_to_chunk_map_path, + ) ) _ranker = EnsembleRanker( @@ -209,59 +295,56 @@ async def lifespan(app: FastAPI): ) init_feedback_db() - if _config.enable_topic_extraction: - _topic_extractor = TopicExtractor( + _topic_extractor = ( + TopicExtractor( extracted_index_path=_config.extracted_index_path, page_to_chunk_map_path=_config.page_to_chunk_map_path, ) - else: - _topic_extractor = None - + if _config.enable_topic_extraction + else None + ) print("TokenSmith API initialized successfully") + except Exception as exc: print(f"Warning: Could not load artifacts: {exc}") print(" Run indexing first or check your configuration") yield + print("Shutting down TokenSmith API...") - print("๐ Shutting down TokenSmith API...") - -# Create FastAPI app app = FastAPI( title="TokenSmith API", description="REST API for TokenSmith RAG chat functionality", version="1.0.0", - lifespan=lifespan + lifespan=lifespan, ) -# Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=[ - "http://localhost:3000", # Next.js dev server - "http://localhost:5173", # Vite dev server - "http://localhost:3001", # Alternative React dev server - "http://localhost:8080", # Alternative dev server - "http://127.0.0.1:3000", # Alternative localhost format - "http://127.0.0.1:5173", # Alternative localhost format - "http://127.0.0.1:3001", # Alternative localhost format - "http://127.0.0.1:8080", # Alternative localhost format + "http://localhost:3000", + "http://localhost:5173", + "http://localhost:3001", + "http://localhost:8080", + "http://127.0.0.1:3000", + "http://127.0.0.1:5173", + "http://127.0.0.1:3001", + "http://127.0.0.1:8080", ], allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], ) + @app.get("/api/health") async def health_check(): - """Health check endpoint.""" return {"status": "ok", "message": "TokenSmith API is running"} @app.post("/api/feedback", response_model=FeedbackResponse) async def feedback(request: FeedbackRequest): - """Store user feedback on an answer.""" if request.vote not in (1, -1): raise HTTPException(status_code=400, detail="vote must be 1 or -1") @@ -299,9 +382,8 @@ async def feedback(request: FeedbackRequest): @app.post("/api/test-chat") async def test_chat(request: ChatRequest): - """Test chat endpoint that bypasses generation to isolate issues.""" + """Test endpoint โ retrieval only, no generation.""" print(f"Test chat request: {request.query}") - try: _ensure_initialized() except HTTPException as exc: @@ -310,38 +392,26 @@ async def test_chat(request: ChatRequest): if not request.query.strip(): return {"error": "Query cannot be empty", "status": "error"} - # Parameter handling (aligned with /api/chat) enable_chunks = ( request.enable_chunks if request.enable_chunks is not None else not _config.disable_chunks ) - disable_chunks = not enable_chunks max_chunks = ( request.top_k if request.top_k is not None else (request.max_chunks if request.max_chunks is not None else _config.top_k) ) - if disable_chunks: - return { - "error": "Chunk retrieval disabled; enable chunks to test retrieval.", - "status": "error", - } + if not enable_chunks: + return {"error": "Chunk retrieval disabled.", "status": "error"} try: - # โ Correct order (matches /api/chat) - topk_idxs, ordered_ranked_scores = _retrieve_and_rank( + topk_idxs, ordered_ranked_scores, raw_scores = _retrieve_and_rank( request.query, top_k=max_chunks ) - - # Ensure safe types - topk_idxs = [int(i) for i in (topk_idxs or [])] - ordered_ranked_scores = ordered_ranked_scores or {} - - ranked_chunks = [ - _artifacts["chunks"][i] for i in topk_idxs[:max_chunks] - ] + topk_idxs = [int(i) for i in topk_idxs] + ranked_chunks = [_artifacts["chunks"][i] for i in topk_idxs[:max_chunks]] return { "status": "success", @@ -352,74 +422,97 @@ async def test_chat(request: ChatRequest): "top_idxs": topk_idxs, "message": "Retrieval and ranking successful, generation skipped", } - except Exception as e: - print(f"Test chat error: {str(e)}") - import traceback traceback.print_exc() return {"error": str(e), "status": "error"} - @app.post("/api/chat/stream") async def chat_stream(request: ChatRequest): """Streaming chat endpoint using Server-Sent Events.""" - import json _ensure_initialized() if not request.query.strip(): raise HTTPException(status_code=400, detail="Query cannot be empty") - - enable_chunks = request.enable_chunks if request.enable_chunks is not None else not _config.disable_chunks + + enable_chunks = ( + request.enable_chunks + if request.enable_chunks is not None + else not _config.disable_chunks + ) disable_chunks = not enable_chunks - prompt_type = request.prompt_type if request.prompt_type is not None else _config.system_prompt_mode - max_chunks = request.top_k if request.top_k is not None else (request.max_chunks if request.max_chunks is not None else _config.top_k) - temperature = request.temperature if request.temperature is not None else 0.7 - - chunks = _artifacts["chunks"] + prompt_type = request.prompt_type if request.prompt_type is not None else _config.system_prompt_mode + max_chunks = request.top_k if request.top_k is not None else (request.max_chunks if request.max_chunks is not None else _config.top_k) + temperature = request.temperature if request.temperature is not None else 0.7 + + chunks = _artifacts["chunks"] sources = _artifacts["sources"] - + + # --- Retrieval, ranking, reranking --- + chunk_diagnostics: Dict[int, Dict[str, Any]] = {} + ordered_ranked_scores: List[float] = [] + if disable_chunks: ranked_chunks, topk_idxs = [], [] else: - topk_idxs, ordered_ranked_scores = _retrieve_and_rank(request.query, top_k=max_chunks) + topk_idxs, ordered_ranked_scores, raw_scores = _retrieve_and_rank( + request.query, top_k=max_chunks + ) topk_idxs = [int(i) for i in topk_idxs] - ranked_chunks = [chunks[i] for i in topk_idxs[:max_chunks]] - - if not _config.gen_model: - raise HTTPException(status_code=500, detail="Model path not configured.") - answer_id = str(uuid4()) + chunk_diagnostics = _build_chunk_diagnostics( + topk_idxs, topk_idxs, ordered_ranked_scores, raw_scores + ) + + topk_idxs, ranked_chunks = _apply_reranking( + request.query, topk_idxs, chunk_diagnostics + ) + + answer_id = str(uuid4()) session_id = request.session_id or str(uuid4()) - + async def event_generator(): full_response_accumulator = [] try: page_nums = get_page_numbers(topk_idxs, _artifacts["meta"]) - sources_used = set() + sources_used: set = set() chunks_by_page: Dict[int, List[str]] = {} + for i in topk_idxs[:max_chunks]: - source_text = sources[i] pages = page_nums.get(i, [1]) or [1] - - print(f"[DEBUG] i={i} pages={pages!r} page_nums_has_key={i in page_nums}", flush=True) - for page in pages: chunks_by_page.setdefault(page, []).append(chunks[i]) - sources_used.add(SourceItem(page=page, text=source_text)) - + sources_used.add(SourceItem(page=page, text=sources[i])) + yield f"data: {json.dumps({'type': 'sources', 'content': [s.dict() for s in sources_used]})}\n\n" yield f"data: {json.dumps({'type': 'chunks_by_page', 'content': chunks_by_page})}\n\n" - # Stream generation token by token - for delta in answer(request.query, ranked_chunks, _config.gen_model, - _config.max_gen_tokens, system_prompt_mode=prompt_type, temperature=temperature): + for delta in answer( + request.query, + ranked_chunks, + _config.gen_model, + _config.max_gen_tokens, + system_prompt_mode=prompt_type, + temperature=temperature, + ): if delta: - full_response_accumulator.append(delta) # Capture for log + full_response_accumulator.append(delta) yield f"data: {json.dumps({'type': 'token', 'content': delta})}\n\n" - + if _logger: - success_log = _create_log(chunks , sources , topk_idxs, ordered_ranked_scores, page_nums, full_response_accumulator, request, - enable_chunks, prompt_type, max_chunks, temperature) + success_log = _create_log( + chunks, + sources, + topk_idxs, + ordered_ranked_scores, + page_nums, + full_response_accumulator, + request, + enable_chunks, + prompt_type, + max_chunks, + temperature, + chunk_diagnostics=chunk_diagnostics, + ) if not success_log: print("Logging failed for this request.") @@ -454,79 +547,70 @@ async def event_generator(): }, ) - # Include sources in the final done message for completeness yield f"data: {json.dumps({'type': 'done', 'answer_id': answer_id, 'session_id': session_id, 'sources': [s.dict() for s in sources_used]})}\n\n" + except Exception as e: - # Using print here so you can see crashes in the terminal while debugging print(f"Backend error: {e}") traceback.print_exc() yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n" - + return StreamingResponse(event_generator(), media_type="text/event-stream") @app.post("/api/chat", response_model=ChatResponse) async def chat(request: ChatRequest): - """Main chat endpoint (Non-streaming).""" + """Main chat endpoint (non-streaming).""" _ensure_initialized() if not request.query.strip(): raise HTTPException(status_code=400, detail="Query cannot be empty") - # 1. Parameter setup (Syncing with stream logic) enable_chunks = ( request.enable_chunks if request.enable_chunks is not None else not _config.disable_chunks ) disable_chunks = not enable_chunks - prompt_type = ( - request.prompt_type - if request.prompt_type is not None - else _config.system_prompt_mode - ) - max_chunks = ( - request.top_k - if request.top_k is not None - else ( - request.max_chunks - if request.max_chunks is not None - else _config.top_k - ) - ) - temperature = request.temperature if request.temperature is not None else 0.7 + prompt_type = request.prompt_type if request.prompt_type is not None else _config.system_prompt_mode + max_chunks = request.top_k if request.top_k is not None else (request.max_chunks if request.max_chunks is not None else _config.top_k) + temperature = request.temperature if request.temperature is not None else 0.7 - chunks = _artifacts["chunks"] + chunks = _artifacts["chunks"] sources = _artifacts["sources"] + chunk_diagnostics: Dict[int, Dict[str, Any]] = {} + ordered_ranked_scores: List[float] = [] + try: - # 2. Retrieval & Ranking (SAFE against mocked None return) + # --- Retrieval, ranking, reranking --- if disable_chunks: - ranked_chunks, topk_idxs, ordered_ranked_scores = [], [], {} + ranked_chunks, topk_idxs = [], [] else: - retrieval_result = _retrieve_and_rank( - request.query, top_k=max_chunks - ) + retrieval_result = _retrieve_and_rank(request.query, top_k=max_chunks) - # ๐ Safe unpacking for unit tests where ranker is mocked if ( not retrieval_result or not isinstance(retrieval_result, (list, tuple)) - or len(retrieval_result) != 2 + or len(retrieval_result) != 3 ): - topk_idxs, ordered_ranked_scores = [], {} + topk_idxs, ordered_ranked_scores, raw_scores = [], [], {} else: - topk_idxs, ordered_ranked_scores = retrieval_result + topk_idxs, ordered_ranked_scores, raw_scores = retrieval_result - topk_idxs = [int(i) for i in (topk_idxs or [])] - ordered_ranked_scores = ordered_ranked_scores or {} + topk_idxs = [int(i) for i in topk_idxs] - ranked_chunks = [chunks[i] for i in topk_idxs[:max_chunks]] + chunk_diagnostics = _build_chunk_diagnostics( + topk_idxs, topk_idxs, ordered_ranked_scores, raw_scores + ) + + topk_idxs, ranked_chunks = _apply_reranking( + request.query, topk_idxs, chunk_diagnostics + ) if not _config.gen_model: raise HTTPException(status_code=500, detail="Model path not configured.") - # 3. Full Generation + # --- Generation --- try: answer_text = "".join( answer( @@ -539,33 +623,25 @@ async def chat(request: ChatRequest): ) ) except Exception as gen_error: - print(f"Generation failed: {str(gen_error)}") - answer_text = ( - "I'm sorry, but I couldn't generate a response due to an internal error." - ) + print(f"Generation failed: {gen_error}") + answer_text = "I'm sorry, but I couldn't generate a response due to an internal error." - # 4. Post-processing (Metadata & Pages) + # --- Page/source metadata --- page_nums = get_page_numbers(topk_idxs, _artifacts["meta"]) or {} - - sources_used = set() + sources_used: set = set() chunks_by_page: Dict[int, List[str]] = {} for i in topk_idxs[:max_chunks]: - source_text = sources[i] pages = page_nums.get(i, [1]) - if isinstance(pages, list): for page in pages: - sources_used.add(SourceItem(page=int(page), text=source_text)) + sources_used.add(SourceItem(page=int(page), text=sources[i])) chunks_by_page.setdefault(int(page), []).append(chunks[i]) elif isinstance(pages, int): - sources_used.add(SourceItem(page=int(pages), text=source_text)) + sources_used.add(SourceItem(page=int(pages), text=sources[i])) chunks_by_page.setdefault(int(pages), []).append(chunks[i]) - else: # Error case - print(f"Unexpected page number format for chunk index {i}: {pages}") - - # 5. Logging + # --- Logging --- if _logger: success_log = _create_log( chunks, @@ -579,12 +655,12 @@ async def chat(request: ChatRequest): prompt_type, max_chunks, temperature, + chunk_diagnostics=chunk_diagnostics, ) if not success_log: print("Logging failed for this request.") - - answer_id = str(uuid4()) + answer_id = str(uuid4()) session_id = request.session_id or str(uuid4()) retrieval_info = { "chunks_used": topk_idxs[:max_chunks], @@ -616,11 +692,11 @@ async def chat(request: ChatRequest): "answer_id": answer_id, }, ) - + return ChatResponse( answer_id=answer_id, session_id=session_id, - answer=answer_text.strip() if answer_text and answer_text.strip() else "No response generated", + answer=answer_text.strip() or "No response generated", sources=list(sources_used), chunks_used=topk_idxs, chunks_by_page=chunks_by_page, @@ -628,13 +704,10 @@ async def chat(request: ChatRequest): ) except Exception as e: - print(f"Error processing query: {str(e)}") - raise HTTPException( - status_code=500, - detail=f"Error processing query: {str(e)}", - ) + print(f"Error processing query: {e}") + raise HTTPException(status_code=500, detail=f"Error processing query: {e}") if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file diff --git a/src/config.py b/src/config.py index b5c86437..35811fc6 100644 --- a/src/config.py +++ b/src/config.py @@ -14,15 +14,16 @@ class RAGConfig: # chunking chunk_config: ChunkConfig = field(init=False) chunk_mode: str = "recursive_sections" - chunk_size: int = 2000 - chunk_overlap: int = 200 + chunk_size_in_chars: int = 2000 + chunk_overlap: int = 300 # retrieval + ranking top_k: int = 10 num_candidates: int = 60 embed_model: str = "models/Qwen3-Embedding-4B-Q5_K_M.gguf" + embedding_model_context_window: int = 4096 ensemble_method: str = "rrf" - rrf_k: int = 60 + rrf_k: int = 60 ranker_weights: Dict[str, float] = field( default_factory=lambda: {"faiss": 1.0, "bm25": 0.0, "index_keywords": 0.0} ) @@ -32,7 +33,7 @@ class RAGConfig: # generation max_gen_tokens: int = 400 gen_model: str = "models/qwen2.5-3b-instruct-q8_0.gguf" - + # testing system_prompt_mode: str = "baseline" disable_chunks: bool = False @@ -48,7 +49,7 @@ class RAGConfig: # conversational memory enable_history: bool = True max_history_turns: int = 3 - + # index parameters use_indexed_chunks: bool = False extracted_index_path: os.PathLike = "data/extracted_index.json" @@ -61,17 +62,18 @@ class RAGConfig: @classmethod def from_yaml(cls, path: os.PathLike) -> RAGConfig: with open(path, 'r') as f: - data = yaml.safe_load(open(path)) + data = yaml.safe_load(f) return cls(**data) - + def __post_init__(self): """Validation logic runs automatically after initialization.""" assert self.top_k > 0, "top_k must be > 0" assert self.num_candidates >= self.top_k, "num_candidates must be >= top_k" - assert self.ensemble_method.lower() in {"linear","weighted","rrf"} - if self.ensemble_method.lower() in {"linear","weighted"}: + assert self.ensemble_method.lower() in {"linear", "weighted", "rrf"} + assert self.embedding_model_context_window > 0, "embedding_model_context_window must be > 0" + if self.ensemble_method.lower() in {"linear", "weighted"}: s = sum(self.ranker_weights.values()) or 1.0 - self.ranker_weights = {k: v/s for k, v in self.ranker_weights.items()} + self.ranker_weights = {k: v / s for k, v in self.ranker_weights.items()} self.chunk_config = self.get_chunk_config() self.chunk_config.validate() @@ -81,8 +83,8 @@ def get_chunk_config(self) -> ChunkConfig: """Parse chunk configuration from YAML.""" if self.chunk_mode == "recursive_sections": return SectionRecursiveConfig( - recursive_chunk_size=self.chunk_size, - recursive_overlap=self.chunk_overlap + recursive_chunk_size=self.chunk_size_in_chars, + recursive_overlap=self.chunk_overlap, ) else: raise ValueError(f"Unknown chunk_mode: {self.chunk_mode}. Supported: recursive_sections") @@ -98,14 +100,12 @@ def get_artifacts_directory(self) -> os.PathLike: strategy_dir = pathlib.Path("index", strategy.artifact_folder_name()) strategy_dir.mkdir(parents=True, exist_ok=True) return strategy_dir - - def get_config_state(self) -> None: - """Returns dict of all config parameters except chunk_config """ + + def get_config_state(self) -> dict: + """Returns dict of all config parameters except chunk_config.""" state = self.__dict__.copy() - state.pop("chunk_config", None) # remove chunk_config to avoid serialization issues - # also pop any non-serializable fields if needed + state.pop("chunk_config", None) for key in list(state.keys()): if not isinstance(state[key], (int, float, str, bool, list, dict, type(None))): state.pop(key) - return state - + return state \ No newline at end of file diff --git a/src/embedder.py b/src/embedder.py index d543b322..1b95f532 100644 --- a/src/embedder.py +++ b/src/embedder.py @@ -12,10 +12,9 @@ _worker_model: Optional[Llama] = None _worker_embedding_dim: int = 0 + def _init_worker(model_path: str, n_ctx: int, n_threads: int): - """ - Initializes the model inside a worker process. - """ + """Initializes the model inside a worker process.""" global _worker_model, _worker_embedding_dim _worker_model = Llama( @@ -24,57 +23,55 @@ def _init_worker(model_path: str, n_ctx: int, n_threads: int): n_threads=n_threads, embedding=True, verbose=False, - use_mmap=True # Allows OS to share model weights across processes + use_mmap=True, ) - - # Cache dimension + test_emb = _worker_model.create_embedding("test")['data'][0]['embedding'] _worker_embedding_dim = len(test_emb) + def _encode_batch_worker(texts: List[str]) -> List[List[float]]: - """ - Encodes a batch of text using the worker's local model instance. - """ + """Encodes a batch of text using the worker's local model instance.""" global _worker_model, _worker_embedding_dim if _worker_model is None: return [] - + embeddings = [] for text in texts: try: - # Create embedding emb = _worker_model.create_embedding(text)['data'][0]['embedding'] embeddings.append(emb) except Exception: - # Return zero vector on failure embeddings.append([0.0] * _worker_embedding_dim) - + return embeddings + class SentenceTransformer: def __init__(self, model_path: str, n_ctx: int = 4096, n_threads: int = None): """ Initialize with a local GGUF model file path. - + Args: model_path: Path to your local .gguf file - n_ctx: Context window size (increased to match Qwen3 training context) - n_threads: Number of threads to use (None = auto-detect) + n_ctx: Context window size. Defaults to 4096. + n_threads: Number of threads (None = auto-detect) """ self.model_path = model_path self.n_ctx = n_ctx - + self.model = Llama( model_path=model_path, n_ctx=n_ctx, n_threads=n_threads, embedding=True, - verbose=True, + verbose=False, use_mmap=True, - n_gpu_layers=-1 # use GPU if available + n_gpu_layers=-1, ) self._embedding_dimension = None - + + # Warm up โ also caches embedding dimension _ = self.embedding_dimension @property @@ -85,116 +82,102 @@ def embedding_dimension(self) -> int: self._embedding_dimension = len(test_embedding) return self._embedding_dimension - def encode(self, - texts: Union[str, List[str]], - batch_size: int = 16, # Adjusted for 4B model - normalize: bool = False, - show_progress_bar: bool = False, - **kwargs) -> np.ndarray: - + def encode( + self, + texts: Union[str, List[str]], + batch_size: int = 1, + normalize: bool = False, + show_progress_bar: bool = False, + **kwargs, + ) -> np.ndarray: """ - Encode texts to embeddings with batch processing. - + Encode texts to embeddings sequentially. + Args: - texts: Single text or list of texts to encode - batch_size: Number of texts to process at once - normalize: Whether to normalize embeddings - show_progress_bar: Whether to show progress bar - Returns: - numpy.ndarray: Float32 embeddings array + texts: Single text or list of texts to encode. + batch_size: Unused โ kept for API compatibility only. + All encoding is sequential (one chunk per + forward pass) due to llama-cpp-python + n_seq_max limitations in 0.3.x. + normalize: Whether to L2-normalize embeddings. + show_progress_bar: Whether to show a tqdm progress bar. + Returns: + numpy.ndarray: Float32 embeddings of shape (len(texts), dim). """ if isinstance(texts, str): texts = [texts] - + if not texts: - return np.array([], dtype=np.float32).reshape(0, -1) - - # Process in batches + return np.array([], dtype=np.float32).reshape(0, self.embedding_dimension) + embeddings = [] - num_batches = (len(texts) + batch_size - 1) // batch_size + failed_indices = [] - for i in tqdm(range(num_batches), desc="Encoding", disable=not show_progress_bar): - start_idx = i * batch_size - end_idx = min((i + 1) * batch_size, len(texts)) - batch_texts = texts[start_idx:end_idx] - + for i, text in enumerate(tqdm(texts, desc="Encoding", disable=not show_progress_bar)): try: - # IMPORTANT CHANGE: Pass the entire LIST to the model at once. - # This triggers the native C++/Metal batch processing logic. - response = self.model.create_embedding(batch_texts) - - # Extract the list of embedding vectors from the response - batch_embeddings = [item['embedding'] for item in response['data']] - embeddings.extend(batch_embeddings) - + emb = self.model.create_embedding(text)['data'][0]['embedding'] + embeddings.append(emb) except Exception as e: - print(f"Error encoding batch: {e}") - # Fallback: encode one by one if batch fails, or append zeros - for _ in batch_texts: - embeddings.append([0.0] * self.embedding_dimension) - + print(f" [ERROR] Failed to embed chunk {i}: {e}") + print(f" Preview: '{text[:80]}...'") + failed_indices.append(i) + embeddings.append([0.0] * self.embedding_dimension) + + if failed_indices: + print(f"\n[WARNING] {len(failed_indices)} chunk(s) failed embedding: indices {failed_indices}") + print("These chunks will have zero vectors in the FAISS index and will not be retrievable.\n") + vecs = np.array(embeddings, dtype=np.float32) - - if normalize: # do L2 normalization + + if normalize: norms = np.linalg.norm(vecs, axis=1, keepdims=True) vecs = vecs / np.where(norms == 0, 1e-12, norms) - + return vecs def get_sentence_embedding_dimension(self) -> int: - """Get the dimension of embeddings (compatibility method).""" + """Compatibility method.""" return self.embedding_dimension def start_multi_process_pool(self, num_workers: int = None) -> multiprocessing.pool.Pool: - """ - Starts a pool of worker processes. - """ - if num_workers: - workers = num_workers - else: - # Default to CPU count - 2 (leave room for OS/Main process) - workers = max(1, multiprocessing.cpu_count() - 2) - + """Starts a pool of worker processes.""" + workers = num_workers if num_workers else max(1, multiprocessing.cpu_count() - 2) print(f"Creating {workers} worker processes...") - - # Use 1 thread per worker to avoid CPU thrashing - worker_threads = 1 - + pool = multiprocessing.Pool( processes=workers, initializer=_init_worker, - initargs=(self.model_path, self.n_ctx, worker_threads) + initargs=(self.model_path, self.n_ctx, 1), ) return pool - def encode_multi_process(self, texts: List[str], pool: multiprocessing.pool.Pool, batch_size: int = 32) -> np.ndarray: - """ - Distributes work across the pool. - """ - # Sort by length to minimize padding/processing waste + def encode_multi_process( + self, + texts: List[str], + pool: multiprocessing.pool.Pool, + batch_size: int = 32, + ) -> np.ndarray: + """Distributes encoding work across the worker pool.""" indices = np.argsort([len(t) for t in texts])[::-1] sorted_texts = [texts[i] for i in indices] - # Create batches - chunks = [sorted_texts[i : i + batch_size] for i in range(0, len(sorted_texts), batch_size)] + chunks = [sorted_texts[i:i + batch_size] for i in range(0, len(sorted_texts), batch_size)] - # Process with progress bar results = [] print(f"Dispatching {len(chunks)} batches to pool...") for batch_result in tqdm( - pool.imap(_encode_batch_worker, chunks), - total=len(chunks), - desc="Parallel Encoding" + pool.imap(_encode_batch_worker, chunks), + total=len(chunks), + desc="Parallel Encoding", ): results.append(batch_result) flat_embeddings = [emb for batch in results for emb in batch] - # Restore original order inverse_indices = np.empty_like(indices) inverse_indices[indices] = np.arange(len(indices)) ordered_embeddings = [flat_embeddings[i] for i in inverse_indices] - + return np.array(ordered_embeddings, dtype=np.float32) @staticmethod @@ -205,14 +188,13 @@ def stop_multi_process_pool(pool: multiprocessing.pool.Pool): class EmbeddingCache: """Persistent SQLite cache for embeddings.""" - + def __init__(self, cache_dir: str = "index/cache"): self.db_path = Path(cache_dir) / "embeddings.db" self.db_path.parent.mkdir(parents=True, exist_ok=True) self._init_db() - + def _init_db(self): - """Initialize database schema.""" with sqlite3.connect(self.db_path) as conn: conn.execute(""" CREATE TABLE IF NOT EXISTS embeddings ( @@ -225,31 +207,27 @@ def _init_db(self): ) """) conn.execute("CREATE INDEX IF NOT EXISTS idx_model_name ON embeddings(model_name)") - + def get(self, model_path: str, query: str) -> Optional[np.ndarray]: - """Retrieve cached embedding if it exists.""" model_hash = hashlib.md5(model_path.encode()).hexdigest()[:16] - with sqlite3.connect(self.db_path) as conn: row = conn.execute( "SELECT embedding FROM embeddings WHERE model_hash=? AND query_text=?", - (model_hash, query) + (model_hash, query), ).fetchone() - if row: return np.frombuffer(row[0], dtype=np.float32) return None - + def set(self, model_path: str, query: str, embedding: np.ndarray): - """Store embedding in cache.""" model_name = Path(model_path).stem model_hash = hashlib.md5(model_path.encode()).hexdigest()[:16] blob = embedding.astype(np.float32).tobytes() - with sqlite3.connect(self.db_path) as conn: conn.execute( - "INSERT OR REPLACE INTO embeddings (model_name, model_hash, query_text, embedding) VALUES (?,?,?,?)", - (model_name, model_hash, query, blob) + "INSERT OR REPLACE INTO embeddings " + "(model_name, model_hash, query_text, embedding) VALUES (?,?,?,?)", + (model_name, model_hash, query, blob), ) @@ -258,22 +236,20 @@ class CachedEmbedder: Wrapper around SentenceTransformer that caches query embeddings. Drop-in replacement for SentenceTransformer. """ - + def __init__(self, model_path: str, **kwargs): self.embedder = SentenceTransformer(model_path, **kwargs) self.cache = EmbeddingCache() self.model_path = model_path - + def encode(self, texts, **kwargs): - """Encode texts with caching support.""" if isinstance(texts, str): texts = [texts] - + results = [] to_compute = [] to_compute_indices = [] - - # Check cache for each text + for i, text in enumerate(texts): cached = self.cache.get(self.model_path, text) if cached is not None: @@ -281,18 +257,15 @@ def encode(self, texts, **kwargs): else: to_compute.append(text) to_compute_indices.append(i) - - # Compute missing embeddings + if to_compute: computed = self.embedder.encode(to_compute, **kwargs) for idx, text, emb in zip(to_compute_indices, to_compute, computed): self.cache.set(self.model_path, text, emb) results.append((idx, emb)) - - # Restore original order + results.sort(key=lambda x: x[0]) return np.array([emb for _, emb in results]) - + def __getattr__(self, name): - """Delegate other methods to wrapped embedder.""" - return getattr(self.embedder, name) + return getattr(self.embedder, name) \ No newline at end of file diff --git a/src/index_builder.py b/src/index_builder.py index c445fa7f..8d9c0741 100644 --- a/src/index_builder.py +++ b/src/index_builder.py @@ -2,9 +2,6 @@ """ index_builder.py PDF -> markdown text -> chunks -> embeddings -> BM25 + FAISS + metadata - -Entry point (called by main.py): - build_index(markdown_file, cfg, keep_tables=True, do_visualize=False) """ import os @@ -14,11 +11,12 @@ import json from typing import List, Dict +import numpy as np import faiss from rank_bm25 import BM25Okapi from src.embedder import SentenceTransformer -from src.preprocessing.chunking import DocumentChunker, ChunkConfig +from src.preprocessing.chunking import DocumentChunker, ChunkConfig, print_chunk_stats from src.preprocessing.extraction import extract_sections_from_markdown # ----- runtime parallelism knobs (avoid oversubscription) ----- @@ -29,10 +27,8 @@ os.environ["VECLIB_MAXIMUM_THREADS"] = "1" os.environ["NUMEXPR_NUM_THREADS"] = "1" -# Default keywords to exclude sections DEFAULT_EXCLUSION_KEYWORDS = ['questions', 'exercises', 'summary', 'references'] -# ------------------------ Main index builder ----------------------------- def build_index( markdown_file: str, @@ -40,10 +36,11 @@ def build_index( chunker: DocumentChunker, chunk_config: ChunkConfig, embedding_model_path: str, + embedding_model_context_window: int, artifacts_dir: os.PathLike, index_prefix: str, use_multiprocessing: bool = False, - use_headings: bool = False + use_headings: bool = False, ) -> None: """ Extract sections, chunk, embed, and build both FAISS and BM25 indexes. @@ -54,12 +51,12 @@ def build_index( - {prefix}_chunks.pkl - {prefix}_sources.pkl - {prefix}_meta.pkl + - {prefix}_page_to_chunk_map.json """ all_chunks: List[str] = [] sources: List[str] = [] metadata: List[Dict] = [] - # Extract sections from markdown. Exclude some with certain keywords. sections = extract_sections_from_markdown( markdown_file, exclusion_keywords=DEFAULT_EXCLUSION_KEYWORDS @@ -70,73 +67,47 @@ def build_index( total_chunks = 0 heading_stack = [] - # Step 1: Chunk using DocumentChunker + # Step 1: Chunk for i, c in enumerate(sections): - # Determine current section level current_level = c.get('level', 1) - - # Determine current chapter number chapter_num = c.get('chapter', 0) - # Pop sections that are deeper or siblings while heading_stack and heading_stack[-1][0] >= current_level: heading_stack.pop() - - # Push pair of (level, heading) + if c['heading'] != "Introduction": heading_stack.append((current_level, c['heading'])) - # Construct section path path_list = [h[1] for h in heading_stack] full_section_path = " ".join(path_list) full_section_path = f"Chapter {chapter_num} " + full_section_path - # Use DocumentChunker to recursively split this section sub_chunks = chunker.chunk(c['content']) - - # Regex to find page markers like "--- Page 3 ---" page_pattern = re.compile(r'--- Page (\d+) ---') - # Iterate through each chunk produced from this section for sub_chunk_id, sub_chunk in enumerate(sub_chunks): - # Track all pages this specific chunk touches chunk_pages = set() - - # Split the sub_chunk by page markers to see if it - # spans multiple pages. fragments = page_pattern.split(sub_chunk) - # If there is content before the first page marker, - # it belongs to the current_page. if fragments[0].strip(): - page_to_chunk_ids.setdefault(current_page, set()).add(total_chunks+sub_chunk_id) + page_to_chunk_ids.setdefault(current_page, set()).add(total_chunks + sub_chunk_id) chunk_pages.add(current_page) - # Process the new pages found within this sub_chunk. - # Step by 2 where each pair represents (page number, text after it) - for i in range(1, len(fragments), 2): + for idx in range(1, len(fragments), 2): try: - # Get the new page number from the marker - new_page = int(fragments[i]) + 1 - - # If there is text after this marker, it belongs to the new_page. - if fragments[i+1].strip(): + new_page = int(fragments[idx]) + 1 + if fragments[idx + 1].strip(): page_to_chunk_ids.setdefault(new_page, set()).add(total_chunks + sub_chunk_id) chunk_pages.add(new_page) - current_page = new_page - except (IndexError, ValueError): continue - # Clean sub_chunk by removing page markers clean_chunk = re.sub(page_pattern, '', sub_chunk).strip() - - # Skip introduction chunks for embedding + if c["heading"] == "Introduction": continue - - # Prepare metadata + meta = { "filename": markdown_file, "mode": chunk_config.to_string(), @@ -146,59 +117,53 @@ def build_index( "section_path": full_section_path, "text_preview": clean_chunk[:100], "page_numbers": sorted(list(chunk_pages)), - "chunk_id": total_chunks + sub_chunk_id + "chunk_id": total_chunks + sub_chunk_id, } - # Prepare chunk with prefix - if use_headings: - chunk_prefix = ( - f"Description: {full_section_path} " - f"Content: " - ) - else: - chunk_prefix = "" + chunk_prefix = ( + f"Description: {full_section_path} Content: " + if use_headings else "" + ) - all_chunks.append(chunk_prefix+clean_chunk) + all_chunks.append(chunk_prefix + clean_chunk) sources.append(markdown_file) metadata.append(meta) total_chunks += len(sub_chunks) - # Convert the sets to sorted lists for a clean, predictable output - final_map = {} - for page, id_set in page_to_chunk_ids.items(): - final_map[page] = sorted(list(id_set)) - + # Save page-to-chunk map + final_map = {page: sorted(list(ids)) for page, ids in page_to_chunk_ids.items()} output_file = artifacts_dir / f"{index_prefix}_page_to_chunk_map.json" with open(output_file, "w") as f: json.dump(final_map, f, indent=2) print(f"Saved page to chunk ID map: {output_file}") - # Step 2: Create embeddings for FAISS index - print(f"Embedding {len(all_chunks):,} chunks with {pathlib.Path(embedding_model_path).stem} ...") - embedder = SentenceTransformer(embedding_model_path) + # Print chunk stats before embedding + print_chunk_stats(all_chunks, chunk_size_in_chars=chunk_config.recursive_chunk_size) + + # Step 2: Load embedder + print(f"Loading embedding model (n_ctx={embedding_model_context_window})...") + embedder = SentenceTransformer( + embedding_model_path, + n_ctx=embedding_model_context_window, + ) + print(f"Embedding {len(all_chunks):,} chunks sequentially...") if use_multiprocessing: print("Starting multi-process pool for embeddings...") - # Start the pool. Adjust number of workers as needed. pool = embedder.start_multi_process_pool(workers=4) try: - # Compute embeddings in parallel embeddings = embedder.encode_multi_process( - all_chunks, - pool, - batch_size=32 + all_chunks, + pool, + batch_size=4, ) finally: - # Stop the pool to prevent hanging processes embedder.stop_multi_process_pool(pool) else: - # Standard single-process embedding embeddings = embedder.encode( - all_chunks, - batch_size=8, + all_chunks, show_progress_bar=True, - convert_to_numpy=True ) # Step 3: Build FAISS index @@ -207,7 +172,7 @@ def build_index( index = faiss.IndexFlatL2(dim) index.add(embeddings) faiss.write_index(index, str(artifacts_dir / f"{index_prefix}.faiss")) - print(f"FAISS Index built successfully: {index_prefix}.faiss") + print(f"FAISS index built: {index_prefix}.faiss") # Step 4: Build BM25 index print(f"Building BM25 index for {len(all_chunks):,} chunks...") @@ -215,9 +180,9 @@ def build_index( bm25_index = BM25Okapi(tokenized_chunks) with open(artifacts_dir / f"{index_prefix}_bm25.pkl", "wb") as f: pickle.dump(bm25_index, f) - print(f"BM25 Index built successfully: {index_prefix}_bm25.pkl") + print(f"BM25 index built: {index_prefix}_bm25.pkl") - # Step 5: Dump index artifacts + # Step 5: Persist remaining artifacts with open(artifacts_dir / f"{index_prefix}_chunks.pkl", "wb") as f: pickle.dump(all_chunks, f) with open(artifacts_dir / f"{index_prefix}_sources.pkl", "wb") as f: @@ -226,20 +191,9 @@ def build_index( pickle.dump(metadata, f) print(f"Saved all index artifacts with prefix: {index_prefix}") -# ------------------------ Helper functions ------------------------------ def preprocess_for_bm25(text: str) -> list[str]: - """ - Simplifies text to keep only letters, numbers, underscores, hyphens, - apostrophes, plus, and hash โ suitable for BM25 tokenization. - """ - # Convert to lowercase + """Lowercase and tokenize text for BM25 indexing.""" text = text.lower() - - # Keep only allowed characters text = re.sub(r"[^a-z0-9_'#+-]", " ", text) - - # Split by whitespace - tokens = text.split() - - return tokens + return text.split() \ No newline at end of file diff --git a/src/instrumentation/logging.py b/src/instrumentation/logging.py index 7f1f4c72..072dd874 100644 --- a/src/instrumentation/logging.py +++ b/src/instrumentation/logging.py @@ -4,6 +4,7 @@ from typing import Dict, List, Any, Optional, Union import numpy as np + class NpEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.integer): @@ -14,93 +15,138 @@ def default(self, obj): return obj.tolist() return super(NpEncoder, self).default(obj) + class RunLogger: def __init__(self): self.logs_dir = Path("logs") self.logs_dir.mkdir(exist_ok=True) - def save_chat_log(self, - query: str, - chat_request_params : Optional[Dict[str, Any]], - ordered_scores: List[Union[float, int]], - config_state: Dict[str, Any], - top_idxs: List[int], - chunks: List[str], - sources: List[str], - page_map: Dict[int, int], - full_response: str, - top_k: int, - additional_log_info: Optional[Dict[str, Any]] = None): - """Creates a unique JSON file for this specific chat request.""" - - # timestamp for filename: 20240520_143005 (Sorts newest to bottom, - # but if you want newest at top, most OS sort by name DESC) + def save_chat_log( + self, + query: str, + chat_request_params: Optional[Dict[str, Any]], + config_state: Dict[str, Any], + top_idxs: List[int], + chunks: List[str], + sources: List[str], + page_map: Dict[int, List[int]], + full_response: str, + top_k: int, + ordered_scores: List[float], + chunk_diagnostics: Optional[Dict[int, Dict[str, Any]]] = None, + additional_log_info: Optional[Dict[str, Any]] = None, + ): + """ + Creates a unique JSON log file for this query. + + Args: + query: The user query. + chat_request_params: Params used in the chat request. + config_state: Snapshot of RAGConfig at query time. + top_idxs: Ordered list of retrieved chunk indices (post-reranking). + chunks: Full chunk list (all chunks, not just top-k). + sources: Source filepath per chunk. + page_map: Mapping of chunk_idx -> list of page numbers. + full_response: The generated answer. + top_k: Number of chunks used. + ordered_scores: Post-fusion RRF scores, aligned with top_idxs. + chunk_diagnostics: Optional dict keyed by chunk_idx with per-retriever + diagnostic fields: + { + idx: { + "faiss_score": float, + "faiss_rank": int, + "bm25_score": float, + "bm25_rank": int, + "post_fusion_rank": int, + "post_reranking_rank": int, + "cross_encoder_score": float, + } + } + additional_log_info: Any extra fields to merge into the top-level log. + """ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") log_id = f"chat_{timestamp_str}" - - ### make list of dicts for retrieval - - # make page_numbers list - page_numbers_list = [page_map.get(i, 1) for i in top_idxs] - # make sure chunks, top_idxs, sources, and ordered_scores are the same length - if not (len(chunks) == len(top_idxs) == len(sources) == len(ordered_scores) == len(page_numbers_list)): - print("Warning: Lengths of chunks, top_idxs, sources, ordered_scores, and page_numbers do not match.") - print("Defaulting to long form logging ") + + page_numbers_list = [page_map.get(i, [1]) for i in top_idxs] + + lengths_match = ( + len(chunks) == len(top_idxs) == len(sources) + == len(ordered_scores) == len(page_numbers_list) + ) + + if not lengths_match: + print("Warning: Lengths of chunks, top_idxs, sources, ordered_scores, " + "and page_numbers do not match. Defaulting to long-form logging.") log_data = { "timestamp": datetime.now().isoformat(), "query": query, "chat_request_params": chat_request_params, - "config_state" : config_state, + "config_state": config_state, "top_k": top_k, "ordered_scores": ordered_scores[:len(top_idxs)], "top_idxs": top_idxs, "chunks": chunks[:len(top_idxs)], "sources": sources[:len(top_idxs)], - "page_numbers": [page_map.get(i, 1) for i in top_idxs], - "full_response": full_response + "page_numbers": [page_map.get(i, [1]) for i in top_idxs], + "full_response": full_response, } else: retrieved_chunks = [] - for i, (chunk, idx, source, score, page_number) in enumerate(zip(chunks, top_idxs, sources, ordered_scores, page_numbers_list)): - retrieved_chunks.append({ - "rank": i + 1, + for rank, (chunk_text, idx, source, score, page_numbers) in enumerate( + zip(chunks, top_idxs, sources, ordered_scores, page_numbers_list), start=1 + ): + entry = { + "rank": rank, "idx": idx, - "chunk": chunk, + "chunk": chunk_text, "source": source, - "score": score, - "page_number": page_number - }) + "page_number": page_numbers, + # Post-fusion RRF score โ always present + "post_fusion_score": score, + } + + # Merge per-chunk diagnostics if provided + if chunk_diagnostics and idx in chunk_diagnostics: + diag = chunk_diagnostics[idx] + entry["faiss_score"] = diag.get("faiss_score", None) + entry["faiss_rank"] = diag.get("faiss_rank", None) + entry["bm25_score"] = diag.get("bm25_score", None) + entry["bm25_rank"] = diag.get("bm25_rank", None) + entry["post_fusion_rank"] = diag.get("post_fusion_rank", None) + entry["post_reranking_rank"] = diag.get("post_reranking_rank", None) + entry["cross_encoder_score"] = diag.get("cross_encoder_score", None) + + retrieved_chunks.append(entry) log_data = { "timestamp": datetime.now().isoformat(), "query": query, "chat_request_params": chat_request_params, - "config_state" : config_state, + "config_state": config_state, "top_k": top_k, "retrieved_chunks": retrieved_chunks, - "full_response": full_response + "full_response": full_response, } + if additional_log_info: - for key in additional_log_info: + for key, value in additional_log_info.items(): if key in log_data: - print(f"Warning: Key '{key}' in additional_log_info conflicts with existing log data keys. Skipping this key.") + print(f"Warning: Key '{key}' in additional_log_info conflicts " + f"with existing log key. Skipping.") else: - log_data[key] = additional_log_info[key] - + log_data[key] = value + log_file = self.logs_dir / f"{log_id}.json" - - # Write as a single pretty-printed JSON file with open(log_file, "w", encoding="utf-8") as f: json.dump(log_data, f, ensure_ascii=False, indent=4, cls=NpEncoder) -# Global Instance logic + +# Global instance _INSTANCE = None + def get_logger(): global _INSTANCE if _INSTANCE is None: _INSTANCE = RunLogger() - return _INSTANCE - - - - + return _INSTANCE \ No newline at end of file diff --git a/src/main.py b/src/main.py index d14b8a23..d2759372 100644 --- a/src/main.py +++ b/src/main.py @@ -70,12 +70,12 @@ def run_index_mode(args: argparse.Namespace, cfg: RAGConfig): chunker=chunker, chunk_config=cfg.chunk_config, embedding_model_path=cfg.embed_model, + embedding_model_context_window=cfg.embedding_model_context_window, artifacts_dir=artifacts_dir, index_prefix=args.index_prefix, use_multiprocessing=args.multiproc_indexing, use_headings=args.embed_with_headings, ) - def use_indexed_chunks(question: str, chunks: list) -> list: # Logic for keyword matching from textbook index try: @@ -107,149 +107,181 @@ def get_answer( is_test_mode: bool = False, additional_log_info: Optional[Dict[str, Any]] = None ) -> Union[str, Tuple[str, List[Dict[str, Any]], Optional[str]]]: - """ - Run a single query through the pipeline. - """ + """Run a single query through the pipeline.""" + chunks = artifacts["chunks"] sources = artifacts["sources"] retrievers = artifacts["retrievers"] ranker = artifacts["ranker"] - # Ensure these locals exist for all control flows to avoid UnboundLocalError + ranked_chunks: List[str] = [] topk_idxs: List[int] = [] - scores = [] - - # Step 1: Get chunks (golden, retrieved, or none) - chunks_info = None + scores: List[float] = [] + chunk_diagnostics: Dict[int, Dict[str, Any]] = {} hyde_query = None + chunks_info = None + + # Step 1: Get chunks if golden_chunks and cfg.use_golden_chunks: - # Use provided golden chunks ranked_chunks = golden_chunks + elif cfg.disable_chunks: - # No chunks - baseline mode ranked_chunks = [] + elif cfg.use_indexed_chunks: ranked_chunks, topk_idxs = use_indexed_chunks(question, chunks) + else: retrieval_query = question - # print(f"Retrieval query: {retrieval_query}") if cfg.use_hyde: - retrieval_query = generate_hypothetical_document(question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens) - + retrieval_query = generate_hypothetical_document( + question, cfg.gen_model, max_tokens=cfg.hyde_max_tokens + ) + hyde_query = retrieval_query + pool_n = max(cfg.num_candidates, cfg.top_k + 10) + + # Step 1a: Per-retriever raw scores raw_scores: Dict[str, Dict[int, float]] = {} for retriever in retrievers: - # print(f"Getting scores from retriever: {retriever.name}...") - raw_scores[retriever.name] = retriever.get_scores(retrieval_query, pool_n, chunks) - # TODO: Fix retrieval logging. - - # print("Raw scores from retrievers:") - # for retriever_name, score_dict in raw_scores.items(): - # print(f" {retriever_name}: {list(score_dict.values())}") - # Step 2: Ranking + raw_scores[retriever.name] = retriever.get_scores( + retrieval_query, pool_n, chunks + ) + + # Step 2: Fusion ranking ordered, scores = ranker.rank(raw_scores=raw_scores) - # print(f"Ordered candidate indices after ranking: {ordered[:cfg.top_k]}") - # print(f"Corresponding scores: {scores[:cfg.top_k]}") + + # Post-fusion rank lookup: chunk_idx -> 1-based rank in fused ordering + post_fusion_ranks = {idx: rank + 1 for rank, idx in enumerate(ordered)} + + # Step 3: Filter to top-k topk_idxs = filter_retrieved_chunks(cfg, chunks, ordered) - ranked_chunks = [chunks[i] for i in topk_idxs] - # print(f"Top-{cfg.top_k} chunk indices after filtering: {topk_idxs}") - # print("Len Ranked chunks:", len(ranked_chunks)) - # print("Example ranked chunk content:", ranked_chunks[0] if ranked_chunks else "No chunks retrieved") - - - # Capture chunk info if in test mode + + # Step 4: Build per-retriever rank lookups for diagnostics + faiss_scores = raw_scores.get("faiss", {}) + bm25_scores = raw_scores.get("bm25", {}) + + faiss_ranked = sorted(faiss_scores, key=lambda i: faiss_scores[i], reverse=True) + bm25_ranked = sorted(bm25_scores, key=lambda i: bm25_scores[i], reverse=True) + + faiss_ranks = {idx: rank + 1 for rank, idx in enumerate(faiss_ranked)} + bm25_ranks = {idx: rank + 1 for rank, idx in enumerate(bm25_ranked)} + + # Build diagnostics dict for each top-k chunk (pre-reranking) + for idx in topk_idxs: + chunk_diagnostics[idx] = { + "faiss_score": faiss_scores.get(idx, None), + "faiss_rank": faiss_ranks.get(idx, None), + "bm25_score": bm25_scores.get(idx, None), + "bm25_rank": bm25_ranks.get(idx, None), + "post_fusion_rank": post_fusion_ranks.get(idx, None), + # Filled in after reranking below + "post_reranking_rank": None, + "cross_encoder_score": None, + } + + # Step 5: Reranking โ pass (idx, text) pairs so indices survive + indexed_chunks = [(idx, chunks[idx]) for idx in topk_idxs] + reranked = rerank( + question, + indexed_chunks, + mode=cfg.rerank_mode, + top_n=cfg.rerank_top_k, + ) + # reranked is List[Tuple[int, str, float]]: (idx, text, cross_encoder_score) + + # Rebuild topk_idxs and ranked_chunks in reranked order + topk_idxs = [idx for idx, _, _ in reranked] + ranked_chunks = [(text, ce_score) for _, text, ce_score in reranked] + + # Fill in post-reranking rank and cross-encoder score + for rerank_rank, (idx, _, ce_score) in enumerate(reranked, start=1): + if idx in chunk_diagnostics: + chunk_diagnostics[idx]["post_reranking_rank"] = rerank_rank + chunk_diagnostics[idx]["cross_encoder_score"] = ce_score + + # For test mode if is_test_mode: - # Compute individual ranker ranks - faiss_scores = raw_scores.get("faiss", {}) - bm25_scores = raw_scores.get("bm25", {}) - index_scores = raw_scores.get("index_keywords", {}) - - faiss_ranked = sorted(faiss_scores.keys(), key=lambda i: faiss_scores[i], reverse=True) - bm25_ranked = sorted(bm25_scores.keys(), key=lambda i: bm25_scores[i], reverse=True) - index_ranked = sorted(index_scores.keys(), key=lambda i: index_scores[i], reverse=True) - - faiss_ranks = {idx: rank + 1 for rank, idx in enumerate(faiss_ranked)} - bm25_ranks = {idx: rank + 1 for rank, idx in enumerate(bm25_ranked)} - index_ranks = {idx: rank + 1 for rank, idx in enumerate(index_ranked)} - chunks_info = [] - for rank, idx in enumerate(topk_idxs, 1): + for rerank_rank, (idx, text, ce_score) in enumerate(reranked, start=1): + diag = chunk_diagnostics.get(idx, {}) chunks_info.append({ - "rank": rank, - "chunk_id": idx, - "content": chunks[idx], - "faiss_score": faiss_scores.get(idx, 0), - "faiss_rank": faiss_ranks.get(idx, 0), - "bm25_score": bm25_scores.get(idx, 0), - "bm25_rank": bm25_ranks.get(idx, 0), - "index_score": index_scores.get(idx, 0), - "index_rank": index_ranks.get(idx, 0), + "rank": rerank_rank, + "chunk_id": idx, + "content": text, + "faiss_score": diag.get("faiss_score"), + "faiss_rank": diag.get("faiss_rank"), + "bm25_score": diag.get("bm25_score"), + "bm25_rank": diag.get("bm25_rank"), + "post_fusion_rank": diag.get("post_fusion_rank"), + "post_reranking_rank": rerank_rank, + "cross_encoder_score": ce_score, }) - # Step 3: Final re-ranking - ranked_chunks = rerank(question, ranked_chunks, mode=cfg.rerank_mode, top_n=cfg.rerank_top_k) - # print("Reranked Chunks", type(ranked_chunks), len(ranked_chunks), type(ranked_chunks[0]) if ranked_chunks else "No chunks") - # print("Example reranked chunk content:", ranked_chunks[0] if ranked_chunks else "No chunks after reranking") - if not ranked_chunks and not cfg.disable_chunks: if console: console.print(f"\n{ANSWER_NOT_FOUND}\n") return ANSWER_NOT_FOUND - # Step 4: Generation + # Step 6: Generation model_path = cfg.gen_model system_prompt = args.system_prompt_mode or cfg.system_prompt_mode - use_double = getattr(args, "double_prompt", False) or cfg.use_double_prompt if use_double: stream_iter = double_answer( - question, - ranked_chunks, - model_path, + question, ranked_chunks, model_path, max_tokens=cfg.max_gen_tokens, system_prompt_mode=system_prompt, ) else: stream_iter = answer( - question, - ranked_chunks, - model_path, + question, ranked_chunks, model_path, max_tokens=cfg.max_gen_tokens, system_prompt_mode=system_prompt, ) if is_test_mode: - # We do not render MD in the test mode ans = "" for delta in stream_iter: ans += delta ans = dedupe_generated_text(ans) return ans, chunks_info, hyde_query - else: - # Accumulate the full text while rendering incremental Markdown chunks - ans = render_streaming_ans(console, stream_iter) - - # Logging - meta = artifacts.get("meta", []) - page_nums = get_page_numbers(topk_idxs, meta) - logger.save_chat_log( - query=question, - config_state=cfg.get_config_state(), - ordered_scores=scores[:len(topk_idxs)] if 'scores' in locals() else [], - chat_request_params={ - "system_prompt": system_prompt, - "max_tokens": cfg.max_gen_tokens - }, - top_idxs=topk_idxs, - chunks=chunks, - sources=sources, - page_map=page_nums, - full_response=ans, - top_k=len(topk_idxs), - additional_log_info=additional_log_info - ) - return ans + + # Step 7: Render + log + ans = render_streaming_ans(console, stream_iter) + + meta = artifacts.get("meta", []) + page_nums = get_page_numbers(topk_idxs, meta) + + # Build aligned lists for the logger (post-reranking order) + top_chunks_text = [chunks[i] for i in topk_idxs] + top_sources = [sources[i] for i in topk_idxs] + top_scores = [ + scores[ordered.index(i)] if i in ordered else 0.0 + for i in topk_idxs + ] + + logger.save_chat_log( + query=question, + config_state=cfg.get_config_state(), + ordered_scores=top_scores, + chat_request_params={ + "system_prompt": system_prompt, + "max_tokens": cfg.max_gen_tokens, + }, + top_idxs=topk_idxs, + chunks=top_chunks_text, + sources=top_sources, + page_map=page_nums, + full_response=ans, + top_k=len(topk_idxs), + chunk_diagnostics=chunk_diagnostics, + additional_log_info=additional_log_info, + ) + + return ans + def render_streaming_ans(console, stream_iter): ans = "" diff --git a/src/preprocessing/chunking.py b/src/preprocessing/chunking.py index aad39bdd..8b5047e3 100644 --- a/src/preprocessing/chunking.py +++ b/src/preprocessing/chunking.py @@ -1,4 +1,5 @@ import re +import statistics from abc import ABC, abstractmethod from dataclasses import dataclass from typing import List, Tuple, Optional @@ -11,75 +12,175 @@ class ChunkConfig(ABC): @abstractmethod def validate(self): pass - + @abstractmethod def to_string(self) -> str: pass + @dataclass class SectionRecursiveConfig(ChunkConfig): """Configuration for section-based chunking with recursive splitting.""" recursive_chunk_size: int recursive_overlap: int - + def to_string(self) -> str: - return f"chunk_mode=sections+recursive, chunk_size={self.recursive_chunk_size}, overlap={self.recursive_overlap}" + return ( + f"chunk_mode=sections+recursive, " + f"chunk_size={self.recursive_chunk_size}, " + f"overlap={self.recursive_overlap}" + ) def validate(self): assert self.recursive_chunk_size > 0, "recursive_chunk_size must be > 0" assert self.recursive_overlap >= 0, "recursive_overlap must be >= 0" + assert self.recursive_overlap < self.recursive_chunk_size, \ + "recursive_overlap must be less than recursive_chunk_size" + # -------------------------- Chunking Strategies -------------------------- +# Separator hierarchy: tries the largest natural boundary first, +# falling back to smaller ones only when a chunk still exceeds chunk_size. +# The "" fallback guarantees a hard character-level split as a last resort. +_RECURSIVE_SEPARATORS = [ + "\n\n", # paragraph break โ strongest preference + ". ", # declarative sentence end + "? ", # question end + "! ", # exclamation end +] + + class ChunkStrategy(ABC): """Abstract base for all chunking strategies.""" @abstractmethod def name(self) -> str: pass - + @abstractmethod def chunk(self, text: str) -> List[str]: pass - + @abstractmethod def artifact_folder_name(self) -> str: pass + class SectionRecursiveStrategy(ChunkStrategy): """ - Applies recursive character-based splitting to text. - This is meant to be used on already-extracted sections. + Applies recursive character-based splitting to already-extracted sections. + Tries paragraph โ line โ sentence โ word โ character boundaries in order. + The "" fallback ensures no chunk can exceed chunk_size_in_chars. """ def __init__(self, config: SectionRecursiveConfig): self.config = config - self.recursive_chunk_size = config.recursive_chunk_size - self.recursive_overlap = config.recursive_overlap + config.validate() + self._splitter = RecursiveCharacterTextSplitter( + chunk_size=config.recursive_chunk_size, + chunk_overlap=config.recursive_overlap, + separators=_RECURSIVE_SEPARATORS, + keep_separator=True, + ) def name(self) -> str: - return f"sections+recursive({self.recursive_chunk_size},{self.recursive_overlap})" + return ( + f"sections+recursive" + f"({self.config.recursive_chunk_size},{self.config.recursive_overlap})" + ) def artifact_folder_name(self) -> str: return "sections" def chunk(self, text: str) -> List[str]: - """ - Recursively splits text into smaller chunks based on sentence boundaries. - If a chunk exceeds recursive_chunk_size, it is further split. - """ - splitter = RecursiveCharacterTextSplitter( - chunk_size=self.recursive_chunk_size, - chunk_overlap=self.recursive_overlap, - separators=[". "] - ) - return splitter.split_text(text) + chunks = self._splitter.split_text(text) + + # Post-split validation: the "" separator should prevent this, + # but log loudly if anything still slips through. + over_limit = [ + (i, len(c)) for i, c in enumerate(chunks) + if len(c) > self.config.recursive_chunk_size + ] + if over_limit: + for idx, length in over_limit: + print( + f"[WARNING] Chunk {idx} has {length} chars, " + f"exceeding limit of {self.config.recursive_chunk_size}. " + f"This should not happen with the '' fallback separator โ " + f"check for non-standard whitespace or very long tokens." + ) + + # Drop pure-whitespace chunks + return [c for c in chunks if c.strip()] + + +# -------------------------- Chunk Stats -------------------------- + +def print_chunk_stats(chunks: List[str], chunk_size_in_chars: int) -> None: + """ + Prints a statistical summary of chunk character lengths. + Useful for diagnosing chunking quality and context window overflow risk. + + Args: + chunks: List of chunk strings to analyse. + chunk_size_in_chars: The configured hard limit, used to flag overflows. + """ + if not chunks: + print("[Chunk Stats] No chunks to analyse.") + return + + lengths = [len(c) for c in chunks] + total = len(lengths) + over = [l for l in lengths if l > chunk_size_in_chars] + pct_over = (len(over) / total) * 100 + + # Rough token estimate at 4 chars/token + est_tokens = [l / 4.0 for l in lengths] + + print("\n" + "="*55) + print(" CHUNK STATS") + print("="*55) + print(f" Total chunks : {total:,}") + print(f" Char limit : {chunk_size_in_chars:,}") + print(f" --- Character lengths ---") + print(f" Min : {min(lengths):,}") + print(f" Max : {max(lengths):,}") + print(f" Mean : {statistics.mean(lengths):,.1f}") + print(f" Median : {statistics.median(lengths):,.1f}") + print(f" Stdev : {statistics.stdev(lengths):,.1f}" if total > 1 else " Stdev : N/A") + print(f" --- Token estimates (chars รท 4) ---") + print(f" Min : {min(est_tokens):.0f}") + print(f" Max : {max(est_tokens):.0f}") + print(f" Mean : {statistics.mean(est_tokens):.1f}") + print(f" --- Overflow ---") + print(f" Chunks over limit : {len(over):,} / {total:,} ({pct_over:.1f}%)") + if over: + print(f" Largest offender : {max(over):,} chars (~{max(over)/4:.0f} tokens)") + + print(f" --- Distribution ---") + buckets = [ + (" 0 โ 500 chars", 0, 500), + ("501 โ 1000 chars", 501, 1000), + ("1001 โ 1500 chars", 1001, 1500), + ("1501 โ 2000 chars", 1501, 2000), + ("2001 โ 2500 chars", 2001, 2500), + ("2500+ chars", 2501, float("inf")), + ] + for label, lo, hi in buckets: + count = sum(1 for l in lengths if lo <= l <= hi) + bar = "โ" * (count * 30 // total) if total else "" + print(f" {label}: {count:>5,} {bar}") + print("="*55 + "\n") + # ----------------------------- Document Chunker --------------------------------- class DocumentChunker: """ - Chunk text via a provided strategy. - Table blocks (