From 0ac2bcbe547e6e388029675dd98994ebbfc4115f Mon Sep 17 00:00:00 2001 From: hrshjswniii Date: Sun, 7 Jun 2026 19:29:19 +0530 Subject: [PATCH 1/6] [BUG] : Leaking Soft-Deleted Documents in Global Chat RAG Retrieval --- backend/app/rag/bm25.py | 6 ++++ backend/app/rag/retriever.py | 35 ++++++++++++++++++ backend/tests/test_retriever.py | 64 ++++++++++++++++++++++++++++++++- 3 files changed, 104 insertions(+), 1 deletion(-) diff --git a/backend/app/rag/bm25.py b/backend/app/rag/bm25.py index 82298bf5..de30a522 100644 --- a/backend/app/rag/bm25.py +++ b/backend/app/rag/bm25.py @@ -108,6 +108,7 @@ def query_bm25( query: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, top_k: int = 10, ) -> List[Dict[str, Any]]: """ @@ -127,6 +128,11 @@ def query_bm25( all_results = [] for path in glob.glob(os.path.join(user_dir, "*.pkl")): + # Filter by document_ids if provided + if document_ids is not None: + doc_id = os.path.basename(path).rsplit(".", 1)[0] + if doc_id not in document_ids: + continue results = _query_single_index(path, tokenized_query, top_k) all_results.extend(results) diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index e542c17f..b5b7d63f 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -41,6 +41,7 @@ def invoke(self, query): class CustomVectorRetriever(BaseRetriever): user_id: str = Field(description="User ID") document_id: Optional[str] = Field(default=None, description="Document ID") + document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs") top_k: int = Field(default=10, description="Top K results") def _get_relevant_documents( @@ -51,6 +52,7 @@ def _get_relevant_documents( query_embedding=query_vector, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, top_k=self.top_k, ) return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates] @@ -59,6 +61,7 @@ def _get_relevant_documents( class CustomBM25Retriever(BaseRetriever): user_id: str = Field(description="User ID") document_id: Optional[str] = Field(default=None, description="Document ID") + document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs") top_k: int = Field(default=10, description="Top K results") def _get_relevant_documents( @@ -69,11 +72,13 @@ def _get_relevant_documents( query=query, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, top_k=self.top_k, ) return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates] + def transform_query(query: str) -> List[str]: """Rewrite a user question into multiple retrieval-friendly search queries.""" original_query = query.strip() @@ -228,17 +233,47 @@ def retrieve( Returns chunks with confidence scores. """ + from app.database import SessionLocal + from app.models import Document + + # Fetch active document IDs + db = SessionLocal() + try: + if document_id: + # Check if specific document is active (not deleted) + doc = db.query(Document).filter( + Document.id == document_id, + Document.user_id == user_id, + Document.is_deleted.is_(False), + ).first() + if not doc: + return [] + active_doc_ids = [str(doc.id)] + else: + # Check all active documents for this user + docs = db.query(Document).filter( + Document.user_id == user_id, + Document.is_deleted.is_(False), + ).all() + if not docs: + return [] + active_doc_ids = [str(doc.id) for doc in docs] + finally: + db.close() + # ── Stage 1: Hybrid Search with Query Transformation ───────────── effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL vector_retriever = CustomVectorRetriever( user_id=user_id, document_id=document_id, + document_ids=active_doc_ids if not document_id else None, top_k=effective_top_k, ) bm25_retriever = CustomBM25Retriever( user_id=user_id, document_id=document_id, + document_ids=active_doc_ids if not document_id else None, top_k=effective_top_k, ) diff --git a/backend/tests/test_retriever.py b/backend/tests/test_retriever.py index 6045dde4..66e29083 100644 --- a/backend/tests/test_retriever.py +++ b/backend/tests/test_retriever.py @@ -30,7 +30,25 @@ def test_retrieve_fans_out_transformed_queries_and_merges_duplicates(monkeypatch monkeypatch.setattr(retriever, "embed_query", lambda query: f"embedding:{query}") monkeypatch.setattr(retriever, "get_reranker", lambda: None) - def fake_query_chunks(query_embedding, user_id, document_id=None, top_k=10): + # Mock SessionLocal and Document + class MockDoc: + id = "policy.pdf" + + class MockQuery: + def filter(self, *args, **kwargs): + return self + def all(self): + return [MockDoc()] + + class MockSession: + def query(self, *args, **kwargs): + return MockQuery() + def close(self): + pass + + monkeypatch.setattr("app.database.SessionLocal", lambda: MockSession()) + + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): searched_queries.append(query_embedding) if query_embedding == "embedding:taxes": return [ @@ -75,3 +93,47 @@ def fake_query_chunks(query_embedding, user_id, document_id=None, top_k=10): assert [chunk["id"] for chunk in chunks] == ["shared", "taxes", "healthcare"] assert chunks[0]["score"] == 1.0 assert chunks[0]["confidence"] == 100.0 + + +def test_retrieve_excludes_soft_deleted_documents(db_session, user, monkeypatch): + from app.models import Document + from app.rag import retriever + + # Create one active document and one deleted document + active_doc = Document( + id="active-doc-id", + user_id=user.id, + filename="active.pdf", + original_name="active.pdf", + is_deleted=False, + ) + deleted_doc = Document( + id="deleted-doc-id", + user_id=user.id, + filename="deleted.pdf", + original_name="deleted.pdf", + is_deleted=True, + ) + db_session.add(active_doc) + db_session.add(deleted_doc) + db_session.commit() + + monkeypatch.setattr("app.database.SessionLocal", lambda: db_session) + monkeypatch.setattr(retriever, "transform_query", lambda _query: ["query"]) + monkeypatch.setattr(retriever, "embed_query", lambda query: "embedding") + monkeypatch.setattr(retriever, "get_reranker", lambda: None) + + captured_doc_ids = [] + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): + nonlocal captured_doc_ids + captured_doc_ids = document_ids + return [] + + monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks) + monkeypatch.setattr(retriever.CustomBM25Retriever, "_get_relevant_documents", lambda *args, **kwargs: []) + + retriever.retrieve("test query", user_id=user.id) + + # Should only query for the active document ID + assert captured_doc_ids == ["active-doc-id"] + From 2328fc5ea867a1196096e4db42b7adbe13d50c43 Mon Sep 17 00:00:00 2001 From: hrshjswniii Date: Sun, 7 Jun 2026 20:02:28 +0530 Subject: [PATCH 2/6] feat: implement collaborative workspaces and fix unit test issues --- backend/app/database.py | 1 + backend/app/models.py | 33 +++ backend/app/rag/agent.py | 9 +- backend/app/rag/retriever.py | 48 +++- backend/app/rag/tools.py | 2 + backend/app/routes/chat.py | 10 +- backend/app/routes/documents.py | 140 +++++++++--- backend/app/routes/workspaces.py | 173 +++++++++++++- backend/app/schemas.py | 18 ++ backend/tests/test_rag_tools.py | 2 +- backend/tests/test_workspaces.py | 154 ++++++++++--- frontend/src/app/dashboard/page.tsx | 21 +- frontend/src/app/invite/page.tsx | 215 ++++++++++++++++++ frontend/src/components/chat/ChatPanel.tsx | 3 + .../components/document/DocumentSidebar.tsx | 5 +- 15 files changed, 737 insertions(+), 97 deletions(-) create mode 100644 frontend/src/app/invite/page.tsx diff --git a/backend/app/database.py b/backend/app/database.py index b7fa2bd2..1ffea5ae 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -113,6 +113,7 @@ def _migrate_schema(): ("documents", "drive_file_id", "ALTER TABLE documents ADD COLUMN drive_file_id VARCHAR(255)"), ("documents", "drive_folder_id", "ALTER TABLE documents ADD COLUMN drive_folder_id VARCHAR(255)"), ("documents", "drive_synced_at", "ALTER TABLE documents ADD COLUMN drive_synced_at TIMESTAMP"), + ("documents", "workspace_id", "ALTER TABLE documents ADD COLUMN workspace_id VARCHAR(36)"), ] for table, column, ddl in docs_migrations: if column not in existing_docs_columns: diff --git a/backend/app/models.py b/backend/app/models.py index 25587fc0..0b134edb 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -163,6 +163,11 @@ class User(Base): back_populates="user", cascade="all, delete-orphan", ) + workspace_memberships = relationship( + "WorkspaceMembership", + back_populates="user", + cascade="all, delete-orphan", + ) class ApiKey(Base): @@ -185,6 +190,32 @@ class ApiKey(Base): user = relationship("User", back_populates="api_keys") +class Workspace(Base): + __tablename__ = "workspaces" + + id = Column(String(36), primary_key=True, default=generate_uuid) + name = Column(String(100), nullable=False) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + + # Relationships + memberships = relationship("WorkspaceMembership", back_populates="workspace", cascade="all, delete-orphan") + documents = relationship("Document", back_populates="workspace") + + +class WorkspaceMembership(Base): + __tablename__ = "workspace_memberships" + + id = Column(String(36), primary_key=True, default=generate_uuid) + workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=False, index=True) + user_id = Column(GUID, ForeignKey("users.id"), nullable=False, index=True) + role = Column(String(20), default="member", nullable=False) # "admin" | "member" + joined_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + + # Relationships + workspace = relationship("Workspace", back_populates="memberships") + user = relationship("User", back_populates="workspace_memberships") + + class WorkspaceInvitation(Base): __tablename__ = "workspace_invitations" @@ -254,9 +285,11 @@ class Document(Base): drive_synced_at = Column(DateTime, nullable=True) is_deleted = Column(Boolean, default=False, nullable=False, index=True) deleted_at = Column(DateTime, nullable=True) + workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=True, index=True) # Relationships owner = relationship("User", back_populates="documents") + workspace = relationship("Workspace", back_populates="documents") messages = relationship( "ChatMessage", back_populates="document", diff --git a/backend/app/rag/agent.py b/backend/app/rag/agent.py index ceae0a7e..64a90e3b 100644 --- a/backend/app/rag/agent.py +++ b/backend/app/rag/agent.py @@ -55,11 +55,12 @@ def get_agent_executor( hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, + workspace: Optional[str] = None, ): """Initialize the LangChain ReAct agent executor.""" # Initialize tools - pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, top_k=top_k) + pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, workspace=workspace, top_k=top_k) tools = [pdf_tool, MathTool(), WebSearchTool()] # Initialize LLM @@ -121,6 +122,7 @@ def generate_answer( hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, + workspace: Optional[str] = None, ) -> Dict[str, Any]: """ Agentic generation: retrieve via tools → reason → generate answer. @@ -145,7 +147,7 @@ def generate_answer( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) + executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history, workspace) result = executor.invoke({"input": question, "chat_history": formatted_history}) raw_answer = result.get("output", "") @@ -193,6 +195,7 @@ def generate_answer_stream( hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, + workspace: Optional[str] = None, ) -> Generator[str, None, None]: """ Streaming Agentic pipeline. @@ -218,7 +221,7 @@ def generate_answer_stream( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) + executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history, workspace) sources_sent = False diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index b5b7d63f..65d96700 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -224,6 +224,7 @@ def retrieve( query: str, user_id: str, document_id: Optional[str] = None, + workspace: Optional[str] = None, top_k: Optional[int] = None, ) -> List[Dict[str, Any]]: """ @@ -234,7 +235,7 @@ def retrieve( Returns chunks with confidence scores. """ from app.database import SessionLocal - from app.models import Document + from app.models import Document, WorkspaceMembership # Fetch active document IDs db = SessionLocal() @@ -243,18 +244,51 @@ def retrieve( # Check if specific document is active (not deleted) doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user_id, Document.is_deleted.is_(False), ).first() if not doc: return [] + + # Verify user has access to the document + has_access = False + if str(doc.user_id) == str(user_id): + has_access = True + elif doc.workspace_id: + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == doc.workspace_id, + WorkspaceMembership.user_id == user_id, + ).first() + if membership: + has_access = True + + if not has_access: + return [] active_doc_ids = [str(doc.id)] else: - # Check all active documents for this user - docs = db.query(Document).filter( - Document.user_id == user_id, - Document.is_deleted.is_(False), - ).all() + # Filter documents by workspace + filters = [Document.is_deleted.is_(False)] + if workspace == "company": + memberships = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user_id + ).all() + workspace_ids = [m.workspace_id for m in memberships] + if not workspace_ids: + return [] + filters.append(Document.workspace_id.in_(workspace_ids)) + elif workspace == "personal" or not workspace: + filters.append(Document.user_id == user_id) + filters.append(Document.workspace_id.is_(None)) + else: + # Specific workspace ID + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace, + WorkspaceMembership.user_id == user_id, + ).first() + if not membership: + return [] + filters.append(Document.workspace_id == workspace) + + docs = db.query(Document).filter(*filters).all() if not docs: return [] active_doc_ids = [str(doc.id) for doc in docs] diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py index 03813756..c03c56ab 100644 --- a/backend/app/rag/tools.py +++ b/backend/app/rag/tools.py @@ -156,6 +156,7 @@ class PDFSearchTool(BaseTool): user_id: str document_id: Optional[str] = None + workspace: Optional[str] = None top_k: Optional[int] = None # We'll store sources here to retrieve them after agent execution last_sources: List[Dict[str, Any]] = [] @@ -167,6 +168,7 @@ def _run(self, query: str) -> str: query=query, user_id=self.user_id, document_id=self.document_id, + workspace=self.workspace, top_k=self.top_k, ) diff --git a/backend/app/routes/chat.py b/backend/app/routes/chat.py index 8bfd8c61..09fa7a41 100644 --- a/backend/app/routes/chat.py +++ b/backend/app/routes/chat.py @@ -262,16 +262,16 @@ def get_session_history( return ChatHistoryResponse(messages=formatted, document_id=None) -def generate_answer(question: str, user_id: str, document_id: Optional[str] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None): +def generate_answer(question: str, user_id: str, document_id: Optional[str] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, workspace: Optional[str] = None): from app.rag.agent import generate_answer as _generate_answer - return _generate_answer(question=question, user_id=user_id, document_id=document_id, hf_token=hf_token, top_k=top_k, chat_history=chat_history) + return _generate_answer(question=question, user_id=user_id, document_id=document_id, hf_token=hf_token, top_k=top_k, chat_history=chat_history, workspace=workspace) -def generate_answer_stream(question: str, user_id: str, document_id: Optional[str] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None): +def generate_answer_stream(question: str, user_id: str, document_id: Optional[str] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, workspace: Optional[str] = None): from app.rag.agent import generate_answer_stream as _generate_answer_stream - return _generate_answer_stream(question=question, user_id=user_id, document_id=document_id, hf_token=hf_token, top_k=top_k, chat_history=chat_history) + return _generate_answer_stream(question=question, user_id=user_id, document_id=document_id, hf_token=hf_token, top_k=top_k, chat_history=chat_history, workspace=workspace) @router.post( @@ -351,6 +351,7 @@ def ask_question( hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, + workspace=payload.workspace, ) # Save to chat history @@ -451,6 +452,7 @@ def event_stream(): hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, + workspace=payload.workspace, ): yield chunk diff --git a/backend/app/routes/documents.py b/backend/app/routes/documents.py index 2c7d8fac..ad5df028 100644 --- a/backend/app/routes/documents.py +++ b/backend/app/routes/documents.py @@ -14,12 +14,12 @@ import shutil import tempfile from urllib.parse import urlparse -from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status, Query, BackgroundTasks +from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status, Query, BackgroundTasks, Form from fastapi.responses import FileResponse from sqlalchemy.orm import Session from app.database import get_db -from app.models import User, Document +from app.models import User, Document, Workspace, WorkspaceMembership from app.schemas import ( DocumentResponse, DocumentListResponse, @@ -51,6 +51,50 @@ router = APIRouter(prefix="/documents", tags=["Documents"]) +def get_target_workspace_id(workspace_param: Optional[str], user_id: str, db: Session) -> Optional[str]: + if not workspace_param or workspace_param == "personal": + return None + + if workspace_param == "company": + # Check if the user is a member of any workspace + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user_id + ).first() + if membership: + return membership.workspace_id + + # If not, create a default "Company" workspace for them + workspace = Workspace(name="Company") + db.add(workspace) + db.commit() + db.refresh(workspace) + + membership = WorkspaceMembership( + workspace_id=workspace.id, + user_id=user_id, + role="admin", + ) + db.add(membership) + db.commit() + return workspace.id + + # If a specific workspace ID was passed, just return it + return workspace_param + + +def check_document_access(doc: Document, user_id: str, db: Session) -> bool: + if str(doc.user_id) == str(user_id): + return True + if doc.workspace_id: + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == doc.workspace_id, + WorkspaceMembership.user_id == user_id, + ).first() + if membership: + return True + return False + + ALLOWED_MIME_TYPES = settings.ALLOWED_MIME_TYPES @@ -169,6 +213,7 @@ async def _crawl(): @router.post("/upload", response_model=DocumentResponse, status_code=status.HTTP_202_ACCEPTED) async def upload_document( file: UploadFile = File(...), + workspace: Optional[str] = Form(None), background_tasks: BackgroundTasks = None, user: User = Depends(get_current_user), db: Session = Depends(get_db), @@ -184,10 +229,11 @@ async def upload_document( Args: file: The uploaded file, provided as a multipart/form-data field in the request. + workspace: Optional workspace context ('personal', 'company', or specific UUID). background_tasks: FastAPI BackgroundTasks instance for in-process fallback execution. user: The currently authenticated user, injected by the `get_current_user` dependency. db: Database session, injected by the `get_db` dependency. - + Returns: DocumentResponse: The created document record, validated against the response model (includes id, filename, original_name, file_size, status, etc.). @@ -225,6 +271,11 @@ async def upload_document( file_size = Path(filepath).stat().st_size + # Resolve target workspace + if not isinstance(workspace, str): + workspace = None + target_workspace_id = get_target_workspace_id(workspace, user.id, db) + # ── Create database record ─────────────────────── document = Document( user_id=user.id, @@ -232,6 +283,7 @@ async def upload_document( original_name=file.filename, file_size=file_size, status="pending", + workspace_id=target_workspace_id, ) db.add(document) db.commit() @@ -324,6 +376,9 @@ async def upload_document_url( url_path = parsed.path.rstrip("/") original_name = f"{parsed.netloc}{url_path or ''}.txt" + # Resolve target workspace + target_workspace_id = get_target_workspace_id(payload.workspace, user.id, db) + # ── Create database record ───────────────────────────── document = Document( user_id=user.id, @@ -331,6 +386,7 @@ async def upload_document_url( original_name=original_name, file_size=file_size, status="pending", + workspace_id=target_workspace_id, ) db.add(document) db.commit() @@ -390,11 +446,10 @@ def get_document_status( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise HTTPException(status_code=404, detail="Document not found") return DocumentStatusResponse.model_validate(doc) @@ -404,36 +459,53 @@ def get_document_status( def list_documents( page: int = Query(1, ge=1), per_page: int = Query(20, ge=1), + workspace: Optional[str] = Query(None), user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """ - List all documents for the authenticated user with pagination. + List all documents for the authenticated user with pagination, filtered by workspace. - Returns a paginated list of documents belonging to the current user, + Returns a paginated list of documents belonging to the current user or their workspaces, ordered by upload date (newest first). - - Args: - page: The page number to retrieve (1: indexed). Defaults to 1. - per_page: The number of documents to return per page. Defaults to 20. - user: The currently authenticated user, injected by the `get_current_user` dependency. - db: Database session, injected by the `get_db` dependency. - - Returns: - DocumentListResponse: A response model containing: - - items: A list of DocumentResponse objects for the current page. - - total: The total number of documents for the user. - - page: The current page number. - - pages: The total number of pages available. """ """Number of rows to skip""" skip: int = (page - 1) * per_page + # Base query filters + filters = [Document.is_deleted.is_(False)] + + if workspace == "company": + memberships = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user.id + ).all() + workspace_ids = [m.workspace_id for m in memberships] + if not workspace_ids: + return DocumentListResponse(items=[], total=0, page=page, pages=0) + filters.append(Document.workspace_id.in_(workspace_ids)) + elif workspace == "personal" or not workspace: + # Default to personal: own documents with no workspace + filters.append(Document.user_id == user.id) + filters.append(Document.workspace_id.is_(None)) + else: + # Filter by a specific workspace ID + # Verify user has access to this workspace + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace, + WorkspaceMembership.user_id == user.id, + ).first() + if not membership: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have access to this workspace.", + ) + filters.append(Document.workspace_id == workspace) + """Total Pages""" totalDocuments = ( db.query(Document) - .filter(Document.user_id == user.id, Document.is_deleted.is_(False)) + .filter(*filters) .count() ) """Total Pages""" @@ -442,7 +514,7 @@ def list_documents( """List all documents for the authenticated user in Paginated form""" docs = (( db.execute(select(Document) - .where(Document.user_id == user.id, Document.is_deleted.is_(False)) + .where(*filters) .order_by(Document.uploaded_at.desc()) .limit(per_page).offset(skip)) ) @@ -510,11 +582,10 @@ def get_document( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise HTTPException(status_code=404, detail="Document not found") return DocumentResponse.model_validate(doc) @@ -547,14 +618,13 @@ def serve_pdf( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise HTTPException(status_code=404, detail="Document not found") - filepath = os.path.join(settings.UPLOAD_DIR, user.id, doc.filename) + filepath = os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename) if not os.path.exists(filepath): raise HTTPException(status_code=404, detail="File not found on disk") @@ -597,11 +667,10 @@ def delete_document( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise HTTPException(status_code=404, detail="Document not found") doc.is_deleted = True @@ -637,14 +706,13 @@ def update_chunk_settings( HTTPException: With status code 404 if the document is not found or does not belong to the authenticated user. HTTPException: With status code 400 if the provided chunk size or overlap values are invalid (e.g., chunk size less than 100, or overlap greater than or equal to chunk size). """ - # Validate if the document exists and belongs to the user + # Validate if the document exists and belongs to the user or workspace doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise HTTPException(status_code=404, detail="Document not found") if settings_update.chunk_size is not None: @@ -674,9 +742,9 @@ def update_chunk_settings( try: task = process_document.delay( document_id=doc.id, - filepath=os.path.join(settings.UPLOAD_DIR, user.id, doc.filename), + filepath=os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename), original_name=doc.original_name, - user_id=user.id, + user_id=str(doc.user_id), ) task_id = task.id except Exception as e: @@ -685,9 +753,9 @@ def update_chunk_settings( background_tasks.add_task( ingest_document, document_id=doc.id, - filepath=os.path.join(settings.UPLOAD_DIR, user.id, doc.filename), + filepath=os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename), original_name=doc.original_name, - user_id=user.id, + user_id=str(doc.user_id), ) task_id = f"local_{uuid.uuid4().hex}" diff --git a/backend/app/routes/workspaces.py b/backend/app/routes/workspaces.py index 7ce2de63..8157389a 100644 --- a/backend/app/routes/workspaces.py +++ b/backend/app/routes/workspaces.py @@ -8,12 +8,17 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from app.auth import create_invite_token, get_admin_user +from app.auth import create_invite_token, get_current_user, decode_invite_token from app.config import get_settings from app.database import get_db from app.email_service import send_email -from app.models import User, WorkspaceInvitation -from app.schemas import WorkspaceInviteRequest, WorkspaceInviteResponse +from app.models import User, Workspace, WorkspaceMembership, WorkspaceInvitation +from app.schemas import ( + WorkspaceInviteRequest, + WorkspaceInviteResponse, + WorkspaceInviteVerifyResponse, + WorkspaceInviteAcceptResponse, +) router = APIRouter(prefix="/workspaces", tags=["Workspaces"]) settings = get_settings() @@ -23,24 +28,59 @@ @router.post("/invite", response_model=WorkspaceInviteResponse, status_code=status.HTTP_200_OK) def invite_workspace( payload: WorkspaceInviteRequest, - admin_user: User = Depends(get_admin_user), + current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """Invite a user by email to join a workspace via a secure time-bound token.""" - existing_user = db.query(User).filter(User.email == payload.email).first() - if existing_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="A user with this email already exists.", + # 1. Lookup or create the Workspace + workspace = db.query(Workspace).filter(Workspace.name == payload.workspace_name).first() + if not workspace: + workspace = Workspace(name=payload.workspace_name) + db.add(workspace) + db.commit() + db.refresh(workspace) + + # Add the inviter as the admin of this workspace + membership = WorkspaceMembership( + workspace_id=workspace.id, + user_id=current_user.id, + role="admin", ) + db.add(membership) + db.commit() + else: + # Verify if current_user is a member of this workspace + user_membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace.id, + WorkspaceMembership.user_id == current_user.id, + ).first() + if not user_membership: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to invite users to this workspace.", + ) + + # 2. Check if the invited user is already a member of this workspace + invited_user = db.query(User).filter(User.email == payload.email).first() + if invited_user: + existing_membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace.id, + WorkspaceMembership.user_id == invited_user.id, + ).first() + if existing_membership: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="The user is already a member of this workspace.", + ) - token = create_invite_token(admin_user.id, payload.email, payload.workspace_name) + # 3. Create invitation + token = create_invite_token(current_user.id, payload.email, payload.workspace_name) token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.INVITE_TOKEN_EXPIRY_HOURS) invitation = WorkspaceInvitation( email=payload.email, - inviter_id=admin_user.id, + inviter_id=current_user.id, token_hash=token_hash, workspace_name=payload.workspace_name, expires_at=expires_at, @@ -71,3 +111,114 @@ def invite_workspace( invite_link=join_link, expires_in_hours=settings.INVITE_TOKEN_EXPIRY_HOURS, ) + + +@router.get("/invite/verify", response_model=WorkspaceInviteVerifyResponse, status_code=status.HTTP_200_OK) +def verify_workspace_invite( + token: str, + db: Session = Depends(get_db), +): + """Verify an invitation token and return workspace details.""" + payload = decode_invite_token(token) + if not payload: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired invitation token.", + ) + + token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() + invitation = db.query(WorkspaceInvitation).filter( + WorkspaceInvitation.token_hash == token_hash + ).first() + + if not invitation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Invitation not found.", + ) + + inviter = db.query(User).filter(User.id == invitation.inviter_id).first() + inviter_email = inviter.email if inviter else "unknown" + inviter_username = inviter.username if inviter else "unknown" + + is_expired = invitation.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc) + + return WorkspaceInviteVerifyResponse( + workspace_name=invitation.workspace_name, + inviter_email=inviter_email, + inviter_username=inviter_username, + email=invitation.email, + expires_at=invitation.expires_at, + is_expired=is_expired, + is_accepted=invitation.accepted_at is not None, + ) + + +@router.post("/invite/accept", response_model=WorkspaceInviteAcceptResponse, status_code=status.HTTP_200_OK) +def accept_workspace_invite( + token: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Accept a workspace invitation using the time-bound token.""" + payload = decode_invite_token(token) + if not payload: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired invitation token.", + ) + + token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() + invitation = db.query(WorkspaceInvitation).filter( + WorkspaceInvitation.token_hash == token_hash + ).first() + + if not invitation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Invitation not found.", + ) + + if invitation.accepted_at is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="This invitation has already been accepted.", + ) + + if invitation.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="This invitation has expired.", + ) + + # Lookup or create the workspace + workspace = db.query(Workspace).filter( + Workspace.name == invitation.workspace_name + ).first() + if not workspace: + workspace = Workspace(name=invitation.workspace_name) + db.add(workspace) + db.commit() + db.refresh(workspace) + + # Check if the user is already a member + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace.id, + WorkspaceMembership.user_id == current_user.id, + ).first() + + if not membership: + membership = WorkspaceMembership( + workspace_id=workspace.id, + user_id=current_user.id, + role="member", + ) + db.add(membership) + + invitation.accepted_at = datetime.now(timezone.utc) + db.commit() + + return WorkspaceInviteAcceptResponse( + message="Invitation accepted successfully.", + workspace_name=workspace.name, + ) diff --git a/backend/app/schemas.py b/backend/app/schemas.py index f6c0c752..7c592aa3 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -86,6 +86,21 @@ class WorkspaceInviteResponse(BaseModel): expires_in_hours: int +class WorkspaceInviteVerifyResponse(BaseModel): + workspace_name: str + inviter_email: str + inviter_username: str + email: str + expires_at: datetime + is_expired: bool + is_accepted: bool + + +class WorkspaceInviteAcceptResponse(BaseModel): + message: str + workspace_name: str + + class TokenResponse(BaseModel): access_token: str refresh_token: str @@ -160,6 +175,7 @@ class DocumentResponse(BaseModel): uploaded_at: datetime summary: Optional[str] = None # New field for document summary task_id: Optional[str] = None + workspace_id: Optional[str] = None class Config: from_attributes = True @@ -224,6 +240,7 @@ class ChatRequest(BaseModel): document_ids: Optional[List[str]] = None session_id: Optional[str] = None top_k: int = Field(default=5, ge=1, le=20) + workspace: Optional[str] = None class SourceChunk(BaseModel): @@ -268,6 +285,7 @@ class ChunkSettings(BaseModel): class UploadUrl(BaseModel): url: str + workspace: Optional[str] = None class ShareAnswerResponse(BaseModel): id: str diff --git a/backend/tests/test_rag_tools.py b/backend/tests/test_rag_tools.py index 30bbc9fa..12eddb81 100644 --- a/backend/tests/test_rag_tools.py +++ b/backend/tests/test_rag_tools.py @@ -154,7 +154,7 @@ def test_pdf_search_tool_formats_chunks_and_graph_context(monkeypatch): retrieve_calls = [] graph_calls = [] - def fake_retrieve(query, user_id, document_id=None, top_k=None): + def fake_retrieve(query, user_id, document_id=None, workspace=None, top_k=None, **kwargs): retrieve_calls.append((query, user_id, document_id)) return chunks diff --git a/backend/tests/test_workspaces.py b/backend/tests/test_workspaces.py index 8bf5610a..01ae03f0 100644 --- a/backend/tests/test_workspaces.py +++ b/backend/tests/test_workspaces.py @@ -1,8 +1,20 @@ -from app.auth import create_access_token, hash_password -from app.models import User, WorkspaceInvitation +import hashlib +from datetime import datetime, timedelta, timezone +from app.auth import create_access_token, hash_password, create_invite_token +from app.models import User, Workspace, WorkspaceMembership, WorkspaceInvitation + + +def test_workspace_invite_creates_workspace_and_membership_for_user(client, db_session, user, monkeypatch): + sent = {} + + def fake_send_email(to, subject, body, html=None): + sent["to"] = to + sent["subject"] = subject + sent["body"] = body + + monkeypatch.setattr("app.routes.workspaces.send_email", fake_send_email) -def test_workspace_invite_requires_admin(client, db_session, user): token = create_access_token(user.id) response = client.post( "/api/v1/workspaces/invite", @@ -10,47 +22,125 @@ def test_workspace_invite_requires_admin(client, db_session, user): json={"email": "invitee@example.com", "workspace_name": "Engineering"}, ) - assert response.status_code == 403 - assert response.json()["detail"] == "Admin access required" + assert response.status_code == 200 + payload = response.json() + assert payload["email"] == "invitee@example.com" + assert payload["workspace_name"] == "Engineering" + assert "invite_link" in payload + + # Verify workspace and membership were created + workspace = db_session.query(Workspace).filter_by(name="Engineering").first() + assert workspace is not None + + membership = db_session.query(WorkspaceMembership).filter_by( + workspace_id=workspace.id, user_id=user.id + ).first() + assert membership is not None + assert membership.role == "admin" -def test_workspace_invite_creates_invitation_and_sends_email(client, db_session, monkeypatch): - admin = User( - username="admin", - email="admin@example.com", +def test_workspace_invite_existing_member_fails(client, db_session, user, monkeypatch): + # Setup workspace and membership + workspace = Workspace(name="Marketing") + db_session.add(workspace) + db_session.commit() + + member = User( + username="member", + email="member@example.com", hashed_password=hash_password("password123"), - is_admin=True, ) - db_session.add(admin) + db_session.add(member) db_session.commit() - db_session.refresh(admin) - sent = {} - - def fake_send_email(to, subject, body, html=None): - sent["to"] = to - sent["subject"] = subject - sent["body"] = body - - monkeypatch.setattr("app.routes.workspaces.send_email", fake_send_email) + membership1 = WorkspaceMembership(workspace_id=workspace.id, user_id=user.id, role="admin") + membership2 = WorkspaceMembership(workspace_id=workspace.id, user_id=member.id, role="member") + db_session.add(membership1) + db_session.add(membership2) + db_session.commit() - token = create_access_token(admin.id) + token = create_access_token(user.id) response = client.post( "/api/v1/workspaces/invite", headers={"Authorization": f"Bearer {token}"}, - json={"email": "invitee@example.com", "workspace_name": "Engineering"}, + json={"email": "member@example.com", "workspace_name": "Marketing"}, + ) + + assert response.status_code == 400 + assert "already a member" in response.json()["detail"] + + +def test_workspace_invite_verify(client, db_session, user): + invite_token = create_invite_token(user.id, "invitee@example.com", "Sales") + token_hash = hashlib.sha256(invite_token.encode("utf-8")).hexdigest() + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) + + invitation = WorkspaceInvitation( + email="invitee@example.com", + inviter_id=user.id, + token_hash=token_hash, + workspace_name="Sales", + expires_at=expires_at, + ) + db_session.add(invitation) + db_session.commit() + + response = client.get( + f"/api/v1/workspaces/invite/verify?token={invite_token}" ) assert response.status_code == 200 payload = response.json() + assert payload["workspace_name"] == "Sales" assert payload["email"] == "invitee@example.com" - assert payload["workspace_name"] == "Engineering" - assert "invite_link" in payload - assert payload["invite_link"].startswith("http") - assert "token=" in payload["invite_link"] - assert sent["to"] == "invitee@example.com" - assert "Invitation to join workspace" in sent["subject"] - - invitation = db_session.query(WorkspaceInvitation).filter_by(email="invitee@example.com").first() - assert invitation is not None - assert invitation.workspace_name == "Engineering" + assert payload["inviter_email"] == user.email + assert payload["is_expired"] is False + assert payload["is_accepted"] is False + + +def test_workspace_invite_accept(client, db_session, user): + invitee = User( + username="invitee", + email="invitee@example.com", + hashed_password=hash_password("password123"), + ) + db_session.add(invitee) + db_session.commit() + + invite_token = create_invite_token(user.id, "invitee@example.com", "Support") + token_hash = hashlib.sha256(invite_token.encode("utf-8")).hexdigest() + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) + + invitation = WorkspaceInvitation( + email="invitee@example.com", + inviter_id=user.id, + token_hash=token_hash, + workspace_name="Support", + expires_at=expires_at, + ) + db_session.add(invitation) + db_session.commit() + + # Pre-create workspace to map + workspace = Workspace(name="Support") + db_session.add(workspace) + db_session.commit() + + token = create_access_token(invitee.id) + response = client.post( + f"/api/v1/workspaces/invite/accept?token={invite_token}", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 200 + assert response.json()["workspace_name"] == "Support" + + # Verify membership and invitation status + membership = db_session.query(WorkspaceMembership).filter_by( + workspace_id=workspace.id, user_id=invitee.id + ).first() + assert membership is not None + assert membership.role == "member" + + db_session.refresh(invitation) + assert invitation.accepted_at is not None diff --git a/frontend/src/app/dashboard/page.tsx b/frontend/src/app/dashboard/page.tsx index f5a25f0a..3384e242 100644 --- a/frontend/src/app/dashboard/page.tsx +++ b/frontend/src/app/dashboard/page.tsx @@ -10,6 +10,7 @@ import Header from "@/components/layout/Header"; import DocumentSidebar from "@/components/document/DocumentSidebar"; import ChatSessionSidebar from "@/components/chat/ChatSessionSidebar"; import ChatPanel from "@/components/chat/ChatPanel"; +import { useWorkspaceStore } from "@/store/workspace-store"; function PDFViewerSkeleton() { return (
s.workspace); const [documents, setDocuments] = useState([]); const prevDocsRef = useRef>({}); @@ -91,6 +93,21 @@ export default function DashboardPage() { if (initialized && !user) router.replace("/login"); }, [user, initialized, router]); + // Handle pending workspace invitations after login/register + useEffect(() => { + if (initialized && user) { + try { + const pendingToken = sessionStorage.getItem("pending_invite_token"); + if (pendingToken) { + sessionStorage.removeItem("pending_invite_token"); + router.replace(`/invite?token=${encodeURIComponent(pendingToken)}`); + } + } catch (e) { + console.warn("sessionStorage not accessible", e); + } + } + }, [user, initialized, router]); + // Check if Hugging Face token configuration is present useEffect(() => { if (user) { @@ -110,7 +127,7 @@ export default function DashboardPage() { setDocumentsLoading(true); try { const data = await api.get<{ documents?: DocInfo[]; items?: DocInfo[] }>( - "/api/v1/documents/" + `/api/v1/documents/?workspace=${encodeURIComponent(workspace)}` ); setDocuments(data?.documents ?? data?.items ?? []); setConnectionError(""); @@ -124,7 +141,7 @@ export default function DashboardPage() { } finally { setDocumentsLoading(false); } - }, []); + }, [workspace]); useEffect(() => { if (!user) return; diff --git a/frontend/src/app/invite/page.tsx b/frontend/src/app/invite/page.tsx new file mode 100644 index 00000000..0d7a5ce1 --- /dev/null +++ b/frontend/src/app/invite/page.tsx @@ -0,0 +1,215 @@ +"use client"; + +import { useEffect, useState, Suspense } from "react"; +import { useRouter, useSearchParams } from "next/navigation"; +import { api } from "@/lib/api"; +import { useAuth } from "@/lib/auth"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Brain, CheckCircle2, AlertTriangle, Loader2, ArrowRight } from "lucide-react"; +import { toast } from "sonner"; +import { useWorkspaceStore } from "@/store/workspace-store"; + +interface InviteInfo { + workspace_name: string; + inviter_email: string; + inviter_username: string; + email: string; + expires_at: string; + is_expired: boolean; + is_accepted: boolean; +} + +function InviteContent() { + const router = useRouter(); + const searchParams = useSearchParams(); + const { user, initialized } = useAuth(); + const setWorkspace = useWorkspaceStore((s) => s.setWorkspace); + + const token = searchParams.get("token"); + + const [loading, setLoading] = useState(true); + const [accepting, setAccepting] = useState(false); + const [error, setError] = useState(""); + const [inviteInfo, setInviteInfo] = useState(null); + + useEffect(() => { + if (!token) { + setError("Invitation token is missing. Please check the link in your email."); + setLoading(false); + return; + } + + // Save token to sessionStorage in case they need to log in/register + try { + sessionStorage.setItem("pending_invite_token", token); + } catch (e) { + console.warn("sessionStorage is not available", e); + } + + api + .get(`/api/v1/workspaces/invite/verify?token=${encodeURIComponent(token)}`) + .then((data) => { + setInviteInfo(data); + if (data.is_expired) { + setError(`This invitation has expired (expired at ${new Date(data.expires_at).toLocaleString()}).`); + } else if (data.is_accepted) { + setError("This invitation has already been accepted."); + } + }) + .catch((err) => { + setError(err instanceof Error ? err.message : "Failed to verify invitation token."); + }) + .finally(() => { + setLoading(false); + }); + }, [token]); + + const handleAccept = async () => { + if (!token) return; + setAccepting(true); + + try { + await api.post(`/api/v1/workspaces/invite/accept?token=${encodeURIComponent(token)}`); + toast.success(`🎉 Welcome to the '${inviteInfo?.workspace_name}' workspace!`); + + // Clean up sessionStorage + try { + sessionStorage.removeItem("pending_invite_token"); + } catch (e) { + // ignore + } + + // Switch to company workspace in store so they see the documents immediately + setWorkspace("company"); + + router.push("/dashboard"); + } catch (err) { + toast.error(err instanceof Error ? err.message : "Failed to accept invitation."); + setAccepting(false); + } + }; + + if (loading || !initialized) { + return ( +
+ +

