diff --git a/backend/app/config.py b/backend/app/config.py index 2b7f54e9..a6bc027a 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -105,6 +105,9 @@ class Settings(BaseSettings): # ── LLM (HuggingFace Inference API) ────────────────── HF_TOKEN: str = os.getenv("HF_TOKEN", "") # HuggingFace API token (set in .env) + LLM_PROVIDER: str = os.getenv("LLM_PROVIDER", "huggingface") + GEMINI_API_KEY: str = os.getenv("GEMINI_API_KEY", "") + GROQ_API_KEY: str = os.getenv("GROQ_API_KEY", "") LLM_MODEL: str = "Qwen/Qwen2.5-72B-Instruct" LLM_MAX_NEW_TOKENS: int = 1024 LLM_TEMPERATURE: float = 0.3 diff --git a/backend/app/rag/agent.py b/backend/app/rag/agent.py index 4917cfe5..86e48d66 100644 --- a/backend/app/rag/agent.py +++ b/backend/app/rag/agent.py @@ -1,110 +1,211 @@ """ -Agentic RAG — intelligent routing using ReAct (Reasoning and Acting). -Intelligently chooses between PDF search, Web Search, and Math tools. +Agentic RAG — fast direct pipeline with optional ReAct tool use. + +For document questions: retrieve → format context → single LLM call (fast). +For math / web queries: fall back to ReAct agent with tools. """ import logging import json from typing import List, Dict, Any, Optional, Generator -from sympy import python - from huggingface_hub import InferenceClient -from langchain_classic.agents import create_react_agent, AgentExecutor -from langchain_core.prompts import PromptTemplate -from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace +from langchain_core.messages import HumanMessage, SystemMessage from app.config import get_settings from app.rag.retriever import retrieve -from app.rag.graph_retriever import get_entity_context -from app.rag.prompts import AGENT_SYSTEM_PROMPT -from app.exceptions import ExternalServiceException +from app.rag.prompts import SYSTEM_PROMPT, RAG_PROMPT_TEMPLATE, AGENT_SYSTEM_PROMPT from app.rag.security import MALFORMED_OUTPUT_MESSAGE, OutputParserError, parse_agent_output -from app.rag.tools import PDFSearchTool, MathTool, WebSearchTool from app.rag.tracing import trace_function logger = logging.getLogger(__name__) settings = get_settings() +# ── LLM singleton cache ────────────────────────────── +_llm_cache: Dict[str, Any] = {} -def get_llm_client(hf_token: Optional[str] = None) -> InferenceClient: - """Create a HuggingFace InferenceClient per-request.""" +def get_base_llm(hf_token: Optional[str] = None): + """Initialize (or return cached) base LLM based on provider.""" + cache_key = settings.LLM_PROVIDER + if cache_key in _llm_cache: + return _llm_cache[cache_key] - token = hf_token or settings.HF_TOKEN - - if not token: - raise ValueError( - "Hugging Face API token is missing. Please configure HF_TOKEN." + if settings.LLM_PROVIDER == "gemini": + from langchain_google_genai import ChatGoogleGenerativeAI + llm = ChatGoogleGenerativeAI( + api_key=settings.GEMINI_API_KEY, + model=settings.LLM_MODEL or "gemini-2.0-flash", + temperature=settings.LLM_TEMPERATURE, + max_tokens=settings.LLM_MAX_NEW_TOKENS, + timeout=120, + ) + elif settings.LLM_PROVIDER == "groq": + from langchain_groq import ChatGroq + llm = ChatGroq( + api_key=settings.GROQ_API_KEY, + model=settings.LLM_MODEL or "llama-3.3-70b-versatile", + temperature=settings.LLM_TEMPERATURE, + max_tokens=settings.LLM_MAX_NEW_TOKENS, + timeout=120, ) + else: + from langchain_huggingface import HuggingFaceEndpoint + llm = HuggingFaceEndpoint( + model=settings.LLM_MODEL, + huggingfacehub_api_token=hf_token or settings.HF_TOKEN, + max_new_tokens=settings.LLM_MAX_NEW_TOKENS, + temperature=settings.LLM_TEMPERATURE, + timeout=120, + ) + + _llm_cache[cache_key] = llm + return llm + + +def get_llm_client(hf_token: Optional[str] = None) -> InferenceClient: + """Create a HuggingFace InferenceClient per-request (for simple tasks).""" + return InferenceClient( + token=hf_token or settings.HF_TOKEN, + ) + - return InferenceClient(token=token) +def is_greeting(question: str) -> bool: + """Detect if the question is a casual greeting rather than a document query.""" + greetings = { + "hi", "hello", "hey", "how are you", "what's up", "whats up", + "good morning", "good evening", "good afternoon", "thanks", "thank you", + "bye", "goodbye", "help", "what can you do", "who are you", + } + return question.lower().strip().rstrip("!?.") in greetings -def _format_chat_history(messages: List[Dict[str, str]]) -> str: - if not messages: - return "" - lines = ["Previous conversation:"] - for msg in messages: - role = "User" if msg["role"] == "user" else "Assistant" - lines.append(f"{role}: {msg['content']}") - return "\n".join(lines) +def _needs_tools(question: str) -> bool: + """Detect if the question needs math or web search tools.""" + q = question.lower() + math_keywords = ["calculate", "compute", "sum", "total", "add", "subtract", + "multiply", "divide", "percentage", "average", "mean", "="] + web_keywords = ["latest", "current news", "today", "live", "stock price", + "weather", "real-time", "search the web"] + return any(k in q for k in math_keywords + web_keywords) + + +def _format_context(chunks: List[Dict[str, Any]]) -> str: + """Format retrieved chunks into a clear context string for the LLM.""" + if not chunks: + return "No relevant document content found." + + parts = [] + for i, chunk in enumerate(chunks, 1): + filename = chunk.get("filename", "document") + page = chunk.get("page", "?") + text = chunk.get("text", "").strip() + confidence = chunk.get("confidence", 0) + parts.append( + f"[Excerpt {i} | Source: {filename}, Page {page} | Confidence: {confidence}%]\n" + f"{text}" + ) + return "\n\n---\n\n".join(parts) -def get_agent_executor( +def _direct_rag_answer( + question: str, user_id: str, document_id: Optional[str] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, - chat_history: Optional[List[Dict[str, str]]] = None, -): - """Initialize the LangChain ReAct agent executor.""" +) -> Dict[str, Any]: + """ + Fast direct RAG: retrieve chunks → build context → single LLM call. + Returns dict with 'answer' and 'sources'. + """ + # 1. Retrieve relevant chunks + chunks = retrieve( + query=question, + user_id=user_id, + document_id=document_id, + top_k=top_k or settings.TOP_K_RERANK, + ) - # Initialize tools - pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, top_k=top_k) - tools = [pdf_tool, MathTool(), WebSearchTool()] + # 2. Format context + context = _format_context(chunks) - # Initialize LLM - token = hf_token or settings.HF_TOKEN + # 3. Build prompt + user_prompt = RAG_PROMPT_TEMPLATE.format(context=context, question=question) - if not token: - raise ValueError( - "Hugging Face API token is missing. Please configure HF_TOKEN." - ) + # 4. Call LLM once + llm = get_base_llm(hf_token) + messages = [ + SystemMessage(content=SYSTEM_PROMPT), + HumanMessage(content=user_prompt), + ] + response = llm.invoke(messages) + answer = response.content.strip() if hasattr(response, "content") else str(response).strip() - llm = HuggingFaceEndpoint( - repo_id=settings.LLM_MODEL, - huggingfacehub_api_token=token, - max_new_tokens=settings.LLM_MAX_NEW_TOKENS, - temperature=settings.LLM_TEMPERATURE, - timeout=300, - ) + if not answer: + answer = "I was unable to generate a response. Please try rephrasing your question." - chat_llm = ChatHuggingFace(llm=llm) + # 5. Format sources + sources = [ + { + "text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""), + "filename": chunk["filename"], + "page": chunk["page"], + "score": chunk["score"], + "confidence": chunk.get("confidence", 0), + } + for chunk in chunks + ] - # Setup Agent - prompt = PromptTemplate.from_template(AGENT_SYSTEM_PROMPT) - agent = create_react_agent(chat_llm, tools, prompt) + return {"answer": answer, "sources": sources, "chunks": chunks} + + +def _react_agent_answer( + question: str, + user_id: str, + document_id: Optional[str] = None, + hf_token: Optional[str] = None, + top_k: Optional[int] = None, +) -> Dict[str, Any]: + """ + ReAct agent for math/web tool use (slower but handles tool calls). + """ + from langchain_classic.agents import create_react_agent, AgentExecutor + from langchain_core.prompts import PromptTemplate + from app.rag.tools import PDFSearchTool, MathTool, WebSearchTool + pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, top_k=top_k) + tools = [pdf_tool, MathTool(), WebSearchTool()] + llm = get_base_llm(hf_token) + prompt = PromptTemplate.from_template(AGENT_SYSTEM_PROMPT) + agent = create_react_agent(llm, tools, prompt) executor = AgentExecutor( agent=agent, tools=tools, - verbose=True, + verbose=False, handle_parsing_errors=True, - max_iterations=5, + max_iterations=4, ) - formatted_history = _format_chat_history(chat_history) if chat_history else "" + result = executor.invoke({"input": question}) + raw_answer = result.get("output", "") + try: + answer = parse_agent_output(raw_answer) + except OutputParserError as e: + logger.warning(f"Rejected malformed LLM output: {e}") + answer = MALFORMED_OUTPUT_MESSAGE - return executor, pdf_tool, formatted_history + sources = [ + { + "text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""), + "filename": chunk["filename"], + "page": chunk["page"], + "score": chunk["score"], + "confidence": chunk.get("confidence", 0), + } + for chunk in getattr(pdf_tool, "last_sources", []) + ] -def is_greeting(question: str) -> bool: - """Detect if the question is a casual greeting rather than a document query.""" - greetings = { - "hi", "hello", "hey", "how are you", "what's up", "whats up", - "good morning", "good evening", "good afternoon", "thanks", "thank you", - "bye", "goodbye", "help", "what can you do", "who are you", - } - return question.lower().strip().rstrip("!?.") in greetings + return {"answer": answer, "sources": sources} @trace_function( @@ -121,62 +222,46 @@ def generate_answer( document_id: Optional[str] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, - chat_history: Optional[List[Dict[str, str]]] = None, ) -> Dict[str, Any]: """ - Agentic generation: retrieve via tools → reason → generate answer. + Generate an answer using the fast direct RAG pipeline. + Falls back to ReAct agent for math/web queries. """ # ── Handle greetings ───────────────────────────── if is_greeting(question): - client = get_llm_client(hf_token) + llm = get_base_llm(hf_token) try: messages = [ - {"role": "system", "content": "You are Document AI Analyst, a friendly AI assistant."}, - {"role": "user", "content": question}, + SystemMessage( + content="You are Document AI Analyst, a friendly AI assistant " + "that helps users analyze uploaded documents." + ), + HumanMessage(content=question), ] - response = client.chat_completion( - messages=messages, - model=settings.LLM_MODEL, - max_tokens=256, - ) - answer = response.choices[0].message.content.strip() if response.choices else "Hello! How can I help you today?" - except Exception: + response = llm.invoke(messages) + answer = response.content.strip() if hasattr(response, "content") else str(response).strip() + if not answer: + answer = "Hello! How can I help you today?" + except Exception as e: + logger.error(f"Greeting execution error: {e}") answer = "Hello! I'm Document AI Analyst. How can I help you with your documents?" return {"answer": answer, "sources": []} - # ── Run Agent ──────────────────────────────────── + # ── Route to appropriate pipeline ──────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) - result = executor.invoke({"input": question, "chat_history": formatted_history}) - - raw_answer = result.get("output", "") - try: - answer = parse_agent_output(raw_answer) - except OutputParserError as e: - logger.warning(f"Rejected malformed LLM output: {e}") - answer = MALFORMED_OUTPUT_MESSAGE - - # Retrieve sources from the PDF tool if it was used - sources = [ - { - "text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""), - "filename": chunk["filename"], - "page": chunk["page"], - "score": chunk["score"], - "confidence": chunk.get("confidence", 0), - "bbox": chunk.get("bbox", ""), - } - for chunk in getattr(pdf_tool, "last_sources", []) - ] - - return {"answer": answer, "sources": sources} + if _needs_tools(question): + logger.info("Routing to ReAct agent (math/web tools needed)") + return _react_agent_answer(question, user_id, document_id, hf_token, top_k) + else: + logger.info("Routing to direct RAG pipeline") + return _direct_rag_answer(question, user_id, document_id, hf_token, top_k) - except (OutputParserError, ValueError) as e: - logger.warning(f"Agent output error: {e}") - return {"answer": MALFORMED_OUTPUT_MESSAGE, "sources": []} except Exception as e: logger.error(f"Agent execution error: {e}") - raise ExternalServiceException("HuggingFace", str(e)) from e + return { + "answer": f"I encountered an error while processing your request: {str(e)}", + "sources": [], + } @trace_function( @@ -193,67 +278,67 @@ def generate_answer_stream( document_id: Optional[str] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, - chat_history: Optional[List[Dict[str, str]]] = None, ) -> Generator[str, None, None]: """ - Streaming Agentic pipeline. + Streaming RAG pipeline — retrieves once then streams LLM tokens. + Much faster than the ReAct multi-turn loop. """ # ── Handle greetings ───────────────────────────── if is_greeting(question): yield f"data: {json.dumps({'type': 'sources', 'data': []})}\n\n" - client = get_llm_client(hf_token) + llm = get_base_llm(hf_token) try: - stream = client.chat_completion( - messages=[{"role": "user", "content": question}], - model=settings.LLM_MODEL, - max_tokens=256, - stream=True, - ) - for chunk in stream: - if chunk.choices and chunk.choices[0].delta.content: - yield f"data: {json.dumps({'type': 'token', 'data': chunk.choices[0].delta.content})}\n\n" + messages = [HumanMessage(content=question)] + for chunk in llm.stream(messages): + content = chunk.content if hasattr(chunk, "content") else str(chunk) + if content: + yield f"data: {json.dumps({'type': 'token', 'data': content})}\n\n" except Exception as e: + logger.error(f"Greeting streaming error: {e}") yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n" yield f"data: {json.dumps({'type': 'done'})}\n\n" return - # ── Run Agent ──────────────────────────────────── + # ── Direct RAG streaming ────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) - - sources_sent = False - - for step in executor.stream({"input": question, "chat_history": formatted_history}): - if "actions" in step: - continue - - elif "intermediate_steps" in step: - if not sources_sent and getattr(pdf_tool, "last_sources", []): - sources = [ - { - "text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""), - "filename": chunk["filename"], - "page": chunk["page"], - "score": chunk["score"], - "confidence": chunk.get("confidence", 0), - "bbox": chunk.get("bbox", ""), - } - for chunk in pdf_tool.last_sources - ] - yield f"data: {json.dumps({'type': 'sources', 'data': sources})}\n\n" - sources_sent = True - - elif "output" in step: - full_answer = step["output"] - try: - clean_answer = parse_agent_output(full_answer) - except OutputParserError as e: - logger.warning(f"Rejected malformed streamed LLM output: {e}") - clean_answer = MALFORMED_OUTPUT_MESSAGE - yield f"data: {json.dumps({'type': 'token', 'data': clean_answer})}\n\n" + # Step 1: Retrieve chunks first (fast) + chunks = retrieve( + query=question, + user_id=user_id, + document_id=document_id, + top_k=top_k or settings.TOP_K_RERANK, + ) + + # Step 2: Emit sources immediately so UI can show them + sources = [ + { + "text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""), + "filename": chunk["filename"], + "page": chunk["page"], + "score": chunk["score"], + "confidence": chunk.get("confidence", 0), + } + for chunk in chunks + ] + yield f"data: {json.dumps({'type': 'sources', 'data': sources})}\n\n" + + # Step 3: Build context and prompt + context = _format_context(chunks) + user_prompt = RAG_PROMPT_TEMPLATE.format(context=context, question=question) + llm = get_base_llm(hf_token) + messages = [ + SystemMessage(content=SYSTEM_PROMPT), + HumanMessage(content=user_prompt), + ] + + # Step 4: Stream LLM tokens + for chunk in llm.stream(messages): + content = chunk.content if hasattr(chunk, "content") else str(chunk) + if content: + yield f"data: {json.dumps({'type': 'token', 'data': content})}\n\n" except Exception as e: - logger.error(f"Agent streaming error: {e}") + logger.error(f"Streaming error: {e}") yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n" yield f"data: {json.dumps({'type': 'done'})}\n\n" diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index e542c17f..63c80bcd 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -31,12 +31,31 @@ def invoke(self, query): from app.rag.embeddings import embed_query from app.rag.tracing import trace_function from app.rag.vectorstore import query_chunks -from app.rag.reranker import get_reranker logger = logging.getLogger(__name__) settings = get_settings() MAX_QUERY_VARIANTS = 4 +# ── Singleton reranker ─────────────────────────────── +_reranker = None + + +def get_reranker(): + """Load cross-encoder reranker model (singleton).""" + global _reranker + + if _reranker is None: + try: + from sentence_transformers import CrossEncoder + logger.info(f"Loading reranker: {settings.RERANKER_MODEL}") + _reranker = CrossEncoder(settings.RERANKER_MODEL, max_length=512) + logger.info("Reranker loaded successfully") + except Exception as e: + logger.warning(f"Failed to load reranker: {e}. Falling back to embedding-only retrieval.") + _reranker = "disabled" + + return _reranker if _reranker != "disabled" else None + class CustomVectorRetriever(BaseRetriever): user_id: str = Field(description="User ID") @@ -90,7 +109,16 @@ def transform_query(query: str) -> List[str]: def _generate_query_variants(query: str) -> List[str]: - """Use the configured LLM to split/rewrite a user query for semantic search.""" + """Use the configured LLM to split/rewrite a user query for semantic search. + + Only uses the HuggingFace InferenceClient when LLM_PROVIDER is 'huggingface'. + Groq/Gemini models are not available on the HF router, so we skip query + transformation for those providers to avoid slow failing API calls per query. + """ + # Skip HF router for non-HuggingFace providers — their models aren't on the HF router + if settings.LLM_PROVIDER.lower() != "huggingface": + return [] + if not settings.HF_TOKEN: return [] @@ -206,9 +234,10 @@ def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: @trace_function( "retrieve", - metadata_factory=lambda query, user_id, document_id=None, top_k=None: { + metadata_factory=lambda query, user_id, document_id=None, top_k=None, **kwargs: { "user_id": user_id, "document_id": document_id, + "top_k": top_k, "embedding_model": settings.EMBEDDING_MODEL, "reranker_model": settings.RERANKER_MODEL, "top_k_retrieval": settings.TOP_K_RETRIEVAL, @@ -228,18 +257,20 @@ def retrieve( Returns chunks with confidence scores. """ + actual_top_k_retrieval = max(settings.TOP_K_RETRIEVAL, top_k * 2) if top_k else settings.TOP_K_RETRIEVAL + actual_top_k_rerank = top_k if top_k else settings.TOP_K_RERANK + # ── Stage 1: Hybrid Search with Query Transformation ───────────── - effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL vector_retriever = CustomVectorRetriever( user_id=user_id, document_id=document_id, - top_k=effective_top_k, + top_k=actual_top_k_retrieval, ) bm25_retriever = CustomBM25Retriever( user_id=user_id, document_id=document_id, - top_k=effective_top_k, + top_k=actual_top_k_retrieval, ) ensemble_retriever = EnsembleRetriever( @@ -264,30 +295,41 @@ def retrieve( # ── Stage 2: Cross-encoder reranking ───────────── reranker = get_reranker() - - if reranker is not None: - top_chunks = reranker.rerank( - query=query, - documents=candidates, - top_k=settings.TOP_K_RERANK - ) - else: - # Fall back to hybrid scores (no reranker) - candidates.sort(key=lambda x: x.get("score", 0), reverse=True) - top_chunks = candidates[:settings.TOP_K_RERANK] - # top_chunks is now always defined + if reranker is not None and len(candidates) > 1: + try: + # Build query-document pairs for reranking + pairs = [(query, chunk["text"]) for chunk in candidates] + rerank_scores = reranker.predict(pairs, convert_to_numpy=True) # type: ignore[call-overload] + + # Assign rerank scores + for i, chunk in enumerate(candidates): + chunk["rerank_score"] = float(rerank_scores[i]) + + # Sort by rerank score (descending) + candidates.sort(key=lambda x: x.get("rerank_score", 0), reverse=True) + + except Exception as e: + logger.warning(f"Reranking failed, using hybrid scores: {e}") + + # Ensure candidates are sorted by best available score + candidates.sort(key=lambda x: x.get("rerank_score", x.get("score", 0)), reverse=True) + + # ── Take top-K after reranking ─────────────────── + top_chunks = candidates[:actual_top_k_rerank] + # ── Calculate confidence percentages ───────────── if top_chunks: - max_score = max( - chunk.get("rerank_score", chunk.get("score", 0)) + max_score: float = max( + float(chunk.get("rerank_score") or chunk.get("score") or 0.0) for chunk in top_chunks ) - max_score = max(max_score, 0.001) + max_score = max(max_score, 0.001) # Avoid division by zero for chunk in top_chunks: - raw = chunk.get("rerank_score", chunk.get("score", 0)) + raw: float = float(chunk.get("rerank_score") or chunk.get("score") or 0.0) chunk["confidence"] = round((raw / max_score) * 100, 1) + # Clean up internal score if "rerank_score" in chunk: chunk["score"] = round(chunk["rerank_score"], 4) del chunk["rerank_score"] diff --git a/backend/app/routes/chat.py b/backend/app/routes/chat.py index 00f14cc3..41697858 100644 --- a/backend/app/routes/chat.py +++ b/backend/app/routes/chat.py @@ -1,27 +1,18 @@ """ Chat routes — ask questions with RAG, stream responses via SSE, manage history. """ - -import html import json import time from datetime import datetime, timezone -from io import BytesIO import logging from typing import Optional, List -from fastapi import APIRouter, Depends, Request, WebSocket, WebSocketDisconnect, Query +from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import Response, StreamingResponse from sqlalchemy.orm import Session from app.auth import get_current_user -from app.cache import get_cached_response, set_cached_response from app.database import get_db -from app.exceptions import ( - NotFoundException, - UnauthorizedException, - ValidationException, -) from app.metrics import record_query_response_time from app.models import User, ChatMessage, Document, SharedMessage, ChatSession from app.rate_limit import CHAT_QUERY_RATE_LIMIT, limiter @@ -44,170 +35,6 @@ router = APIRouter(prefix="/chat", tags=["Chat"]) -@router.websocket("/ws") -async def chat_ws(websocket: WebSocket, token: Optional[str] = Query(None)): - """WebSocket endpoint for streaming agentic thoughts and tokens. - - Authenticate via `token` query param or expect first JSON message - containing `{token, question, document_id?, session_id?}`. - """ - await websocket.accept() - - # Simple DB-backed auth similar to get_current_user - from app.database import SessionLocal - from app.auth import decode_token - from app.models import ApiKey, User - - db = SessionLocal() - user = None - - try: - # Try token from query param - if token: - tok = token - initial_payload = None - else: - # Expect first message to contain token and the payload - msg = await websocket.receive_json() - tok = msg.get("token") - initial_payload = msg - - if not tok: - await websocket.send_json({"type": "error", "data": "Missing token"}) - await websocket.close() - return - - # API key check - if tok.startswith("pdf_rag_"): - import hashlib - hashed = hashlib.sha256(tok.encode("utf-8")).hexdigest() - api_key = db.query(ApiKey).filter(ApiKey.hashed_key == hashed, ApiKey.is_active == True).first() - if not api_key: - await websocket.send_json({"type": "error", "data": "Invalid API key"}) - await websocket.close() - return - user = api_key.user - else: - user_id = decode_token(tok) - if not user_id: - await websocket.send_json({"type": "error", "data": "Invalid or expired token"}) - await websocket.close() - return - user = db.query(User).filter(User.id == user_id).first() - - if not user: - await websocket.send_json({"type": "error", "data": "User not found"}) - await websocket.close() - return - - # Receive or reuse initial payload - if initial_payload: - payload = initial_payload - else: - payload = await websocket.receive_json() - - question = payload.get("question") - document_id = payload.get("document_id") - session_id = payload.get("session_id") - - from app.rag.security import validate_user_input, UnsafePromptError - - try: - validate_user_input(question) - except UnsafePromptError as exc: - await websocket.send_json({"type": "error", "data": str(exc)}) - await websocket.close() - return - - # Validate document if given - if document_id: - doc = db.query(Document).filter( - Document.id == document_id, - Document.user_id == user.id, - Document.is_deleted.is_(False), - ).first() - if not doc: - await websocket.send_json({"type": "error", "data": "Document not found"}) - await websocket.close() - return - if doc.status != "ready": - progress = getattr(doc, "processing_progress", None) - stage = getattr(doc, "processing_stage", None) - detail = f"Document is still {doc.status}." - if progress is not None: - detail += f" Progress: {progress}%" - if stage: - detail += f" Stage: {stage}" - await websocket.send_json({"type": "error", "data": detail}) - await websocket.close() - return - - # Resolve or create session - if not session_id: - session = db.query(ChatSession).filter(ChatSession.user_id == user.id).first() - if not session: - session = ChatSession(user_id=user.id, title="Default Chat") - db.add(session) - db.commit() - db.refresh(session) - session_id = session.id - - # Build chat history - recent_messages = ( - db.query(ChatMessage) - .filter( - ChatMessage.session_id == session_id, - ChatMessage.user_id == user.id, - ) - .order_by(ChatMessage.created_at.desc()) - .limit(12) - .all() - ) - recent_messages.reverse() - chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] - - # Save user message - _save_message(db, user.id, document_id, "user", question, session_id=session_id) - - # Stream answer using existing generator and forward structured events - try: - for chunk in generate_answer_stream( - question=question, - user_id=user.id, - document_id=document_id, - hf_token=user.hf_token, - chat_history=chat_history, - ): - # chunk is SSE-style string like 'data: {json}\n\n' or similar - try: - if chunk.startswith("data: "): - payload = json.loads(chunk[6:].strip()) - await websocket.send_json(payload) - else: - # Fallback: send raw token - await websocket.send_json({"type": "token", "data": chunk}) - except Exception: - await websocket.send_json({"type": "token", "data": chunk}) - - # Notify client - await websocket.send_json({"type": "done"}) - - except WebSocketDisconnect: - return - except Exception as e: - await websocket.send_json({"type": "error", "data": str(e)}) - - except WebSocketDisconnect: - return - except Exception as e: - try: - await websocket.send_json({"type": "error", "data": str(e)}) - except Exception: - pass - finally: - db.close() - - @router.get( "/share/{message_id}", response_model=ShareAnswerResponse, @@ -227,17 +54,13 @@ def get_shared_answer( exposed. User prompts, private chat history, and unshared answers remain protected. """ - message = ( - db.query(ChatMessage) - .filter( - ChatMessage.id == message_id, - ChatMessage.role == "assistant", - ) - .first() - ) + message = db.query(ChatMessage).filter( + ChatMessage.id == message_id, + ChatMessage.role == "assistant", + ).first() if not message or not db.query(SharedMessage).filter(SharedMessage.message_id == message.id).first(): - raise NotFoundException("Shared answer") + raise HTTPException(status_code=404, detail="Shared answer not found") return _share_answer_response(message) @@ -247,7 +70,8 @@ def get_shared_answer( response_model=ShareLinkResponse, summary="Create a public share link for an assistant answer", description=( - "Marks one authenticated user's assistant message as shareable and " "returns the frontend share URL." + "Marks one authenticated user's assistant message as shareable and " + "returns the frontend share URL." ), ) def create_share_link( @@ -260,20 +84,16 @@ def create_share_link( The message must belong to the authenticated user and must have the assistant role. User-authored messages cannot be shared through this route. """ - message = ( - db.query(ChatMessage) - .filter( - ChatMessage.id == message_id, - ChatMessage.user_id == user.id, - ) - .first() - ) + message = db.query(ChatMessage).filter( + ChatMessage.id == message_id, + ChatMessage.user_id == user.id, + ).first() if not message: - raise NotFoundException("Message") + raise HTTPException(status_code=404, detail="Message not found") if message.role != "assistant": - raise ValidationException("Only assistant messages can be shared") + raise HTTPException(status_code=400, detail="Only assistant messages can be shared") shared_message = db.query(SharedMessage).filter(SharedMessage.message_id == message.id).first() if not shared_message: @@ -299,7 +119,10 @@ def get_chat_sessions( ): """Retrieve all chat sessions for the authenticated user.""" sessions = ( - db.query(ChatSession).filter(ChatSession.user_id == user.id).order_by(ChatSession.created_at.desc()).all() + db.query(ChatSession) + .filter(ChatSession.user_id == user.id) + .order_by(ChatSession.created_at.desc()) + .all() ) return sessions @@ -349,7 +172,7 @@ def rename_chat_session( .first() ) if not session: - raise NotFoundException("Chat session") + raise HTTPException(status_code=404, detail="Chat session not found") session.title = payload.title db.commit() db.refresh(session) @@ -376,7 +199,7 @@ def delete_chat_session( .first() ) if not session: - raise NotFoundException("Chat session") + raise HTTPException(status_code=404, detail="Chat session not found") db.delete(session) db.commit() return Response(status_code=204) @@ -403,7 +226,7 @@ def get_session_history( .first() ) if not session: - raise NotFoundException("Chat session") + raise HTTPException(status_code=404, detail="Chat session not found") messages = ( db.query(ChatMessage) @@ -443,7 +266,6 @@ def generate_answer( document_id: Optional[str] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, - chat_history: Optional[list] = None, ): from app.rag.agent import generate_answer as _generate_answer @@ -453,7 +275,6 @@ def generate_answer( document_id=document_id, hf_token=hf_token, top_k=top_k, - chat_history=chat_history, ) @@ -463,7 +284,6 @@ def generate_answer_stream( document_id: Optional[str] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, - chat_history: Optional[list] = None, ): from app.rag.agent import generate_answer_stream as _generate_answer_stream @@ -473,7 +293,6 @@ def generate_answer_stream( document_id=document_id, hf_token=hf_token, top_k=top_k, - chat_history=chat_history, ) @@ -499,33 +318,25 @@ def ask_question( try: validate_user_input(payload.question) except UnsafePromptError as exc: - raise ValidationException(str(exc)) from exc + raise HTTPException(status_code=400, detail=str(exc)) from exc # Validate document exists if specified if payload.document_id: - doc = ( - db.query(Document) - .filter( - Document.id == payload.document_id, - Document.user_id == user.id, - Document.is_deleted.is_(False), - ) - .first() - ) + doc = db.query(Document).filter( + Document.id == payload.document_id, + Document.user_id == user.id, + Document.is_deleted.is_(False), + ).first() if not doc: - raise NotFoundException("Document") + raise HTTPException(status_code=404, detail="Document not found") if doc.status != "ready": - progress = getattr(doc, "processing_progress", None) - stage = getattr(doc, "processing_stage", None) - detail = f"Document is still {doc.status}. Please wait for processing to complete." - if progress is not None: - detail += f" Progress: {progress}%" - if stage: - detail += f" Stage: {stage}" - raise ValidationException(detail) - + raise HTTPException( + status_code=400, + detail=f"Document is still {doc.status}. Please wait for processing to complete.", + ) + # Update last_accessed_at timestamp doc.last_accessed_at = datetime.now(timezone.utc) db.commit() @@ -541,53 +352,22 @@ def ask_question( db.refresh(session) session_id = session.id - # Build chat history from last 6 exchanges - recent_messages = ( - db.query(ChatMessage) - .filter( - ChatMessage.session_id == session_id, - ChatMessage.user_id == user.id, - ) - .order_by(ChatMessage.created_at.desc()) - .limit(12) - .all() - ) - recent_messages.reverse() - chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] - - # Cache check — return instantly if this (question, document) was answered before - cached_answer = get_cached_response( - document_id=str(payload.document_id or ""), - question=payload.question, - ) - if cached_answer is not None: - logger.debug("Returning cached response for question: %s", payload.question[:40]) - return ChatResponse( - answer=cached_answer, - sources=[], - document_id=payload.document_id, - ) - result = generate_answer( question=payload.question, user_id=user.id, document_id=payload.document_id, hf_token=user.hf_token, top_k=payload.top_k, - chat_history=chat_history, - ) - - # Store result in cache for future identical questions - set_cached_response( - document_id=str(payload.document_id or ""), - question=payload.question, - answer=result["answer"], ) # Save to chat history - _save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id) _save_message( - db, user.id, payload.document_id, "assistant", result["answer"], result["sources"], session_id=session_id + db, user.id, payload.document_id, "user", payload.question, + session_id=session_id, + ) + _save_message( + db, user.id, payload.document_id, "assistant", + result["answer"], result["sources"], session_id=session_id, ) return ChatResponse( @@ -618,33 +398,25 @@ def ask_question_stream( try: validate_user_input(payload.question) except UnsafePromptError as exc: - raise ValidationException(str(exc)) from exc + raise HTTPException(status_code=400, detail=str(exc)) from exc # Validate document if payload.document_id: - doc = ( - db.query(Document) - .filter( - Document.id == payload.document_id, - Document.user_id == user.id, - Document.is_deleted.is_(False), - ) - .first() - ) + doc = db.query(Document).filter( + Document.id == payload.document_id, + Document.user_id == user.id, + Document.is_deleted.is_(False), + ).first() if not doc: - raise NotFoundException("Document") + raise HTTPException(status_code=404, detail="Document not found") if doc.status != "ready": - progress = getattr(doc, "processing_progress", None) - stage = getattr(doc, "processing_stage", None) - detail = f"Document is still {doc.status}. Please wait for processing to complete." - if progress is not None: - detail += f" Progress: {progress}%" - if stage: - detail += f" Stage: {stage}" - raise ValidationException(detail) - + raise HTTPException( + status_code=400, + detail=f"Document is still {doc.status}. Please wait for processing to complete.", + ) + # Update last_accessed_at timestamp doc.last_accessed_at = datetime.now(timezone.utc) db.commit() @@ -662,47 +434,9 @@ def ask_question_stream( db.refresh(session) session_id = session.id - # Build chat history from last 6 exchanges (before saving current message) - recent_messages = ( - db.query(ChatMessage) - .filter( - ChatMessage.session_id == session_id, - ChatMessage.user_id == user.id, - ) - .order_by(ChatMessage.created_at.desc()) - .limit(12) - .all() - ) - recent_messages.reverse() - chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] - # Save user message immediately _save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id) - # Cache check before starting the stream - cached_answer = get_cached_response( - document_id=str(payload.document_id or ""), - question=payload.question, - ) - if cached_answer is not None: - logger.debug("Returning cached stream response for question: %s", payload.question[:40]) - - async def cached_event_stream(): - payload_json = json.dumps({"type": "token", "data": cached_answer}) - yield f"data: {payload_json}\n\n" - done_json = json.dumps({"type": "done"}) - yield f"data: {done_json}\n\n" - - return StreamingResponse( - cached_event_stream(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - # Stream response def event_stream(): full_answer = "" @@ -715,7 +449,6 @@ def event_stream(): document_id=payload.document_id, hf_token=user.hf_token, top_k=payload.top_k, - chat_history=chat_history, ): yield chunk @@ -730,21 +463,16 @@ def event_stream(): except Exception: pass - # Cache the full answer for future identical questions - if full_answer: - set_cached_response( - document_id=str(payload.document_id or ""), - question=payload.question, - answer=full_answer, - ) - # Save assistant response to history - from app.database import get_db_session - - with get_db_session() as save_db: + from app.database import SessionLocal + save_db = SessionLocal() + try: _save_message( - save_db, user.id, payload.document_id, "assistant", full_answer, sources, session_id=session_id + save_db, user.id, payload.document_id, + "assistant", full_answer, sources, session_id=session_id, ) + finally: + save_db.close() finally: record_query_response_time(time.perf_counter() - started_at) @@ -790,16 +518,14 @@ def get_chat_history( except Exception: pass - formatted.append( - ChatMessageResponse( - id=str(msg.id), - role=msg.role, - content=msg.content, - sources=sources, - feedback=msg.feedback, - created_at=msg.created_at, - ) - ) + formatted.append(ChatMessageResponse( + id=str(msg.id), + role=msg.role, + content=msg.content, + sources=sources, + feedback=msg.feedback, + created_at=msg.created_at, + )) return ChatHistoryResponse(messages=formatted, document_id=document_id) @@ -821,17 +547,18 @@ def export_chat_history( """Export the chat history for a document as a downloadable file.""" from app.auth import decode_token as _decode + # Resolve user from query-param token (browser download links can't set headers) resolved_user = None if token: user_id = _decode(token) if user_id: resolved_user = db.query(User).filter(User.id == user_id).first() - + if resolved_user is None: - raise UnauthorizedException("Authentication required") + raise HTTPException(status_code=401, detail="Authentication required") if format not in ("md", "txt", "pdf"): - raise ValidationException("Format must be 'md', 'txt', or 'pdf'") + raise HTTPException(status_code=400, detail="Format must be 'md', 'txt', or 'pdf'") # Verify document exists and belongs to user doc = db.query(Document).filter( @@ -841,7 +568,7 @@ def export_chat_history( ).first() if not doc: - raise NotFoundException("Document") + raise HTTPException(status_code=404, detail="Document not found") messages = ( db.query(ChatMessage) @@ -854,7 +581,7 @@ def export_chat_history( ) if not messages: - raise NotFoundException("Chat history") + raise HTTPException(status_code=404, detail="No chat history found for this document") if format == "md": content = _format_markdown(doc, messages) @@ -866,7 +593,6 @@ def export_chat_history( extension = "txt" else: from app.routes.chat_export import format_pdf as _format_pdf - content = _format_pdf(doc, messages) media_type = "application/pdf" extension = "pdf" @@ -931,9 +657,9 @@ def submit_feedback( ).first() if not msg: - raise NotFoundException("Message") + raise HTTPException(status_code=404, detail="Message not found") if msg.role != "assistant": - raise ValidationException("Can only provide feedback on assistant messages") + raise HTTPException(status_code=400, detail="Can only provide feedback on assistant messages") msg.feedback = payload.feedback db.commit() @@ -1026,11 +752,9 @@ def _format_markdown(doc, messages) -> str: lines.append("**Sources:**") lines.append("") for i, src in enumerate(sources, 1): - lines.append( - f"> **[{i}]** {src.get('filename', 'Unknown')}, " - f"Page {src.get('page', '?')} " - f"(Confidence: {src.get('confidence', 0)}%)" - ) + lines.append(f"> **[{i}]** {src.get('filename', 'Unknown')}, " + f"Page {src.get('page', '?')} " + f"(Confidence: {src.get('confidence', 0)}%)") text_preview = src.get("text", "")[:150] if text_preview: lines.append(f"> {text_preview}...") @@ -1069,11 +793,9 @@ def _format_plaintext(doc, messages) -> str: lines.append("") lines.append("Sources:") for i, src in enumerate(sources, 1): - lines.append( - f" [{i}] {src.get('filename', 'Unknown')}, " - f"Page {src.get('page', '?')} " - f"(Confidence: {src.get('confidence', 0)}%)" - ) + lines.append(f" [{i}] {src.get('filename', 'Unknown')}, " + f"Page {src.get('page', '?')} " + f"(Confidence: {src.get('confidence', 0)}%)") except Exception: pass diff --git a/backend/app/schemas.py b/backend/app/schemas.py index c630538f..5dd69b24 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -1,31 +1,10 @@ """ Pydantic schemas for API request/response validation. """ -from pydantic import BaseModel, EmailStr, Field, field_validator -from typing import Optional, List, Any +from pydantic import BaseModel, EmailStr, Field, ConfigDict +from typing import Optional, List from datetime import datetime -from app.models import UserRole -from app.password_validation import validate_password - - -class ErrorDetail(BaseModel): - field: str - message: str - - -class ErrorEnvelope(BaseModel): - code: str - message: str - details: dict[str, Any] = {} - request_id: str | None = None - - -class ErrorResponse(BaseModel): - error: ErrorEnvelope - - -class ValidationErrorResponse(BaseModel): - error: ErrorEnvelope +from .models import UserRole # ── Auth ───────────────────────────────────────────── @@ -33,13 +12,7 @@ class ValidationErrorResponse(BaseModel): class UserRegister(BaseModel): username: str = Field(..., min_length=3, max_length=80) email: EmailStr - password: str = Field(..., min_length=8) - - @field_validator("password") - @classmethod - def validate_password_strength(cls, value: str) -> str: - validate_password(value) - return value + password: str = Field(..., min_length=6) class UserLogin(BaseModel): @@ -47,44 +20,25 @@ class UserLogin(BaseModel): password: str -class EmailVerificationRequest(BaseModel): - email: EmailStr - - -class MessageResponse(BaseModel): - message: str - - -class RegistrationResponse(MessageResponse): - email: EmailStr - verification_url: Optional[str] = None - - class GoogleLoginRequest(BaseModel): id_token: str = Field(..., min_length=10) class UserUpdate(BaseModel): email: Optional[EmailStr] = None - username:Optional[str] = None - -class UserProfileUpdate(BaseModel): username: Optional[str] = None - display_name: Optional[str] = None + + class UserUpdateResponse(BaseModel): id: str username: str email: EmailStr + class UpdatePassword(BaseModel): password: str confirm_password: str - @field_validator("password") - @classmethod - def validate_password_strength(cls, value: str) -> str: - validate_password(value) - return value class UpdatePasswordResponse(BaseModel): id: str @@ -93,19 +47,6 @@ class UpdatePasswordResponse(BaseModel): password_changed: bool = True -class WorkspaceInviteRequest(BaseModel): - email: EmailStr - workspace_name: str = Field(..., min_length=1, max_length=100) - message: Optional[str] = None - - -class WorkspaceInviteResponse(BaseModel): - email: EmailStr - workspace_name: str - invite_link: str - expires_in_hours: int - - class TokenResponse(BaseModel): access_token: str refresh_token: str @@ -122,22 +63,13 @@ class HFTokenUpdate(BaseModel): hf_token: str -class GoogleDriveAuthUrlResponse(BaseModel): - auth_url: str - - -class GoogleDriveStatusResponse(BaseModel): - connected: bool - - class ApiKeyResponse(BaseModel): id: str name: str key_preview: str created_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ApiKeyCreateResponse(BaseModel): @@ -147,8 +79,7 @@ class ApiKeyCreateResponse(BaseModel): created_at: datetime raw_key: str - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class UserResponse(BaseModel): @@ -157,14 +88,10 @@ class UserResponse(BaseModel): email: str role: UserRole is_admin: bool - is_verified: bool hf_token: Optional[str] = None - display_name: Optional[str] = None - avatar_url: Optional[str] = None created_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) # ── Documents ──────────────────────────────────────── @@ -178,24 +105,9 @@ class DocumentResponse(BaseModel): status: str error_message: Optional[str] = None uploaded_at: datetime - summary: Optional[str] = None # New field for document summary - task_id: Optional[str] = None - extracted_urls: Optional[List[str]] = None - - class Config: - from_attributes = True - + summary: Optional[str] = None # New field for document summary -class DocumentRename(BaseModel): - name: str = Field(..., min_length=1, max_length=255) - - @field_validator("name") - @classmethod - def validate_name(cls, value: str) -> str: - stripped = value.strip() - if not stripped: - raise ValueError("Document name cannot be empty") - return stripped + model_config = ConfigDict(from_attributes=True) class DocumentStatusResponse(BaseModel): @@ -204,15 +116,8 @@ class DocumentStatusResponse(BaseModel): page_count: int chunk_count: int error_message: Optional[str] = None - processing_progress: Optional[int] = None - processing_stage: Optional[str] = None - retry_count: Optional[int] = None - last_error_traceback: Optional[str] = None - processing_started_at: Optional[datetime] = None - completed_at: Optional[datetime] = None - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class DocumentListResponse(BaseModel): @@ -250,7 +155,7 @@ class ChatRequest(BaseModel): document_id: Optional[str] = None document_ids: Optional[List[str]] = None session_id: Optional[str] = None - top_k: int = Field(default=5, ge=1, le=20) + top_k: Optional[int] = Field(None, ge=1, le=50) class SourceChunk(BaseModel): @@ -259,7 +164,6 @@ class SourceChunk(BaseModel): page: int score: float confidence: float - bbox: Optional[str] = None class ChatResponse(BaseModel): @@ -280,22 +184,26 @@ class ChatMessageResponse(BaseModel): feedback: Optional[str] = None created_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) class ChatHistoryResponse(BaseModel): messages: List[ChatMessageResponse] document_id: Optional[str] = None -# Chunk settings schema for optional chunk size and overlap parameters in document processing +# Chunk settings schema for optional chunk size and overlap +# parameters in document processing + + class ChunkSettings(BaseModel): chunk_size: int | None chunk_overlap: int | None - + + class UploadUrl(BaseModel): url: str + class ShareAnswerResponse(BaseModel): id: str content: str @@ -308,10 +216,6 @@ class ShareLinkResponse(BaseModel): share_url: str -class FeedbackRequest(BaseModel): - feedback: Optional[str] = None - - # ── Chat Session ────────────────────────────────────── class ChatSessionCreate(BaseModel): @@ -323,8 +227,7 @@ class ChatSessionResponse(BaseModel): title: str created_at: datetime - class Config: - from_attributes = True + model_config = ConfigDict(from_attributes=True) # Rebuild models for forward references diff --git a/frontend/src/components/chat/ChatPanel.tsx b/frontend/src/components/chat/ChatPanel.tsx index dde5d025..7b9194cc 100644 --- a/frontend/src/components/chat/ChatPanel.tsx +++ b/frontend/src/components/chat/ChatPanel.tsx @@ -1,17 +1,16 @@ "use client"; -import { toast } from "sonner"; import { useState, useRef, useEffect } from "react"; import { useTranslation } from "react-i18next"; import type { DocInfo } from "@/app/dashboard/page"; import { api, API_BASE } from "@/lib/api"; -import { useChatStore, type ChatMsg, type SourceBoundingBox, type SourceChunk } from "@/store/chat-store"; +import { useChatStore, type ChatMsg, type SourceChunk } from "@/store/chat-store"; import { Button } from "@/components/ui/button"; import { Skeleton } from "@/components/ui/skeleton"; import { Textarea } from "@/components/ui/textarea"; import MessageBubble from "./MessageBubble"; import SourceCard from "./SourceCard"; -import { Send, Loader2, Trash2, MessageSquare, Download, Mic, MicOff, HelpCircle } from "lucide-react"; +import { Send, Loader2, Trash2, MessageSquare, Download, Mic, MicOff, Settings2 } from "lucide-react"; import { cn } from "@/lib/utils"; interface ISpeechRecognitionEvent { @@ -48,14 +47,9 @@ interface WindowWithSpeech extends Window { webkitSpeechRecognition?: new () => ISpeechRecognition; } -interface CitationTarget { - page: number; - highlightRects?: SourceBoundingBox[]; -} - interface Props { activeDoc: DocInfo | null; - onCitationClick: (target: CitationTarget) => void; + onCitationClick: (page: number) => void; } export default function ChatPanel({ activeDoc, onCitationClick }: Props) { @@ -66,27 +60,25 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { const isTyping = useChatStore((state) => state.isTyping); const historyLoading = useChatStore((state) => state.historyLoading); const activeSessionId = useChatStore((state) => state.activeSessionId); + const topK = useChatStore((state) => state.topK); const setMessages = useChatStore((state) => state.setMessages); const setInput = useChatStore((state) => state.setInput); + const setTopK = useChatStore((state) => state.setTopK); const setStreaming = useChatStore((state) => state.setStreaming); const setIsTyping = useChatStore((state) => state.setIsTyping); const resetChat = useChatStore((state) => state.resetChat); const fetchSessionHistory = useChatStore((state) => state.fetchSessionHistory); - const [showExportMenu, setShowExportMenu] = useState(false); - const MAX_CHARACTERS = 2000; + const [showSettingsMenu, setShowSettingsMenu] = useState(false); const [isRecording, setIsRecording] = useState(false); const [speechError, setSpeechError] = useState(null); - - // New State for Keyboard Shortcuts Help Modal - const [showHelpModal, setShowHelpModal] = useState(false); - const recognitionRef = useRef(null); const initialInputRef = useRef(""); const textareaRef = useRef(null); const bottomRef = useRef(null); const prevDocId = useRef(null); const exportMenuRef = useRef(null); + const settingsMenuRef = useRef(null); const showEmptyState = messages.length === 0 && !isTyping && !historyLoading; @@ -179,6 +171,7 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { }; setMessages((prev) => [...prev, userMsg]); + const assistantId = `assistant-${Date.now()}`; let assistantCreated = false; @@ -186,144 +179,79 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { setIsTyping(true); try { - // Try WebSocket first for real-time agentic thought streaming - const token = typeof window !== "undefined" ? localStorage.getItem("token") : null; - const base = API_BASE || window.location.origin; - const wsScheme = base.startsWith("https") ? "wss" : base.startsWith("http") ? "ws" : "wss"; - const host = base.replace(/^https?:/, ""); - const wsUrl = `${wsScheme}:${host}/api/v1/chat/ws${token ? `?token=${encodeURIComponent(token)}` : ""}`; - - const ws = new WebSocket(wsUrl); - - const wsDone = new Promise((resolve, reject) => { - ws.onopen = () => { - // Send initial payload - ws.send(JSON.stringify({ question, document_id: activeDoc?.id || null, session_id: activeSessionId })); - }; - - // If WS doesn't open within 800ms, treat as failure and fallback - const connectTimeout = setTimeout(() => { - try { - ws.close(); - } catch (e) { - // ignore - } - reject(new Error("WebSocket connection timeout")); - }, 800); - - ws.onmessage = (ev) => { - clearTimeout(connectTimeout); - try { - const event = JSON.parse(ev.data); - if (event.type === "token") { - if (!assistantCreated) { - assistantCreated = true; - setIsTyping(false); - - const assistantMsg: ChatMsg = { - id: assistantId, - role: "assistant", - content: event.data as string, - sources: [], - isStreaming: true, - }; - - setMessages((prev) => [...prev, assistantMsg]); - } else { - setMessages((prev) => - prev.map((m) => (m.id === assistantId ? { ...m, content: m.content + (event.data as string) } : m)) - ); - } - } else if (event.type === "sources") { - setMessages((prev) => prev.map((m) => (m.id === assistantId ? { ...m, sources: event.data as SourceChunk[] } : m))); - } else if (event.type === "thought") { - // Append thoughts as a temporary assistant note (optional UI handling) - // For simplicity, add to assistant message content in brackets - setMessages((prev) => - prev.map((m) => (m.id === assistantId ? { ...m, content: m.content + `\n[thought] ${event.data}` } : m)) - ); - } else if (event.type === "error") { - setIsTyping(false); - setMessages((prev) => prev.map((m) => (m.id === assistantId ? { ...m, content: `Error: ${event.data}`, isStreaming: false } : m))); - ws.close(); - reject(new Error(String(event.data))); - } else if (event.type === "done") { - setMessages((prev) => prev.map((m) => (m.id === assistantId ? { ...m, isStreaming: false } : m))); - ws.close(); - resolve(); - } - } catch (err) { - // ignore malformed messages - } - }; - - ws.onerror = (ev) => { - clearTimeout(connectTimeout); - reject(new Error("WebSocket error")); - }; - - ws.onclose = () => { - resolve(); - }; + const stream = api.streamPost("/api/v1/chat/ask/stream", { + question, + document_id: activeDoc?.id || null, + session_id: activeSessionId, + ...(topK ? { top_k: topK } : {}), }); - await wsDone; - } catch (err) { - // Fallback to existing SSE stream if WebSocket fails - try { - const stream = api.streamPost("/api/v1/chat/ask/stream", { - question, - document_id: activeDoc?.id || null, - session_id: activeSessionId, - }); - - for await (const event of stream) { - if (event.type === "token") { - if (!assistantCreated) { - assistantCreated = true; - setIsTyping(false); - - const assistantMsg: ChatMsg = { - id: assistantId, - role: "assistant", - content: event.data as string, - sources: [], - isStreaming: true, - }; - - setMessages((prev) => [...prev, assistantMsg]); - } else { - setMessages((prev) => - prev.map((m) => (m.id === assistantId ? { ...m, content: m.content + (event.data as string) } : m)) - ); - } - } else if (event.type === "sources") { - setMessages((prev) => prev.map((m) => (m.id === assistantId ? { ...m, sources: event.data as SourceChunk[] } : m))); - } else if (event.type === "error") { + for await (const event of stream) { + if (event.type === "token") { + // Create assistant message only when first token arrives + if (!assistantCreated) { + assistantCreated = true; setIsTyping(false); - setMessages((prev) => prev.map((m) => (m.id === assistantId ? { ...m, content: `Error: ${event.data}`, isStreaming: false } : m))); - } else if (event.type === "done") { - setMessages((prev) => prev.map((m) => (m.id === assistantId ? { ...m, isStreaming: false } : m))); + + const assistantMsg: ChatMsg = { + id: assistantId, + role: "assistant", + content: event.data as string, + sources: [], + isStreaming: true, + }; + + setMessages((prev) => [...prev, assistantMsg]); + } else { + setMessages((prev) => + prev.map((m) => + m.id === assistantId + ? { ...m, content: m.content + (event.data as string) } + : m + ) + ); } + } else if (event.type === "sources") { + setMessages((prev) => + prev.map((m) => + m.id === assistantId + ? { ...m, sources: event.data as SourceChunk[] } + : m + ) + ); + } else if (event.type === "error") { + setIsTyping(false); + setMessages((prev) => + prev.map((m) => + m.id === assistantId + ? { ...m, content: `Error: ${event.data}`, isStreaming: false } + : m + ) + ); + } else if (event.type === "done") { + setMessages((prev) => + prev.map((m) => + m.id === assistantId ? { ...m, isStreaming: false } : m + ) + ); } - } catch (err2) { - setIsTyping(false); - setMessages((prev) => - prev.map((m) => - m.id === assistantId - ? { - ...m, - content: t("chat.fallbackError", { - message: err2 instanceof Error ? err2.message : "Unknown error", - }), - isStreaming: false, - } - : m - ) - ); } + } catch (err) { + setIsTyping(false); + setMessages((prev) => + prev.map((m) => + m.id === assistantId + ? { + ...m, + content: t("chat.fallbackError", { + message: err instanceof Error ? err.message : "Unknown error", + }), + isStreaming: false, + } + : m + ) + ); } finally { - setStreaming(false); setIsTyping(false); } @@ -334,9 +262,8 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { try { await api.delete(`/api/v1/chat/history/${activeDoc.id}`); setMessages([]); - toast.info("Chat history cleared"); } catch { - // silent fail preserved; no additional toast for this scenario + //silent fail } }; @@ -354,17 +281,19 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { document.body.removeChild(a); }; - // Close export dropdown on outside click + // Close menus on outside click useEffect(() => { - if (!showExportMenu) return; const handleClickOutside = (e: MouseEvent) => { - if (exportMenuRef.current && !exportMenuRef.current.contains(e.target as Node)) { + if (showExportMenu && exportMenuRef.current && !exportMenuRef.current.contains(e.target as Node)) { setShowExportMenu(false); } + if (showSettingsMenu && settingsMenuRef.current && !settingsMenuRef.current.contains(e.target as Node)) { + setShowSettingsMenu(false); + } }; document.addEventListener("mousedown", handleClickOutside); return () => document.removeEventListener("mousedown", handleClickOutside); - }, [showExportMenu]); + }, [showExportMenu, showSettingsMenu]); // Cleanup speech recognition on unmount useEffect(() => { @@ -476,50 +405,15 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { } }; - const handleExportMenuKeyDown = (e: React.KeyboardEvent) => { + const handleMenuKeyDown = (e: React.KeyboardEvent, menuType: "export" | "settings") => { if (e.key === "Escape") { - setShowExportMenu(false); + if (menuType === "export") setShowExportMenu(false); + if (menuType === "settings") setShowSettingsMenu(false); } }; - // ── NEW KEYBOARD SHORTCUTS ENGINE EFFECT ────────────────────────── - useEffect(() => { - const handleGlobalKeyDown = (e: KeyboardEvent) => { - const isCmdOrCtrl = e.metaKey || e.ctrlKey; - - // Shortcut 1: Ctrl/Cmd + Enter -> Send Message (When textarea has focus) - if (isCmdOrCtrl && e.key === "Enter") { - if (document.activeElement === textareaRef.current) { - e.preventDefault(); - handleSend(); - } - } - - // Shortcut 2: Escape -> Clear Input / Close Modal - if (e.key === "Escape") { - if (document.activeElement === textareaRef.current) { - e.preventDefault(); - setInput(""); // Clear textarea state - } else if (showHelpModal) { - setShowHelpModal(false); // Close shortcuts modal if open - } - } - - // Shortcut 3: Ctrl/Cmd + K -> Focus chat input from anywhere - if (isCmdOrCtrl && (e.key === "k" || e.key === "K")) { - e.preventDefault(); - textareaRef.current?.focus(); - } - }; - - window.addEventListener("keydown", handleGlobalKeyDown); - return () => { - window.removeEventListener("keydown", handleGlobalKeyDown); - }; - }, [input, streaming, showHelpModal]); // Dependencies updated to capture fresh state data - return ( -
+
{/* ── Chat Messages ──────────────────────────── */}
{historyLoading ? ( @@ -619,7 +513,6 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) {