Verifying invitation details...

+
+ ); + } + + if (error) { + return ( + + +
+
+ +
+
+ Invitation Error + {error} +
+ + + +
+ ); + } + + const isLoggedIn = !!user; + + return ( + + +
+
+ +
+
+ Workspace Invitation + + You've been invited to join a collaborative workspace + +
+ + +
+

Workspace

+

{inviteInfo?.workspace_name}

+

+ Invited by {inviteInfo?.inviter_username} ({inviteInfo?.inviter_email}) +

+
+ + {isLoggedIn ? ( +
+

+ You are logged in as {user.username} ({user.email}). +

+ +
+ ) : ( +
+
+ Please log in or register an account to accept this invitation. +
+
+ + +
+
+ )} +
+
+ ); +} + +export default function InvitePage() { + return ( +
+ {/* Aesthetic blur gradients */} +
+
+ + + +

Loading...

+
+ }> + + +
+ ); +} diff --git a/frontend/src/components/chat/ChatPanel.tsx b/frontend/src/components/chat/ChatPanel.tsx index 25aceacc..119effc0 100644 --- a/frontend/src/components/chat/ChatPanel.tsx +++ b/frontend/src/components/chat/ChatPanel.tsx @@ -13,6 +13,7 @@ import MessageBubble from "./MessageBubble"; import SourceCard from "./SourceCard"; import { Send, Loader2, Trash2, MessageSquare, Download, Mic, MicOff, HelpCircle } from "lucide-react"; import { cn } from "@/lib/utils"; +import { useWorkspaceStore } from "@/store/workspace-store"; interface ISpeechRecognitionEvent { resultIndex: number; @@ -60,6 +61,7 @@ interface Props { export default function ChatPanel({ activeDoc, onCitationClick }: Props) { const { t, i18n } = useTranslation(); + const workspace = useWorkspaceStore((s) => s.workspace); const messages = useChatStore((state) => state.messages); const input = useChatStore((state) => state.input); const streaming = useChatStore((state) => state.streaming); @@ -190,6 +192,7 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { question, document_id: activeDoc?.id || null, session_id: activeSessionId, + workspace, }); for await (const event of stream) { diff --git a/frontend/src/components/document/DocumentSidebar.tsx b/frontend/src/components/document/DocumentSidebar.tsx index a4c158e5..d1079b12 100644 --- a/frontend/src/components/document/DocumentSidebar.tsx +++ b/frontend/src/components/document/DocumentSidebar.tsx @@ -17,6 +17,7 @@ import { useDropzone } from "react-dropzone"; import { Settings } from "lucide-react"; import DocumentSettings from "./DocumentSettings"; import { toast } from "sonner"; +import { useWorkspaceStore } from "@/store/workspace-store"; interface Props { documents: DocInfo[]; @@ -58,6 +59,7 @@ export default function DocumentSidebar({ onDocumentRenamed, }: Props) { const { t } = useTranslation(); + const workspace = useWorkspaceStore((s) => s.workspace); const [uploading, setUploading] = useState(false); const [uploadProgress, setUploadProgress] = useState(0); const [uploadError, setUploadError] = useState(""); @@ -105,6 +107,7 @@ export default function DocumentSidebar({ const file = acceptedFiles[i]; const formData = new FormData(); formData.append("file", file); + formData.append("workspace", workspace); toast.info(`⏳ Uploading '${file.name}'...`); await api.postForm("/api/v1/documents/upload", formData); @@ -122,7 +125,7 @@ export default function DocumentSidebar({ } })(); }, - [onDocumentsChange, t] + [onDocumentsChange, t, workspace] ); const { getRootProps, getInputProps, isDragActive } = useDropzone({ From 9b8332a9dc371574fff02a93ae1d154cb8135805 Mon Sep 17 00:00:00 2001 From: hrshjswniii Date: Sun, 7 Jun 2026 20:07:36 +0530 Subject: [PATCH 3/6] [BUGFIX] : Fix Unimplemented Collaborative Workspaces --- PDF-Assistant-RAG | 1 + 1 file changed, 1 insertion(+) create mode 160000 PDF-Assistant-RAG diff --git a/PDF-Assistant-RAG b/PDF-Assistant-RAG new file mode 160000 index 00000000..ff801edb --- /dev/null +++ b/PDF-Assistant-RAG @@ -0,0 +1 @@ +Subproject commit ff801edbf71981bda028bd49f695da7122ecc936 From bfb7c689eabb6c2fe2988beb7fcc61c32974ee25 Mon Sep 17 00:00:00 2001 From: hrshjswniii Date: Wed, 10 Jun 2026 01:02:17 +0530 Subject: [PATCH 4/6] [Feature] : Add dedicated Recycle Bin / Trash workflow: --- backend/app/routes/documents.py | 112 ++++++++++- backend/app/services/cleanup.py | 7 + backend/tests/test_documents.py | 87 ++++++++ frontend/package-lock.json | 1 + .../components/document/DocumentSidebar.tsx | 29 ++- .../src/components/document/TrashModal.tsx | 188 ++++++++++++++++++ 6 files changed, 418 insertions(+), 6 deletions(-) create mode 100644 frontend/src/components/document/TrashModal.tsx diff --git a/backend/app/routes/documents.py b/backend/app/routes/documents.py index f4f9577c..126101a8 100644 --- a/backend/app/routes/documents.py +++ b/backend/app/routes/documents.py @@ -9,7 +9,7 @@ import asyncio import concurrent.futures from datetime import datetime, timezone -from typing import Optional +from typing import Optional, List from pathlib import Path import shutil import socket @@ -53,7 +53,8 @@ else: CRAWL4AI_IMPORT_ERROR = None -from sqlalchemy import select +from sqlalchemy import select, or_ +from app.rag.graph_builder import delete_graph logger = logging.getLogger(__name__) settings = get_settings() @@ -588,6 +589,113 @@ def rename_document( return _deserialize_doc(doc) +@router.get("/trash", response_model=List[DocumentResponse]) +def list_trash_documents( + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """ + List all soft-deleted documents (is_deleted == True) belonging to the user or their active workspaces. + """ + memberships = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user.id + ).all() + workspace_ids = [m.workspace_id for m in memberships] + + if workspace_ids: + filters = [ + Document.is_deleted.is_(True), + or_( + Document.user_id == user.id, + Document.workspace_id.in_(workspace_ids) + ) + ] + else: + filters = [ + Document.is_deleted.is_(True), + Document.user_id == user.id + ] + + docs = ( + db.query(Document) + .filter(*filters) + .order_by(Document.deleted_at.desc()) + .all() + ) + + return [_deserialize_doc(d) for d in docs] + + +@router.post("/{document_id}/restore", response_model=DocumentResponse) +def restore_document( + document_id: str, + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """ + Restore a soft-deleted document. + """ + doc = db.query(Document).filter( + Document.id == document_id, + Document.is_deleted.is_(True), + ).first() + + if not doc or not check_document_access(doc, user.id, db): + raise NotFoundException("Document") + + doc.is_deleted = False + doc.deleted_at = None + db.commit() + db.refresh(doc) + + return _deserialize_doc(doc) + + +@router.delete("/{document_id}/purge") +def purge_document( + document_id: str, + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """ + Immediately and permanently hard-deletes the document, including its database record, + physical uploads, vector chunks, and knowledge graph files. + """ + doc = db.query(Document).filter( + Document.id == document_id + ).first() + + if not doc or not check_document_access(doc, user.id, db): + raise NotFoundException("Document") + + # 1. Delete vector chunks + try: + from app.rag.vectorstore import delete_document_chunks + delete_document_chunks(document_id=doc.id, user_id=doc.user_id) + except Exception as e: + logger.warning(f"Error cleaning vectors for {doc.id}: {e}") + + # 2. Delete knowledge graph + try: + delete_graph(user_id=doc.user_id, document_id=doc.id) + except Exception as e: + logger.warning(f"Error deleting graph for {doc.id}: {e}") + + # 3. Delete physical upload file + try: + filepath = os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename) + if os.path.exists(filepath): + os.remove(filepath) + except Exception as e: + logger.warning(f"Error deleting file for {doc.id}: {e}") + + # 4. Delete document record + db.delete(doc) + db.commit() + + return {"message": f"Document '{doc.original_name}' permanently purged successfully"} + + @router.get("/{document_id}", response_model=DocumentResponse) def get_document( document_id: str, diff --git a/backend/app/services/cleanup.py b/backend/app/services/cleanup.py index 70e79b6a..026d6445 100644 --- a/backend/app/services/cleanup.py +++ b/backend/app/services/cleanup.py @@ -82,6 +82,13 @@ def cleanup_old_deleted_documents(): except Exception as e: logger.warning("Error deleting file for %s: %s", doc.id, e) + try: + from app.rag.graph_builder import delete_graph + + delete_graph(user_id=doc.user_id, document_id=doc.id) + except Exception as e: + logger.warning("Error deleting graph for %s: %s", doc.id, e) + db.delete(doc) if old: diff --git a/backend/tests/test_documents.py b/backend/tests/test_documents.py index 9ab16166..7aa98174 100644 --- a/backend/tests/test_documents.py +++ b/backend/tests/test_documents.py @@ -182,3 +182,90 @@ def test_delete_document_soft_deletes_and_hides_document(client, auth_headers, r get_response = client.get(f"/api/v1/documents/{doc_id}", headers=auth_headers) assert get_response.status_code == 404 + + +def test_list_trash_documents(client, auth_headers, ready_document, db_session): + # Set document as soft-deleted + from datetime import datetime, timezone + ready_document.is_deleted = True + ready_document.deleted_at = datetime.now(timezone.utc) + db_session.commit() + + # Get trash + response = client.get("/api/v1/documents/trash", headers=auth_headers) + assert response.status_code == 200 + payload = response.json() + assert len(payload) == 1 + assert payload[0]["id"] == ready_document.id + assert payload[0]["original_name"] == "ready.txt" + + +def test_restore_document(client, auth_headers, ready_document, db_session): + from datetime import datetime, timezone + ready_document.is_deleted = True + ready_document.deleted_at = datetime.now(timezone.utc) + db_session.commit() + + # Verify not in active list + list_response = client.get("/api/v1/documents/", headers=auth_headers) + assert list_response.json()["total"] == 0 + + # Restore + response = client.post(f"/api/v1/documents/{ready_document.id}/restore", headers=auth_headers) + assert response.status_code == 200 + + db_session.refresh(ready_document) + assert ready_document.is_deleted is False + assert ready_document.deleted_at is None + + # Verify back in active list + list_response = client.get("/api/v1/documents/", headers=auth_headers) + assert list_response.json()["total"] == 1 + + +def test_purge_document(client, auth_headers, ready_document, db_session, monkeypatch): + from app.rag import vectorstore + import app.routes.documents + import os + + chunk_deleted = [] + graph_deleted = [] + file_deleted = [] + + monkeypatch.setattr( + vectorstore, + "delete_document_chunks", + lambda document_id, user_id: chunk_deleted.append(document_id) + ) + monkeypatch.setattr( + app.routes.documents, + "delete_graph", + lambda user_id, document_id: graph_deleted.append(document_id) + ) + monkeypatch.setattr( + os.path, + "exists", + lambda path: True + ) + monkeypatch.setattr( + os, + "remove", + lambda path: file_deleted.append(path) + ) + + doc_id = ready_document.id + + # Purge document + response = client.delete(f"/api/v1/documents/{doc_id}/purge", headers=auth_headers) + assert response.status_code == 200 + + # Verify mocks were called + assert doc_id in chunk_deleted + assert doc_id in graph_deleted + assert len(file_deleted) == 1 + + # Verify DB record is gone + refreshed = db_session.get(Document, doc_id) + assert refreshed is None + + diff --git a/frontend/package-lock.json b/frontend/package-lock.json index ac9effab..41685711 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -6631,6 +6631,7 @@ "version": "2.3.2", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", + "dev": true, "hasInstallScript": true, "license": "MIT", "optional": true, diff --git a/frontend/src/components/document/DocumentSidebar.tsx b/frontend/src/components/document/DocumentSidebar.tsx index d1079b12..4358b148 100644 --- a/frontend/src/components/document/DocumentSidebar.tsx +++ b/frontend/src/components/document/DocumentSidebar.tsx @@ -16,6 +16,7 @@ import { import { useDropzone } from "react-dropzone"; import { Settings } from "lucide-react"; import DocumentSettings from "./DocumentSettings"; +import TrashModal from "./TrashModal"; import { toast } from "sonner"; import { useWorkspaceStore } from "@/store/workspace-store"; @@ -72,6 +73,7 @@ export default function DocumentSidebar({ const [driveLoading, setDriveLoading] = useState(true); const [driveConnecting, setDriveConnecting] = useState(false); const [driveError, setDriveError] = useState(""); + const [trashOpen, setTrashOpen] = useState(false); useEffect(() => { let cancelled = false; @@ -337,10 +339,22 @@ export default function DocumentSidebar({ {/* ── Documents List ──────────────────────────── */}
-

- {loading - ? t("documents.documentsTitle", { count: "..." }) - : t("documents.documentsTitle", { count: documents.length })} +

+ + {loading + ? t("documents.documentsTitle", { count: "..." }) + : t("documents.documentsTitle", { count: documents.length })} + +

@@ -482,6 +496,13 @@ export default function DocumentSidebar({ }} /> )} + {trashOpen && ( + + )}
); } diff --git a/frontend/src/components/document/TrashModal.tsx b/frontend/src/components/document/TrashModal.tsx new file mode 100644 index 00000000..4300c399 --- /dev/null +++ b/frontend/src/components/document/TrashModal.tsx @@ -0,0 +1,188 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogDescription, +} from "@/components/ui/dialog"; +import { Button } from "@/components/ui/button"; +import { ScrollArea } from "@/components/ui/scroll-area"; +import { Trash2, RotateCcw, FileText, Loader2, AlertCircle } from "lucide-react"; +import { api } from "@/lib/api"; +import { toast } from "sonner"; +import type { DocInfo } from "@/app/dashboard/page"; + +interface Props { + open: boolean; + onOpenChange: (open: boolean) => void; + onDocumentsChange: () => void; +} + +export default function TrashModal({ open, onOpenChange, onDocumentsChange }: Props) { + const [trashDocs, setTrashDocs] = useState([]); + const [loading, setLoading] = useState(false); + const [actionId, setActionId] = useState(null); + const [actionType, setActionType] = useState<"restore" | "purge" | null>(null); + + const fetchTrash = async () => { + setLoading(true); + try { + const data = await api.get("/api/v1/documents/trash"); + setTrashDocs(data); + } catch (err) { + console.error("Failed to load trash documents:", err); + toast.error("Failed to load Recycle Bin items"); + } finally { + setLoading(false); + } + }; + + useEffect(() => { + if (open) { + void fetchTrash(); + } + }, [open]); + + const handleRestore = async (doc: DocInfo) => { + setActionId(doc.id); + setActionType("restore"); + try { + await api.post(`/api/v1/documents/${doc.id}/restore`); + toast.success(`🎉 Restored '${doc.original_name}' successfully!`); + // Update local state and trigger parent refresh + setTrashDocs((prev) => prev.filter((d) => d.id !== doc.id)); + onDocumentsChange(); + } catch (err) { + toast.error(err instanceof Error ? err.message : "Failed to restore document"); + } finally { + setActionId(null); + setActionType(null); + } + }; + + const handlePurge = async (doc: DocInfo) => { + if (!confirm(`⚠️ Are you sure you want to permanently delete '${doc.original_name}'? This action is irreversible and will purge all files, vector chunks, and graph data immediately.`)) { + return; + } + + setActionId(doc.id); + setActionType("purge"); + try { + await api.delete(`/api/v1/documents/${doc.id}/purge`); + toast.success(`🗑️ Permanently deleted '${doc.original_name}'`); + setTrashDocs((prev) => prev.filter((d) => d.id !== doc.id)); + onDocumentsChange(); + } catch (err) { + toast.error(err instanceof Error ? err.message : "Failed to delete document"); + } finally { + setActionId(null); + setActionType(null); + } + }; + + const formatSize = (bytes: number) => { + if (bytes < 1024) return `${bytes} B`; + if (bytes < 1048576) return `${(bytes / 1024).toFixed(1)} KB`; + return `${(bytes / 1048576).toFixed(1)} MB`; + }; + + return ( + + + + + + Recycle Bin + + + Items in the Recycle Bin will be permanently purged after 30 days. You can restore them or delete them permanently now. + + + +
+ + {loading && trashDocs.length === 0 ? ( +
+ +

Loading trashed files...

+
+ ) : trashDocs.length === 0 ? ( +
+
+ +
+

Recycle Bin is empty

+

+ Soft-deleted documents will appear here. +

+
+ ) : ( +
+ {trashDocs.map((doc) => ( +
+
+ +
+

+ {doc.original_name} +

+
+ {formatSize(doc.file_size)} + + Deleted: {new Date(doc.uploaded_at).toLocaleDateString()} +
+
+
+ +
+ + +
+
+ ))} +
+ )} +
+
+ +
+ +
+
+
+ ); +} From cf00db5d149c4eea3701435fb0e65b277108c146 Mon Sep 17 00:00:00 2001 From: hrshjswniii Date: Wed, 10 Jun 2026 01:27:30 +0530 Subject: [PATCH 5/6] [Feature] : Add support for multi-document selection within the chat interface: --- backend/app/rag/agent.py | 31 ++++++- backend/app/rag/graph_retriever.py | 20 ++++- backend/app/rag/retriever.py | 31 ++++++- backend/app/rag/tools.py | 3 + backend/app/routes/chat.py | 84 +++++++++++++++++-- backend/tests/test_chat.py | 31 +++++++ backend/tests/test_rag_tools.py | 2 +- backend/tests/test_retriever.py | 45 ++++++++++ frontend/src/app/dashboard/page.tsx | 9 ++ frontend/src/components/chat/ChatPanel.tsx | 30 +++++-- .../components/document/DocumentSidebar.tsx | 39 +++++++-- 11 files changed, 297 insertions(+), 28 deletions(-) diff --git a/backend/app/rag/agent.py b/backend/app/rag/agent.py index d8bc0990..454d2cb8 100644 --- a/backend/app/rag/agent.py +++ b/backend/app/rag/agent.py @@ -53,6 +53,7 @@ def _format_chat_history(messages: List[Dict[str, str]]) -> str: def get_agent_executor( user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, @@ -61,7 +62,13 @@ def get_agent_executor( """Initialize the LangChain ReAct agent executor.""" # Initialize tools - pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, workspace=workspace, top_k=top_k) + pdf_tool = PDFSearchTool( + user_id=user_id, + document_id=document_id, + document_ids=document_ids, + workspace=workspace, + top_k=top_k, + ) tools = [pdf_tool, MathTool(), WebSearchTool()] # Initialize LLM @@ -120,6 +127,7 @@ def generate_answer( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, @@ -148,7 +156,15 @@ def generate_answer( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history, workspace) + executor, pdf_tool, formatted_history = get_agent_executor( + user_id=user_id, + document_id=document_id, + document_ids=document_ids, + hf_token=hf_token, + top_k=top_k, + chat_history=chat_history, + workspace=workspace, + ) result = executor.invoke({"input": question, "chat_history": formatted_history}) raw_answer = result.get("output", "") @@ -193,6 +209,7 @@ def generate_answer_stream( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, @@ -222,7 +239,15 @@ def generate_answer_stream( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history, workspace) + executor, pdf_tool, formatted_history = get_agent_executor( + user_id=user_id, + document_id=document_id, + document_ids=document_ids, + hf_token=hf_token, + top_k=top_k, + chat_history=chat_history, + workspace=workspace, + ) sources_sent = False diff --git a/backend/app/rag/graph_retriever.py b/backend/app/rag/graph_retriever.py index 39841aa1..27378300 100644 --- a/backend/app/rag/graph_retriever.py +++ b/backend/app/rag/graph_retriever.py @@ -18,10 +18,21 @@ settings = get_settings() -def _candidate_graphs(user_id: str, document_id: Optional[str]) -> Iterable[nx.Graph]: +def _candidate_graphs( + user_id: str, + document_id: Optional[str], + document_ids: Optional[List[str]] = None, +) -> Iterable[nx.Graph]: if document_id: graph = load_graph(user_id, document_id) return [graph] if graph is not None else [] + elif document_ids: + graphs = [] + for doc_id in document_ids: + graph = load_graph(user_id, doc_id) + if graph is not None: + graphs.append(graph) + return graphs graphs = [] for path in iter_graph_paths(user_id): @@ -67,12 +78,17 @@ def get_entity_context( query: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, ) -> str: """Return compact graph relationship context relevant to the query.""" relationships: Dict[Tuple[str, str], Dict[str, object]] = {} try: - graphs = _candidate_graphs(user_id=user_id, document_id=document_id) + graphs = _candidate_graphs( + user_id=user_id, + document_id=document_id, + document_ids=document_ids, + ) for graph in graphs: matched_nodes = _match_query_nodes(graph, query) diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index 65d96700..cffeb003 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -211,9 +211,11 @@ def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: @trace_function( "retrieve", - metadata_factory=lambda query, user_id, document_id=None, top_k=None: { + metadata_factory=lambda query, user_id, document_id=None, top_k=None, **kwargs: { "user_id": user_id, "document_id": document_id, + "document_ids": kwargs.get("document_ids"), + "workspace": kwargs.get("workspace"), "embedding_model": settings.EMBEDDING_MODEL, "reranker_model": settings.RERANKER_MODEL, "top_k_retrieval": settings.TOP_K_RETRIEVAL, @@ -224,6 +226,7 @@ def retrieve( query: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, workspace: Optional[str] = None, top_k: Optional[int] = None, ) -> List[Dict[str, Any]]: @@ -264,6 +267,32 @@ def retrieve( if not has_access: return [] active_doc_ids = [str(doc.id)] + elif document_ids: + # Filter by a list of document IDs and verify access for each + filters = [ + Document.id.in_(document_ids), + Document.is_deleted.is_(False), + ] + docs = db.query(Document).filter(*filters).all() + + accessible_docs = [] + for doc in docs: + has_access = False + if str(doc.user_id) == str(user_id): + has_access = True + elif doc.workspace_id: + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == doc.workspace_id, + WorkspaceMembership.user_id == user_id, + ).first() + if membership: + has_access = True + if has_access: + accessible_docs.append(doc) + + if not accessible_docs: + return [] + active_doc_ids = [str(doc.id) for doc in accessible_docs] else: # Filter documents by workspace filters = [Document.is_deleted.is_(False)] diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py index c03c56ab..e75ebb99 100644 --- a/backend/app/rag/tools.py +++ b/backend/app/rag/tools.py @@ -156,6 +156,7 @@ class PDFSearchTool(BaseTool): user_id: str document_id: Optional[str] = None + document_ids: Optional[List[str]] = None workspace: Optional[str] = None top_k: Optional[int] = None # We'll store sources here to retrieve them after agent execution @@ -168,6 +169,7 @@ def _run(self, query: str) -> str: query=query, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, workspace=self.workspace, top_k=self.top_k, ) @@ -193,6 +195,7 @@ def _run(self, query: str) -> str: query=query, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, ) main_context = "\n\n".join(context_parts) diff --git a/backend/app/routes/chat.py b/backend/app/routes/chat.py index 03e9bff1..adbe6112 100644 --- a/backend/app/routes/chat.py +++ b/backend/app/routes/chat.py @@ -108,6 +108,7 @@ async def chat_ws(websocket: WebSocket, token: Optional[str] = Query(None)): question = payload.get("question") document_id = payload.get("document_id") + document_ids = payload.get("document_ids") session_id = payload.get("session_id") from app.rag.security import validate_user_input, UnsafePromptError @@ -141,6 +142,21 @@ async def chat_ws(websocket: WebSocket, token: Optional[str] = Query(None)): await websocket.send_json({"type": "error", "data": detail}) await websocket.close() return + elif document_ids: + for doc_id in document_ids: + doc = db.query(Document).filter( + Document.id == doc_id, + Document.user_id == user.id, + Document.is_deleted.is_(False), + ).first() + if not doc: + await websocket.send_json({"type": "error", "data": f"Document {doc_id} not found"}) + await websocket.close() + return + if doc.status != "ready": + await websocket.send_json({"type": "error", "data": f"Document '{doc.original_name}' is still {doc.status}."}) + await websocket.close() + return # Resolve or create session if not session_id: @@ -167,7 +183,14 @@ async def chat_ws(websocket: WebSocket, token: Optional[str] = Query(None)): chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] # Save user message - _save_message(db, user.id, document_id, "user", question, session_id=session_id) + _save_message( + db, + user.id, + document_id if not document_ids else None, + "user", + question, + session_id=session_id, + ) # Stream answer using existing generator and forward structured events try: @@ -175,6 +198,7 @@ async def chat_ws(websocket: WebSocket, token: Optional[str] = Query(None)): question=question, user_id=user.id, document_id=document_id, + document_ids=document_ids, hf_token=user.hf_token, chat_history=chat_history, ): @@ -441,6 +465,7 @@ def generate_answer( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, @@ -452,6 +477,7 @@ def generate_answer( question=question, user_id=user_id, document_id=document_id, + document_ids=document_ids, hf_token=hf_token, top_k=top_k, chat_history=chat_history, @@ -463,6 +489,7 @@ def generate_answer_stream( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, @@ -474,6 +501,7 @@ def generate_answer_stream( question=question, user_id=user_id, document_id=document_id, + document_ids=document_ids, hf_token=hf_token, top_k=top_k, chat_history=chat_history, @@ -533,6 +561,23 @@ def ask_question( # Update last_accessed_at timestamp doc.last_accessed_at = datetime.now(timezone.utc) db.commit() + elif payload.document_ids: + for doc_id in payload.document_ids: + doc = ( + db.query(Document) + .filter( + Document.id == doc_id, + Document.user_id == user.id, + Document.is_deleted.is_(False), + ) + .first() + ) + + if not doc: + raise NotFoundException(f"Document {doc_id}") + + if doc.status != "ready": + raise ValidationException(f"Document '{doc.original_name}' is still {doc.status}. Please wait for processing to complete.") # Resolve or create session session_id = payload.session_id @@ -560,8 +605,9 @@ def ask_question( chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] # Cache check — return instantly if this (question, document) was answered before + doc_ids_str = ",".join(sorted(payload.document_ids)) if payload.document_ids else "" cached_answer = get_cached_response( - document_id=str(payload.document_id or ""), + document_id=str(payload.document_id or doc_ids_str), question=payload.question, ) if cached_answer is not None: @@ -576,6 +622,7 @@ def ask_question( question=payload.question, user_id=user.id, document_id=payload.document_id, + document_ids=payload.document_ids, hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, @@ -584,15 +631,15 @@ def ask_question( # Store result in cache for future identical questions set_cached_response( - document_id=str(payload.document_id or ""), + document_id=str(payload.document_id or doc_ids_str), question=payload.question, answer=result["answer"], ) # Save to chat history - _save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id) + _save_message(db, user.id, payload.document_id if not payload.document_ids else None, "user", payload.question, session_id=session_id) _save_message( - db, user.id, payload.document_id, "assistant", result["answer"], result["sources"], session_id=session_id + db, user.id, payload.document_id if not payload.document_ids else None, "assistant", result["answer"], result["sources"], session_id=session_id ) return ChatResponse( @@ -653,6 +700,23 @@ def ask_question_stream( # Update last_accessed_at timestamp doc.last_accessed_at = datetime.now(timezone.utc) db.commit() + elif payload.document_ids: + for doc_id in payload.document_ids: + doc = ( + db.query(Document) + .filter( + Document.id == doc_id, + Document.user_id == user.id, + Document.is_deleted.is_(False), + ) + .first() + ) + + if not doc: + raise NotFoundException(f"Document {doc_id}") + + if doc.status != "ready": + raise ValidationException(f"Document '{doc.original_name}' is still {doc.status}. Please wait for processing to complete.") started_at = time.perf_counter() @@ -682,11 +746,12 @@ def ask_question_stream( chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] # Save user message immediately - _save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id) + _save_message(db, user.id, payload.document_id if not payload.document_ids else None, "user", payload.question, session_id=session_id) # Cache check before starting the stream + doc_ids_str = ",".join(sorted(payload.document_ids)) if payload.document_ids else "" cached_answer = get_cached_response( - document_id=str(payload.document_id or ""), + document_id=str(payload.document_id or doc_ids_str), question=payload.question, ) if cached_answer is not None: @@ -718,6 +783,7 @@ def event_stream(): question=payload.question, user_id=user.id, document_id=payload.document_id, + document_ids=payload.document_ids, hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, @@ -739,7 +805,7 @@ def event_stream(): # Cache the full answer for future identical questions if full_answer: set_cached_response( - document_id=str(payload.document_id or ""), + document_id=str(payload.document_id or doc_ids_str), question=payload.question, answer=full_answer, ) @@ -749,7 +815,7 @@ def event_stream(): with get_db_session() as save_db: _save_message( - save_db, user.id, payload.document_id, "assistant", full_answer, sources, session_id=session_id + save_db, user.id, payload.document_id if not payload.document_ids else None, "assistant", full_answer, sources, session_id=session_id ) finally: record_query_response_time(time.perf_counter() - started_at) diff --git a/backend/tests/test_chat.py b/backend/tests/test_chat.py index 1073c9c2..3e5386a1 100644 --- a/backend/tests/test_chat.py +++ b/backend/tests/test_chat.py @@ -127,3 +127,34 @@ class MockResponse: generate_answer(question="hello?", user_id="some-user", hf_token=None) from app.config import get_settings assert called_with_token == get_settings().HF_TOKEN + + +def test_chat_ask_success_with_document_ids(client, auth_headers, ready_document, monkeypatch): + monkeypatch.setattr( + "app.routes.chat.generate_answer", + lambda question, user_id, document_id=None, document_ids=None, **kwargs: { + "answer": "Mocked answer for multiple docs", + "sources": [ + { + "text": "Mock source", + "filename": "ready.txt", + "page": 1, + "score": 0.99, + "confidence": 99.0, + } + ], + }, + ) + + response = client.post( + "/api/v1/chat/ask", + headers=auth_headers, + json={"question": "What is in the doc?", "document_ids": [ready_document.id]}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["answer"] == "Mocked answer for multiple docs" + assert payload["document_id"] is None + assert payload["sources"][0]["filename"] == "ready.txt" + diff --git a/backend/tests/test_rag_tools.py b/backend/tests/test_rag_tools.py index 12eddb81..bf33a8b0 100644 --- a/backend/tests/test_rag_tools.py +++ b/backend/tests/test_rag_tools.py @@ -158,7 +158,7 @@ def fake_retrieve(query, user_id, document_id=None, workspace=None, top_k=None, retrieve_calls.append((query, user_id, document_id)) return chunks - def fake_get_entity_context(query, user_id, document_id=None): + def fake_get_entity_context(query, user_id, document_id=None, **kwargs): graph_calls.append((query, user_id, document_id)) return "Alpha -> acquired -> Beta" diff --git a/backend/tests/test_retriever.py b/backend/tests/test_retriever.py index 66e29083..4448cdce 100644 --- a/backend/tests/test_retriever.py +++ b/backend/tests/test_retriever.py @@ -137,3 +137,48 @@ def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=N # Should only query for the active document ID assert captured_doc_ids == ["active-doc-id"] + +def test_retrieve_with_document_ids_list_and_rbac_checks(db_session, user, monkeypatch): + from app.models import Document + from app.rag import retriever + + # Create two documents: one owned by user (should be allowed), one owned by someone else (should be excluded) + owned_doc = Document( + id="owned-doc-id", + user_id=user.id, + filename="owned.pdf", + original_name="owned.pdf", + is_deleted=False, + ) + other_doc = Document( + id="other-doc-id", + user_id="another-user-id", + filename="other.pdf", + original_name="other.pdf", + is_deleted=False, + ) + db_session.add(owned_doc) + db_session.add(other_doc) + db_session.commit() + + monkeypatch.setattr("app.database.SessionLocal", lambda: db_session) + monkeypatch.setattr(retriever, "transform_query", lambda _query: ["query"]) + monkeypatch.setattr(retriever, "embed_query", lambda query: "embedding") + monkeypatch.setattr(retriever, "get_reranker", lambda: None) + + captured_doc_ids = [] + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): + nonlocal captured_doc_ids + captured_doc_ids = document_ids + return [] + + monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks) + monkeypatch.setattr(retriever.CustomBM25Retriever, "_get_relevant_documents", lambda *args, **kwargs: []) + + # Retrieve with both document IDs + retriever.retrieve("test query", user_id=user.id, document_ids=["owned-doc-id", "other-doc-id"]) + + # Should only query for the owned document ID (excluding other_doc because of RBAC check) + assert captured_doc_ids == ["owned-doc-id"] + + diff --git a/frontend/src/app/dashboard/page.tsx b/frontend/src/app/dashboard/page.tsx index 3384e242..d9ef446c 100644 --- a/frontend/src/app/dashboard/page.tsx +++ b/frontend/src/app/dashboard/page.tsx @@ -64,7 +64,12 @@ export default function DashboardPage() { const [documents, setDocuments] = useState([]); const prevDocsRef = useRef>({}); const [activeDoc, setActiveDoc] = useState(null); + const [selectedDocIds, setSelectedDocIds] = useState([]); const [pdfPage, setPdfPage] = useState(1); + + useEffect(() => { + setSelectedDocIds([]); + }, [workspace]); const [pdfHighlightTarget, setPdfHighlightTarget] = useState<{ page: number; rects?: { @@ -197,9 +202,12 @@ export default function DashboardPage() { onSelectDoc={(doc) => { setActiveDoc(doc); setPdfPage(1); + setSelectedDocIds([]); }} onDocumentsChange={loadDocuments} onDocumentRenamed={handleDocumentRenamed} + selectedDocIds={selectedDocIds} + onSelectDocsChange={setSelectedDocIds} /> ); @@ -237,6 +245,7 @@ export default function DashboardPage() {
{ setPdfPage(target.page); setPdfHighlightTarget({ page: target.page, rects: target.highlightRects }); diff --git a/frontend/src/components/chat/ChatPanel.tsx b/frontend/src/components/chat/ChatPanel.tsx index 51ae666f..7efc5fa0 100644 --- a/frontend/src/components/chat/ChatPanel.tsx +++ b/frontend/src/components/chat/ChatPanel.tsx @@ -56,10 +56,11 @@ interface CitationTarget { interface Props { activeDoc: DocInfo | null; + selectedDocIds?: string[]; onCitationClick: (target: CitationTarget) => void; } -export default function ChatPanel({ activeDoc, onCitationClick }: Props) { +export default function ChatPanel({ activeDoc, selectedDocIds = [], onCitationClick }: Props) { const { t, i18n } = useTranslation(); const workspace = useWorkspaceStore((s) => s.workspace); const messages = useChatStore((state) => state.messages); @@ -200,7 +201,13 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { const wsDone = new Promise((resolve, reject) => { ws.onopen = () => { // Send initial payload - ws.send(JSON.stringify({ question, document_id: activeDoc?.id || null, session_id: activeSessionId, workspace })); + ws.send(JSON.stringify({ + question, + document_id: selectedDocIds.length > 0 ? null : (activeDoc?.id || null), + document_ids: selectedDocIds.length > 0 ? selectedDocIds : null, + session_id: activeSessionId, + workspace + })); }; // If WS doesn't open within 800ms, treat as failure and fallback @@ -275,7 +282,8 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { try { const stream = api.streamPost("/api/v1/chat/ask/stream", { question, - document_id: activeDoc?.id || null, + document_id: selectedDocIds.length > 0 ? null : (activeDoc?.id || null), + document_ids: selectedDocIds.length > 0 ? selectedDocIds : null, session_id: activeSessionId, workspace, }); @@ -547,10 +555,16 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) {

- {activeDoc ? t("chat.askAboutDocument") : t("chat.selectDocument")} + {selectedDocIds.length > 0 + ? `Chatting with ${selectedDocIds.length} selected ${selectedDocIds.length === 1 ? "file" : "files"}` + : activeDoc + ? t("chat.askAboutDocument") + : t("chat.selectDocument")}

- {activeDoc + {selectedDocIds.length > 0 + ? `Ask a question to query and synthesize across the selected ${selectedDocIds.length === 1 ? "file" : "files"}.` + : activeDoc ? t("chat.readyPrompt", { name: activeDoc.original_name }) : t("chat.uploadPrompt")}

@@ -627,7 +641,9 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { onChange={(e) => setInput(e.target.value)} onKeyDown={handleKeyDown} placeholder={ - activeDoc + selectedDocIds.length > 0 + ? `Ask about the ${selectedDocIds.length} selected ${selectedDocIds.length === 1 ? "file" : "files"}...` + : activeDoc ? t("chat.askPlaceholder", { name: activeDoc.original_name }) : t("chat.selectPlaceholder") } @@ -697,7 +713,7 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { )} - {messages.length > 0 && ( + {messages.length > 0 && activeDoc && selectedDocIds.length === 0 && ( <> {/* Export dropdown */}
diff --git a/frontend/src/components/document/DocumentSidebar.tsx b/frontend/src/components/document/DocumentSidebar.tsx index 4358b148..0f42a9e7 100644 --- a/frontend/src/components/document/DocumentSidebar.tsx +++ b/frontend/src/components/document/DocumentSidebar.tsx @@ -27,6 +27,8 @@ interface Props { onSelectDoc: (doc: DocInfo) => void; onDocumentsChange: () => void; onDocumentRenamed: (doc: DocInfo) => void; + selectedDocIds?: string[]; + onSelectDocsChange?: (ids: string[]) => void; } function DocumentListSkeleton() { @@ -58,6 +60,8 @@ export default function DocumentSidebar({ onSelectDoc, onDocumentsChange, onDocumentRenamed, + selectedDocIds = [], + onSelectDocsChange, }: Props) { const { t } = useTranslation(); const workspace = useWorkspaceStore((s) => s.workspace); @@ -340,11 +344,18 @@ export default function DocumentSidebar({ {/* ── Documents List ──────────────────────────── */}

- - {loading - ? t("documents.documentsTitle", { count: "..." }) - : t("documents.documentsTitle", { count: documents.length })} - +
+ + {loading + ? t("documents.documentsTitle", { count: "..." }) + : t("documents.documentsTitle", { count: documents.length })} + + {selectedDocIds && selectedDocIds.length > 0 && ( + + {selectedDocIds.length} selected + + )} +