diff --git a/backend/app/database.py b/backend/app/database.py index 8b644760..f74c408b 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -166,6 +166,7 @@ def _migrate_schema(): ("documents", "drive_file_id", "ALTER TABLE documents ADD COLUMN drive_file_id VARCHAR(255)"), ("documents", "drive_folder_id", "ALTER TABLE documents ADD COLUMN drive_folder_id VARCHAR(255)"), ("documents", "drive_synced_at", "ALTER TABLE documents ADD COLUMN drive_synced_at TIMESTAMP"), + ("documents", "workspace_id", "ALTER TABLE documents ADD COLUMN workspace_id VARCHAR(36)"), ] for table, column, ddl in docs_migrations: if column not in existing_docs_columns: diff --git a/backend/app/main.py b/backend/app/main.py index e3af1334..8947836c 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -77,6 +77,13 @@ async def document_cleanup_job(): except Exception as e: logger.warning(f"Auto-cleanup: Error deleting vectors for document {doc.id}: {e}") + # Delete knowledge graph + try: + from app.rag.graph_builder import delete_graph + delete_graph(user_id=doc.user_id, document_id=doc.id) + except Exception as e: + logger.warning(f"Auto-cleanup: Error deleting graph for document {doc.id}: {e}") + # Delete database record db.delete(doc) diff --git a/backend/app/models.py b/backend/app/models.py index a8611fe6..a812213b 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -163,6 +163,11 @@ class User(Base): back_populates="user", cascade="all, delete-orphan", ) + workspace_memberships = relationship( + "WorkspaceMembership", + back_populates="user", + cascade="all, delete-orphan", + ) class ApiKey(Base): @@ -185,6 +190,32 @@ class ApiKey(Base): user = relationship("User", back_populates="api_keys") +class Workspace(Base): + __tablename__ = "workspaces" + + id = Column(String(36), primary_key=True, default=generate_uuid) + name = Column(String(100), nullable=False) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + + # Relationships + memberships = relationship("WorkspaceMembership", back_populates="workspace", cascade="all, delete-orphan") + documents = relationship("Document", back_populates="workspace") + + +class WorkspaceMembership(Base): + __tablename__ = "workspace_memberships" + + id = Column(String(36), primary_key=True, default=generate_uuid) + workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=False, index=True) + user_id = Column(GUID, ForeignKey("users.id"), nullable=False, index=True) + role = Column(String(20), default="member", nullable=False) # "admin" | "member" + joined_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + + # Relationships + workspace = relationship("Workspace", back_populates="memberships") + user = relationship("User", back_populates="workspace_memberships") + + class WorkspaceInvitation(Base): __tablename__ = "workspace_invitations" @@ -254,6 +285,7 @@ class Document(Base): drive_synced_at = Column(DateTime, nullable=True) is_deleted = Column(Boolean, default=False, nullable=False, index=True) deleted_at = Column(DateTime, nullable=True) + workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=True, index=True) processing_progress = Column(Integer, default=0) processing_stage = Column(String(20), default="queued") retry_count = Column(Integer, default=0) @@ -264,6 +296,7 @@ class Document(Base): # Relationships owner = relationship("User", back_populates="documents") + workspace = relationship("Workspace", back_populates="documents") messages = relationship( "ChatMessage", back_populates="document", diff --git a/backend/app/rag/agent.py b/backend/app/rag/agent.py index 4917cfe5..454d2cb8 100644 --- a/backend/app/rag/agent.py +++ b/backend/app/rag/agent.py @@ -53,14 +53,22 @@ def _format_chat_history(messages: List[Dict[str, str]]) -> str: def get_agent_executor( user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, + workspace: Optional[str] = None, ): """Initialize the LangChain ReAct agent executor.""" # Initialize tools - pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, top_k=top_k) + pdf_tool = PDFSearchTool( + user_id=user_id, + document_id=document_id, + document_ids=document_ids, + workspace=workspace, + top_k=top_k, + ) tools = [pdf_tool, MathTool(), WebSearchTool()] # Initialize LLM @@ -119,9 +127,11 @@ def generate_answer( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, + workspace: Optional[str] = None, ) -> Dict[str, Any]: """ Agentic generation: retrieve via tools → reason → generate answer. @@ -146,7 +156,15 @@ def generate_answer( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) + executor, pdf_tool, formatted_history = get_agent_executor( + user_id=user_id, + document_id=document_id, + document_ids=document_ids, + hf_token=hf_token, + top_k=top_k, + chat_history=chat_history, + workspace=workspace, + ) result = executor.invoke({"input": question, "chat_history": formatted_history}) raw_answer = result.get("output", "") @@ -191,9 +209,11 @@ def generate_answer_stream( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, + workspace: Optional[str] = None, ) -> Generator[str, None, None]: """ Streaming Agentic pipeline. @@ -219,7 +239,15 @@ def generate_answer_stream( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) + executor, pdf_tool, formatted_history = get_agent_executor( + user_id=user_id, + document_id=document_id, + document_ids=document_ids, + hf_token=hf_token, + top_k=top_k, + chat_history=chat_history, + workspace=workspace, + ) sources_sent = False diff --git a/backend/app/rag/bm25.py b/backend/app/rag/bm25.py index 84d89c88..9ca8d196 100644 --- a/backend/app/rag/bm25.py +++ b/backend/app/rag/bm25.py @@ -108,6 +108,7 @@ def query_bm25( query: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, top_k: int = 10, ) -> List[Dict[str, Any]]: """ @@ -127,6 +128,11 @@ def query_bm25( all_results = [] for path in glob.glob(os.path.join(user_dir, "*.pkl")): + # Filter by document_ids if provided + if document_ids is not None: + doc_id = os.path.basename(path).rsplit(".", 1)[0] + if doc_id not in document_ids: + continue results = _query_single_index(path, tokenized_query, top_k) all_results.extend(results) diff --git a/backend/app/rag/graph_retriever.py b/backend/app/rag/graph_retriever.py index 39841aa1..27378300 100644 --- a/backend/app/rag/graph_retriever.py +++ b/backend/app/rag/graph_retriever.py @@ -18,10 +18,21 @@ settings = get_settings() -def _candidate_graphs(user_id: str, document_id: Optional[str]) -> Iterable[nx.Graph]: +def _candidate_graphs( + user_id: str, + document_id: Optional[str], + document_ids: Optional[List[str]] = None, +) -> Iterable[nx.Graph]: if document_id: graph = load_graph(user_id, document_id) return [graph] if graph is not None else [] + elif document_ids: + graphs = [] + for doc_id in document_ids: + graph = load_graph(user_id, doc_id) + if graph is not None: + graphs.append(graph) + return graphs graphs = [] for path in iter_graph_paths(user_id): @@ -67,12 +78,17 @@ def get_entity_context( query: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, ) -> str: """Return compact graph relationship context relevant to the query.""" relationships: Dict[Tuple[str, str], Dict[str, object]] = {} try: - graphs = _candidate_graphs(user_id=user_id, document_id=document_id) + graphs = _candidate_graphs( + user_id=user_id, + document_id=document_id, + document_ids=document_ids, + ) for graph in graphs: matched_nodes = _match_query_nodes(graph, query) diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index e542c17f..cffeb003 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -41,6 +41,7 @@ def invoke(self, query): class CustomVectorRetriever(BaseRetriever): user_id: str = Field(description="User ID") document_id: Optional[str] = Field(default=None, description="Document ID") + document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs") top_k: int = Field(default=10, description="Top K results") def _get_relevant_documents( @@ -51,6 +52,7 @@ def _get_relevant_documents( query_embedding=query_vector, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, top_k=self.top_k, ) return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates] @@ -59,6 +61,7 @@ def _get_relevant_documents( class CustomBM25Retriever(BaseRetriever): user_id: str = Field(description="User ID") document_id: Optional[str] = Field(default=None, description="Document ID") + document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs") top_k: int = Field(default=10, description="Top K results") def _get_relevant_documents( @@ -69,11 +72,13 @@ def _get_relevant_documents( query=query, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, top_k=self.top_k, ) return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates] + def transform_query(query: str) -> List[str]: """Rewrite a user question into multiple retrieval-friendly search queries.""" original_query = query.strip() @@ -206,9 +211,11 @@ def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]: @trace_function( "retrieve", - metadata_factory=lambda query, user_id, document_id=None, top_k=None: { + metadata_factory=lambda query, user_id, document_id=None, top_k=None, **kwargs: { "user_id": user_id, "document_id": document_id, + "document_ids": kwargs.get("document_ids"), + "workspace": kwargs.get("workspace"), "embedding_model": settings.EMBEDDING_MODEL, "reranker_model": settings.RERANKER_MODEL, "top_k_retrieval": settings.TOP_K_RETRIEVAL, @@ -219,6 +226,8 @@ def retrieve( query: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, + workspace: Optional[str] = None, top_k: Optional[int] = None, ) -> List[Dict[str, Any]]: """ @@ -228,17 +237,106 @@ def retrieve( Returns chunks with confidence scores. """ + from app.database import SessionLocal + from app.models import Document, WorkspaceMembership + + # Fetch active document IDs + db = SessionLocal() + try: + if document_id: + # Check if specific document is active (not deleted) + doc = db.query(Document).filter( + Document.id == document_id, + Document.is_deleted.is_(False), + ).first() + if not doc: + return [] + + # Verify user has access to the document + has_access = False + if str(doc.user_id) == str(user_id): + has_access = True + elif doc.workspace_id: + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == doc.workspace_id, + WorkspaceMembership.user_id == user_id, + ).first() + if membership: + has_access = True + + if not has_access: + return [] + active_doc_ids = [str(doc.id)] + elif document_ids: + # Filter by a list of document IDs and verify access for each + filters = [ + Document.id.in_(document_ids), + Document.is_deleted.is_(False), + ] + docs = db.query(Document).filter(*filters).all() + + accessible_docs = [] + for doc in docs: + has_access = False + if str(doc.user_id) == str(user_id): + has_access = True + elif doc.workspace_id: + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == doc.workspace_id, + WorkspaceMembership.user_id == user_id, + ).first() + if membership: + has_access = True + if has_access: + accessible_docs.append(doc) + + if not accessible_docs: + return [] + active_doc_ids = [str(doc.id) for doc in accessible_docs] + else: + # Filter documents by workspace + filters = [Document.is_deleted.is_(False)] + if workspace == "company": + memberships = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user_id + ).all() + workspace_ids = [m.workspace_id for m in memberships] + if not workspace_ids: + return [] + filters.append(Document.workspace_id.in_(workspace_ids)) + elif workspace == "personal" or not workspace: + filters.append(Document.user_id == user_id) + filters.append(Document.workspace_id.is_(None)) + else: + # Specific workspace ID + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace, + WorkspaceMembership.user_id == user_id, + ).first() + if not membership: + return [] + filters.append(Document.workspace_id == workspace) + + docs = db.query(Document).filter(*filters).all() + if not docs: + return [] + active_doc_ids = [str(doc.id) for doc in docs] + finally: + db.close() + # ── Stage 1: Hybrid Search with Query Transformation ───────────── effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL vector_retriever = CustomVectorRetriever( user_id=user_id, document_id=document_id, + document_ids=active_doc_ids if not document_id else None, top_k=effective_top_k, ) bm25_retriever = CustomBM25Retriever( user_id=user_id, document_id=document_id, + document_ids=active_doc_ids if not document_id else None, top_k=effective_top_k, ) diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py index 03813756..e75ebb99 100644 --- a/backend/app/rag/tools.py +++ b/backend/app/rag/tools.py @@ -156,6 +156,8 @@ class PDFSearchTool(BaseTool): user_id: str document_id: Optional[str] = None + document_ids: Optional[List[str]] = None + workspace: Optional[str] = None top_k: Optional[int] = None # We'll store sources here to retrieve them after agent execution last_sources: List[Dict[str, Any]] = [] @@ -167,6 +169,8 @@ def _run(self, query: str) -> str: query=query, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, + workspace=self.workspace, top_k=self.top_k, ) @@ -191,6 +195,7 @@ def _run(self, query: str) -> str: query=query, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, ) main_context = "\n\n".join(context_parts) diff --git a/backend/app/routes/chat.py b/backend/app/routes/chat.py index 00f14cc3..adbe6112 100644 --- a/backend/app/routes/chat.py +++ b/backend/app/routes/chat.py @@ -108,6 +108,7 @@ async def chat_ws(websocket: WebSocket, token: Optional[str] = Query(None)): question = payload.get("question") document_id = payload.get("document_id") + document_ids = payload.get("document_ids") session_id = payload.get("session_id") from app.rag.security import validate_user_input, UnsafePromptError @@ -141,6 +142,21 @@ async def chat_ws(websocket: WebSocket, token: Optional[str] = Query(None)): await websocket.send_json({"type": "error", "data": detail}) await websocket.close() return + elif document_ids: + for doc_id in document_ids: + doc = db.query(Document).filter( + Document.id == doc_id, + Document.user_id == user.id, + Document.is_deleted.is_(False), + ).first() + if not doc: + await websocket.send_json({"type": "error", "data": f"Document {doc_id} not found"}) + await websocket.close() + return + if doc.status != "ready": + await websocket.send_json({"type": "error", "data": f"Document '{doc.original_name}' is still {doc.status}."}) + await websocket.close() + return # Resolve or create session if not session_id: @@ -167,7 +183,14 @@ async def chat_ws(websocket: WebSocket, token: Optional[str] = Query(None)): chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] # Save user message - _save_message(db, user.id, document_id, "user", question, session_id=session_id) + _save_message( + db, + user.id, + document_id if not document_ids else None, + "user", + question, + session_id=session_id, + ) # Stream answer using existing generator and forward structured events try: @@ -175,6 +198,7 @@ async def chat_ws(websocket: WebSocket, token: Optional[str] = Query(None)): question=question, user_id=user.id, document_id=document_id, + document_ids=document_ids, hf_token=user.hf_token, chat_history=chat_history, ): @@ -441,9 +465,11 @@ def generate_answer( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, + workspace: Optional[str] = None, ): from app.rag.agent import generate_answer as _generate_answer @@ -451,9 +477,11 @@ def generate_answer( question=question, user_id=user_id, document_id=document_id, + document_ids=document_ids, hf_token=hf_token, top_k=top_k, chat_history=chat_history, + workspace=workspace, ) @@ -461,9 +489,11 @@ def generate_answer_stream( question: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, + workspace: Optional[str] = None, ): from app.rag.agent import generate_answer_stream as _generate_answer_stream @@ -471,9 +501,11 @@ def generate_answer_stream( question=question, user_id=user_id, document_id=document_id, + document_ids=document_ids, hf_token=hf_token, top_k=top_k, chat_history=chat_history, + workspace=workspace, ) @@ -529,6 +561,23 @@ def ask_question( # Update last_accessed_at timestamp doc.last_accessed_at = datetime.now(timezone.utc) db.commit() + elif payload.document_ids: + for doc_id in payload.document_ids: + doc = ( + db.query(Document) + .filter( + Document.id == doc_id, + Document.user_id == user.id, + Document.is_deleted.is_(False), + ) + .first() + ) + + if not doc: + raise NotFoundException(f"Document {doc_id}") + + if doc.status != "ready": + raise ValidationException(f"Document '{doc.original_name}' is still {doc.status}. Please wait for processing to complete.") # Resolve or create session session_id = payload.session_id @@ -556,8 +605,9 @@ def ask_question( chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] # Cache check — return instantly if this (question, document) was answered before + doc_ids_str = ",".join(sorted(payload.document_ids)) if payload.document_ids else "" cached_answer = get_cached_response( - document_id=str(payload.document_id or ""), + document_id=str(payload.document_id or doc_ids_str), question=payload.question, ) if cached_answer is not None: @@ -572,22 +622,24 @@ def ask_question( question=payload.question, user_id=user.id, document_id=payload.document_id, + document_ids=payload.document_ids, hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, + workspace=payload.workspace, ) # Store result in cache for future identical questions set_cached_response( - document_id=str(payload.document_id or ""), + document_id=str(payload.document_id or doc_ids_str), question=payload.question, answer=result["answer"], ) # Save to chat history - _save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id) + _save_message(db, user.id, payload.document_id if not payload.document_ids else None, "user", payload.question, session_id=session_id) _save_message( - db, user.id, payload.document_id, "assistant", result["answer"], result["sources"], session_id=session_id + db, user.id, payload.document_id if not payload.document_ids else None, "assistant", result["answer"], result["sources"], session_id=session_id ) return ChatResponse( @@ -648,6 +700,23 @@ def ask_question_stream( # Update last_accessed_at timestamp doc.last_accessed_at = datetime.now(timezone.utc) db.commit() + elif payload.document_ids: + for doc_id in payload.document_ids: + doc = ( + db.query(Document) + .filter( + Document.id == doc_id, + Document.user_id == user.id, + Document.is_deleted.is_(False), + ) + .first() + ) + + if not doc: + raise NotFoundException(f"Document {doc_id}") + + if doc.status != "ready": + raise ValidationException(f"Document '{doc.original_name}' is still {doc.status}. Please wait for processing to complete.") started_at = time.perf_counter() @@ -677,11 +746,12 @@ def ask_question_stream( chat_history = [{"role": m.role, "content": m.content} for m in recent_messages] # Save user message immediately - _save_message(db, user.id, payload.document_id, "user", payload.question, session_id=session_id) + _save_message(db, user.id, payload.document_id if not payload.document_ids else None, "user", payload.question, session_id=session_id) # Cache check before starting the stream + doc_ids_str = ",".join(sorted(payload.document_ids)) if payload.document_ids else "" cached_answer = get_cached_response( - document_id=str(payload.document_id or ""), + document_id=str(payload.document_id or doc_ids_str), question=payload.question, ) if cached_answer is not None: @@ -713,9 +783,11 @@ def event_stream(): question=payload.question, user_id=user.id, document_id=payload.document_id, + document_ids=payload.document_ids, hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, + workspace=payload.workspace, ): yield chunk @@ -733,7 +805,7 @@ def event_stream(): # Cache the full answer for future identical questions if full_answer: set_cached_response( - document_id=str(payload.document_id or ""), + document_id=str(payload.document_id or doc_ids_str), question=payload.question, answer=full_answer, ) @@ -743,7 +815,7 @@ def event_stream(): with get_db_session() as save_db: _save_message( - save_db, user.id, payload.document_id, "assistant", full_answer, sources, session_id=session_id + save_db, user.id, payload.document_id if not payload.document_ids else None, "assistant", full_answer, sources, session_id=session_id ) finally: record_query_response_time(time.perf_counter() - started_at) diff --git a/backend/app/routes/documents.py b/backend/app/routes/documents.py index 5aa5c73f..126101a8 100644 --- a/backend/app/routes/documents.py +++ b/backend/app/routes/documents.py @@ -9,14 +9,14 @@ import asyncio import concurrent.futures from datetime import datetime, timezone -from typing import Optional +from typing import Optional, List from pathlib import Path import shutil import socket import ipaddress import tempfile from urllib.parse import urlparse -from fastapi import APIRouter, Depends, UploadFile, File, status, Query, BackgroundTasks +from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status, Query, BackgroundTasks, Form from fastapi.responses import FileResponse from sqlalchemy.orm import Session @@ -28,7 +28,7 @@ AppException, ForbiddenException, ) -from app.models import User, Document +from app.models import User, Document, Workspace, WorkspaceMembership from app.schemas import ( DocumentResponse, DocumentListResponse, @@ -53,13 +53,58 @@ else: CRAWL4AI_IMPORT_ERROR = None -from sqlalchemy import select +from sqlalchemy import select, or_ +from app.rag.graph_builder import delete_graph logger = logging.getLogger(__name__) settings = get_settings() router = APIRouter(prefix="/documents", tags=["Documents"]) +def get_target_workspace_id(workspace_param: Optional[str], user_id: str, db: Session) -> Optional[str]: + if not workspace_param or workspace_param == "personal": + return None + + if workspace_param == "company": + # Check if the user is a member of any workspace + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user_id + ).first() + if membership: + return membership.workspace_id + + # If not, create a default "Company" workspace for them + workspace = Workspace(name="Company") + db.add(workspace) + db.commit() + db.refresh(workspace) + + membership = WorkspaceMembership( + workspace_id=workspace.id, + user_id=user_id, + role="admin", + ) + db.add(membership) + db.commit() + return workspace.id + + # If a specific workspace ID was passed, just return it + return workspace_param + + +def check_document_access(doc: Document, user_id: str, db: Session) -> bool: + if str(doc.user_id) == str(user_id): + return True + if doc.workspace_id: + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == doc.workspace_id, + WorkspaceMembership.user_id == user_id, + ).first() + if membership: + return True + return False + + ALLOWED_MIME_TYPES = settings.ALLOWED_MIME_TYPES def _deserialize_doc(doc: Document) -> DocumentResponse: @@ -190,6 +235,7 @@ async def _crawl(): @router.post("/upload", response_model=DocumentResponse, status_code=status.HTTP_202_ACCEPTED) async def upload_document( file: UploadFile = File(...), + workspace: Optional[str] = Form(None), background_tasks: BackgroundTasks = None, user: User = Depends(get_current_user), db: Session = Depends(get_db), @@ -205,10 +251,11 @@ async def upload_document( Args: file: The uploaded file, provided as a multipart/form-data field in the request. + workspace: Optional workspace context ('personal', 'company', or specific UUID). background_tasks: FastAPI BackgroundTasks instance for in-process fallback execution. user: The currently authenticated user, injected by the `get_current_user` dependency. db: Database session, injected by the `get_db` dependency. - + Returns: DocumentResponse: The created document record, validated against the response model (includes id, filename, original_name, file_size, status, etc.). @@ -245,6 +292,11 @@ async def upload_document( file_size = Path(filepath).stat().st_size + # Resolve target workspace + if not isinstance(workspace, str): + workspace = None + target_workspace_id = get_target_workspace_id(workspace, user.id, db) + # ── Create database record ─────────────────────── document = Document( user_id=user.id, @@ -252,6 +304,7 @@ async def upload_document( original_name=file.filename, file_size=file_size, status="pending", + workspace_id=target_workspace_id, ) db.add(document) db.commit() @@ -356,6 +409,9 @@ async def upload_document_url( url_path = parsed.path.rstrip("/") original_name = f"{parsed.netloc}{url_path or ''}.txt" + # Resolve target workspace + target_workspace_id = get_target_workspace_id(payload.workspace, user.id, db) + # ── Create database record ───────────────────────────── document = Document( user_id=user.id, @@ -363,6 +419,7 @@ async def upload_document_url( original_name=original_name, file_size=file_size, status="pending", + workspace_id=target_workspace_id, ) db.add(document) db.commit() @@ -422,11 +479,10 @@ def get_document_status( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") return DocumentStatusResponse.model_validate(doc) @@ -436,36 +492,53 @@ def get_document_status( def list_documents( page: int = Query(1, ge=1), per_page: int = Query(20, ge=1), + workspace: Optional[str] = Query(None), user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """ - List all documents for the authenticated user with pagination. + List all documents for the authenticated user with pagination, filtered by workspace. - Returns a paginated list of documents belonging to the current user, + Returns a paginated list of documents belonging to the current user or their workspaces, ordered by upload date (newest first). - - Args: - page: The page number to retrieve (1: indexed). Defaults to 1. - per_page: The number of documents to return per page. Defaults to 20. - user: The currently authenticated user, injected by the `get_current_user` dependency. - db: Database session, injected by the `get_db` dependency. - - Returns: - DocumentListResponse: A response model containing: - - items: A list of DocumentResponse objects for the current page. - - total: The total number of documents for the user. - - page: The current page number. - - pages: The total number of pages available. """ """Number of rows to skip""" skip: int = (page - 1) * per_page + # Base query filters + filters = [Document.is_deleted.is_(False)] + + if workspace == "company": + memberships = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user.id + ).all() + workspace_ids = [m.workspace_id for m in memberships] + if not workspace_ids: + return DocumentListResponse(items=[], total=0, page=page, pages=0) + filters.append(Document.workspace_id.in_(workspace_ids)) + elif workspace == "personal" or not workspace: + # Default to personal: own documents with no workspace + filters.append(Document.user_id == user.id) + filters.append(Document.workspace_id.is_(None)) + else: + # Filter by a specific workspace ID + # Verify user has access to this workspace + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace, + WorkspaceMembership.user_id == user.id, + ).first() + if not membership: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have access to this workspace.", + ) + filters.append(Document.workspace_id == workspace) + """Total Pages""" totalDocuments = ( db.query(Document) - .filter(Document.user_id == user.id, Document.is_deleted.is_(False)) + .filter(*filters) .count() ) """Total Pages""" @@ -474,7 +547,7 @@ def list_documents( """List all documents for the authenticated user in Paginated form""" docs = (( db.execute(select(Document) - .where(Document.user_id == user.id, Document.is_deleted.is_(False)) + .where(*filters) .order_by(Document.uploaded_at.desc()) .limit(per_page).offset(skip)) ) @@ -516,6 +589,113 @@ def rename_document( return _deserialize_doc(doc) +@router.get("/trash", response_model=List[DocumentResponse]) +def list_trash_documents( + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """ + List all soft-deleted documents (is_deleted == True) belonging to the user or their active workspaces. + """ + memberships = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user.id + ).all() + workspace_ids = [m.workspace_id for m in memberships] + + if workspace_ids: + filters = [ + Document.is_deleted.is_(True), + or_( + Document.user_id == user.id, + Document.workspace_id.in_(workspace_ids) + ) + ] + else: + filters = [ + Document.is_deleted.is_(True), + Document.user_id == user.id + ] + + docs = ( + db.query(Document) + .filter(*filters) + .order_by(Document.deleted_at.desc()) + .all() + ) + + return [_deserialize_doc(d) for d in docs] + + +@router.post("/{document_id}/restore", response_model=DocumentResponse) +def restore_document( + document_id: str, + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """ + Restore a soft-deleted document. + """ + doc = db.query(Document).filter( + Document.id == document_id, + Document.is_deleted.is_(True), + ).first() + + if not doc or not check_document_access(doc, user.id, db): + raise NotFoundException("Document") + + doc.is_deleted = False + doc.deleted_at = None + db.commit() + db.refresh(doc) + + return _deserialize_doc(doc) + + +@router.delete("/{document_id}/purge") +def purge_document( + document_id: str, + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """ + Immediately and permanently hard-deletes the document, including its database record, + physical uploads, vector chunks, and knowledge graph files. + """ + doc = db.query(Document).filter( + Document.id == document_id + ).first() + + if not doc or not check_document_access(doc, user.id, db): + raise NotFoundException("Document") + + # 1. Delete vector chunks + try: + from app.rag.vectorstore import delete_document_chunks + delete_document_chunks(document_id=doc.id, user_id=doc.user_id) + except Exception as e: + logger.warning(f"Error cleaning vectors for {doc.id}: {e}") + + # 2. Delete knowledge graph + try: + delete_graph(user_id=doc.user_id, document_id=doc.id) + except Exception as e: + logger.warning(f"Error deleting graph for {doc.id}: {e}") + + # 3. Delete physical upload file + try: + filepath = os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename) + if os.path.exists(filepath): + os.remove(filepath) + except Exception as e: + logger.warning(f"Error deleting file for {doc.id}: {e}") + + # 4. Delete document record + db.delete(doc) + db.commit() + + return {"message": f"Document '{doc.original_name}' permanently purged successfully"} + + @router.get("/{document_id}", response_model=DocumentResponse) def get_document( document_id: str, @@ -542,11 +722,10 @@ def get_document( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") return _deserialize_doc(doc) @@ -579,14 +758,13 @@ def serve_pdf( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") - filepath = os.path.join(settings.UPLOAD_DIR, user.id, doc.filename) + filepath = os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename) if not os.path.exists(filepath): raise NotFoundException("File") @@ -629,11 +807,10 @@ def delete_document( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") doc.is_deleted = True @@ -669,14 +846,13 @@ def update_chunk_settings( HTTPException: With status code 404 if the document is not found or does not belong to the authenticated user. HTTPException: With status code 400 if the provided chunk size or overlap values are invalid (e.g., chunk size less than 100, or overlap greater than or equal to chunk size). """ - # Validate if the document exists and belongs to the user + # Validate if the document exists and belongs to the user or workspace doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") if settings_update.chunk_size is not None: @@ -706,9 +882,9 @@ def update_chunk_settings( try: task = process_document.delay( document_id=doc.id, - filepath=os.path.join(settings.UPLOAD_DIR, user.id, doc.filename), + filepath=os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename), original_name=doc.original_name, - user_id=user.id, + user_id=str(doc.user_id), ) task_id = task.id except Exception as e: @@ -717,9 +893,9 @@ def update_chunk_settings( background_tasks.add_task( ingest_document, document_id=doc.id, - filepath=os.path.join(settings.UPLOAD_DIR, user.id, doc.filename), + filepath=os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename), original_name=doc.original_name, - user_id=user.id, + user_id=str(doc.user_id), ) task_id = f"local_{uuid.uuid4().hex}" diff --git a/backend/app/routes/workspaces.py b/backend/app/routes/workspaces.py index 7ce2de63..8157389a 100644 --- a/backend/app/routes/workspaces.py +++ b/backend/app/routes/workspaces.py @@ -8,12 +8,17 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from app.auth import create_invite_token, get_admin_user +from app.auth import create_invite_token, get_current_user, decode_invite_token from app.config import get_settings from app.database import get_db from app.email_service import send_email -from app.models import User, WorkspaceInvitation -from app.schemas import WorkspaceInviteRequest, WorkspaceInviteResponse +from app.models import User, Workspace, WorkspaceMembership, WorkspaceInvitation +from app.schemas import ( + WorkspaceInviteRequest, + WorkspaceInviteResponse, + WorkspaceInviteVerifyResponse, + WorkspaceInviteAcceptResponse, +) router = APIRouter(prefix="/workspaces", tags=["Workspaces"]) settings = get_settings() @@ -23,24 +28,59 @@ @router.post("/invite", response_model=WorkspaceInviteResponse, status_code=status.HTTP_200_OK) def invite_workspace( payload: WorkspaceInviteRequest, - admin_user: User = Depends(get_admin_user), + current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """Invite a user by email to join a workspace via a secure time-bound token.""" - existing_user = db.query(User).filter(User.email == payload.email).first() - if existing_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="A user with this email already exists.", + # 1. Lookup or create the Workspace + workspace = db.query(Workspace).filter(Workspace.name == payload.workspace_name).first() + if not workspace: + workspace = Workspace(name=payload.workspace_name) + db.add(workspace) + db.commit() + db.refresh(workspace) + + # Add the inviter as the admin of this workspace + membership = WorkspaceMembership( + workspace_id=workspace.id, + user_id=current_user.id, + role="admin", ) + db.add(membership) + db.commit() + else: + # Verify if current_user is a member of this workspace + user_membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace.id, + WorkspaceMembership.user_id == current_user.id, + ).first() + if not user_membership: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to invite users to this workspace.", + ) + + # 2. Check if the invited user is already a member of this workspace + invited_user = db.query(User).filter(User.email == payload.email).first() + if invited_user: + existing_membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace.id, + WorkspaceMembership.user_id == invited_user.id, + ).first() + if existing_membership: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="The user is already a member of this workspace.", + ) - token = create_invite_token(admin_user.id, payload.email, payload.workspace_name) + # 3. Create invitation + token = create_invite_token(current_user.id, payload.email, payload.workspace_name) token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.INVITE_TOKEN_EXPIRY_HOURS) invitation = WorkspaceInvitation( email=payload.email, - inviter_id=admin_user.id, + inviter_id=current_user.id, token_hash=token_hash, workspace_name=payload.workspace_name, expires_at=expires_at, @@ -71,3 +111,114 @@ def invite_workspace( invite_link=join_link, expires_in_hours=settings.INVITE_TOKEN_EXPIRY_HOURS, ) + + +@router.get("/invite/verify", response_model=WorkspaceInviteVerifyResponse, status_code=status.HTTP_200_OK) +def verify_workspace_invite( + token: str, + db: Session = Depends(get_db), +): + """Verify an invitation token and return workspace details.""" + payload = decode_invite_token(token) + if not payload: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired invitation token.", + ) + + token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() + invitation = db.query(WorkspaceInvitation).filter( + WorkspaceInvitation.token_hash == token_hash + ).first() + + if not invitation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Invitation not found.", + ) + + inviter = db.query(User).filter(User.id == invitation.inviter_id).first() + inviter_email = inviter.email if inviter else "unknown" + inviter_username = inviter.username if inviter else "unknown" + + is_expired = invitation.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc) + + return WorkspaceInviteVerifyResponse( + workspace_name=invitation.workspace_name, + inviter_email=inviter_email, + inviter_username=inviter_username, + email=invitation.email, + expires_at=invitation.expires_at, + is_expired=is_expired, + is_accepted=invitation.accepted_at is not None, + ) + + +@router.post("/invite/accept", response_model=WorkspaceInviteAcceptResponse, status_code=status.HTTP_200_OK) +def accept_workspace_invite( + token: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Accept a workspace invitation using the time-bound token.""" + payload = decode_invite_token(token) + if not payload: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired invitation token.", + ) + + token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() + invitation = db.query(WorkspaceInvitation).filter( + WorkspaceInvitation.token_hash == token_hash + ).first() + + if not invitation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Invitation not found.", + ) + + if invitation.accepted_at is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="This invitation has already been accepted.", + ) + + if invitation.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="This invitation has expired.", + ) + + # Lookup or create the workspace + workspace = db.query(Workspace).filter( + Workspace.name == invitation.workspace_name + ).first() + if not workspace: + workspace = Workspace(name=invitation.workspace_name) + db.add(workspace) + db.commit() + db.refresh(workspace) + + # Check if the user is already a member + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace.id, + WorkspaceMembership.user_id == current_user.id, + ).first() + + if not membership: + membership = WorkspaceMembership( + workspace_id=workspace.id, + user_id=current_user.id, + role="member", + ) + db.add(membership) + + invitation.accepted_at = datetime.now(timezone.utc) + db.commit() + + return WorkspaceInviteAcceptResponse( + message="Invitation accepted successfully.", + workspace_name=workspace.name, + ) diff --git a/backend/app/schemas.py b/backend/app/schemas.py index c630538f..44a2e2cf 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -106,6 +106,21 @@ class WorkspaceInviteResponse(BaseModel): expires_in_hours: int +class WorkspaceInviteVerifyResponse(BaseModel): + workspace_name: str + inviter_email: str + inviter_username: str + email: str + expires_at: datetime + is_expired: bool + is_accepted: bool + + +class WorkspaceInviteAcceptResponse(BaseModel): + message: str + workspace_name: str + + class TokenResponse(BaseModel): access_token: str refresh_token: str @@ -180,6 +195,7 @@ class DocumentResponse(BaseModel): uploaded_at: datetime summary: Optional[str] = None # New field for document summary task_id: Optional[str] = None + workspace_id: Optional[str] = None extracted_urls: Optional[List[str]] = None class Config: @@ -251,6 +267,7 @@ class ChatRequest(BaseModel): document_ids: Optional[List[str]] = None session_id: Optional[str] = None top_k: int = Field(default=5, ge=1, le=20) + workspace: Optional[str] = None class SourceChunk(BaseModel): @@ -295,6 +312,7 @@ class ChunkSettings(BaseModel): class UploadUrl(BaseModel): url: str + workspace: Optional[str] = None class ShareAnswerResponse(BaseModel): id: str diff --git a/backend/app/services/cleanup.py b/backend/app/services/cleanup.py index 70e79b6a..026d6445 100644 --- a/backend/app/services/cleanup.py +++ b/backend/app/services/cleanup.py @@ -82,6 +82,13 @@ def cleanup_old_deleted_documents(): except Exception as e: logger.warning("Error deleting file for %s: %s", doc.id, e) + try: + from app.rag.graph_builder import delete_graph + + delete_graph(user_id=doc.user_id, document_id=doc.id) + except Exception as e: + logger.warning("Error deleting graph for %s: %s", doc.id, e) + db.delete(doc) if old: diff --git a/backend/tests/test_chat.py b/backend/tests/test_chat.py index 1073c9c2..3e5386a1 100644 --- a/backend/tests/test_chat.py +++ b/backend/tests/test_chat.py @@ -127,3 +127,34 @@ class MockResponse: generate_answer(question="hello?", user_id="some-user", hf_token=None) from app.config import get_settings assert called_with_token == get_settings().HF_TOKEN + + +def test_chat_ask_success_with_document_ids(client, auth_headers, ready_document, monkeypatch): + monkeypatch.setattr( + "app.routes.chat.generate_answer", + lambda question, user_id, document_id=None, document_ids=None, **kwargs: { + "answer": "Mocked answer for multiple docs", + "sources": [ + { + "text": "Mock source", + "filename": "ready.txt", + "page": 1, + "score": 0.99, + "confidence": 99.0, + } + ], + }, + ) + + response = client.post( + "/api/v1/chat/ask", + headers=auth_headers, + json={"question": "What is in the doc?", "document_ids": [ready_document.id]}, + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["answer"] == "Mocked answer for multiple docs" + assert payload["document_id"] is None + assert payload["sources"][0]["filename"] == "ready.txt" + diff --git a/backend/tests/test_documents.py b/backend/tests/test_documents.py index 9ab16166..44886606 100644 --- a/backend/tests/test_documents.py +++ b/backend/tests/test_documents.py @@ -1,3 +1,4 @@ +import pytest import types from app.models import Document @@ -182,3 +183,234 @@ def test_delete_document_soft_deletes_and_hides_document(client, auth_headers, r get_response = client.get(f"/api/v1/documents/{doc_id}", headers=auth_headers) assert get_response.status_code == 404 + + +def test_list_trash_documents(client, auth_headers, ready_document, db_session): + # Set document as soft-deleted + from datetime import datetime, timezone + ready_document.is_deleted = True + ready_document.deleted_at = datetime.now(timezone.utc) + db_session.commit() + + # Get trash + response = client.get("/api/v1/documents/trash", headers=auth_headers) + assert response.status_code == 200 + payload = response.json() + assert len(payload) == 1 + assert payload[0]["id"] == ready_document.id + assert payload[0]["original_name"] == "ready.txt" + + +def test_restore_document(client, auth_headers, ready_document, db_session): + from datetime import datetime, timezone + ready_document.is_deleted = True + ready_document.deleted_at = datetime.now(timezone.utc) + db_session.commit() + + # Verify not in active list + list_response = client.get("/api/v1/documents/", headers=auth_headers) + assert list_response.json()["total"] == 0 + + # Restore + response = client.post(f"/api/v1/documents/{ready_document.id}/restore", headers=auth_headers) + assert response.status_code == 200 + + db_session.refresh(ready_document) + assert ready_document.is_deleted is False + assert ready_document.deleted_at is None + + # Verify back in active list + list_response = client.get("/api/v1/documents/", headers=auth_headers) + assert list_response.json()["total"] == 1 + + +def test_purge_document(client, auth_headers, ready_document, db_session, monkeypatch): + from app.rag import vectorstore + import app.routes.documents + import os + + chunk_deleted = [] + graph_deleted = [] + file_deleted = [] + + monkeypatch.setattr( + vectorstore, + "delete_document_chunks", + lambda document_id, user_id: chunk_deleted.append(document_id) + ) + monkeypatch.setattr( + app.routes.documents, + "delete_graph", + lambda user_id, document_id: graph_deleted.append(document_id) + ) + monkeypatch.setattr( + os.path, + "exists", + lambda path: True + ) + monkeypatch.setattr( + os, + "remove", + lambda path: file_deleted.append(path) + ) + + doc_id = ready_document.id + + # Purge document + response = client.delete(f"/api/v1/documents/{doc_id}/purge", headers=auth_headers) + assert response.status_code == 200 + + # Verify mocks were called + assert doc_id in chunk_deleted + assert doc_id in graph_deleted + assert len(file_deleted) == 1 + + # Verify DB record is gone + refreshed = db_session.get(Document, doc_id) + assert refreshed is None + + +def test_cleanup_old_deleted_documents_purges_graph(db_session, user, monkeypatch): + from app.models import Document + from app.services.cleanup import cleanup_old_deleted_documents + from datetime import datetime, timedelta, timezone + from app.rag import vectorstore, graph_builder + import os + + # Create document soft-deleted more than 30 days ago + from app.config import get_settings + settings = get_settings() + max_age_days = settings.DOC_CLEANUP_MAX_AGE_DAYS + deleted_time = datetime.now(timezone.utc) - timedelta(days=max_age_days + 1) + + doc = Document( + id="cleanup-test-doc-id", + user_id=user.id, + filename="cleanup_test.pdf", + original_name="cleanup_test.pdf", + is_deleted=True, + deleted_at=deleted_time, + ) + db_session.add(doc) + db_session.commit() + + chunk_deleted = [] + graph_deleted = [] + file_deleted = [] + + monkeypatch.setattr("app.database.SessionLocal", lambda: db_session) + # Mock database session factory in cleanup + class MockDbSessionContext: + def __init__(self, session): + self.session = session + def __enter__(self): + return self.session + def __exit__(self, exc_type, exc_val, exc_tb): + if not exc_type: + self.session.commit() + monkeypatch.setattr("app.services.cleanup.get_db_session", lambda: MockDbSessionContext(db_session)) + + monkeypatch.setattr( + vectorstore, + "delete_document_chunks", + lambda document_id, user_id: chunk_deleted.append(document_id) + ) + monkeypatch.setattr( + graph_builder, + "delete_graph", + lambda user_id, document_id: graph_deleted.append(document_id) + ) + monkeypatch.setattr( + os.path, + "exists", + lambda path: True + ) + monkeypatch.setattr( + os, + "remove", + lambda path: file_deleted.append(path) + ) + + cleanup_old_deleted_documents() + + assert "cleanup-test-doc-id" in chunk_deleted + assert "cleanup-test-doc-id" in graph_deleted + assert len(file_deleted) == 1 + + # Verify db record is gone + refreshed = db_session.get(Document, "cleanup-test-doc-id") + assert refreshed is None + + +@pytest.mark.anyio +async def test_document_cleanup_job_purges_graph(db_session, user, monkeypatch): + from app.models import Document + from app.main import document_cleanup_job + from datetime import datetime, timedelta, timezone + from app.rag import vectorstore, graph_builder + import os + + # Create document inactive for more than 30 days + cutoff_time = datetime.now(timezone.utc) - timedelta(days=31) + + doc = Document( + id="inactive-test-doc-id", + user_id=user.id, + filename="inactive_test.pdf", + original_name="inactive_test.pdf", + is_deleted=False, + last_accessed_at=cutoff_time, + uploaded_at=cutoff_time, + ) + db_session.add(doc) + db_session.commit() + + chunk_deleted = [] + graph_deleted = [] + file_deleted = [] + + monkeypatch.setattr("app.database.SessionLocal", lambda: db_session) + + # Mock asyncio.sleep to raise exception to break infinite loop + import asyncio + async def fake_sleep(seconds): + raise asyncio.CancelledError() + monkeypatch.setattr(asyncio, "sleep", fake_sleep) + + monkeypatch.setattr( + vectorstore, + "delete_document_chunks", + lambda document_id, user_id: chunk_deleted.append(document_id) + ) + monkeypatch.setattr( + graph_builder, + "delete_graph", + lambda user_id, document_id: graph_deleted.append(document_id) + ) + monkeypatch.setattr( + os.path, + "exists", + lambda path: True + ) + monkeypatch.setattr( + os, + "remove", + lambda path: file_deleted.append(path) + ) + + try: + await document_cleanup_job() + except asyncio.CancelledError: + pass + + assert "inactive-test-doc-id" in chunk_deleted + assert "inactive-test-doc-id" in graph_deleted + assert len(file_deleted) == 1 + + # Verify db record is gone + refreshed = db_session.get(Document, "inactive-test-doc-id") + assert refreshed is None + + + + diff --git a/backend/tests/test_rag_tools.py b/backend/tests/test_rag_tools.py index 30bbc9fa..bf33a8b0 100644 --- a/backend/tests/test_rag_tools.py +++ b/backend/tests/test_rag_tools.py @@ -154,11 +154,11 @@ def test_pdf_search_tool_formats_chunks_and_graph_context(monkeypatch): retrieve_calls = [] graph_calls = [] - def fake_retrieve(query, user_id, document_id=None, top_k=None): + def fake_retrieve(query, user_id, document_id=None, workspace=None, top_k=None, **kwargs): retrieve_calls.append((query, user_id, document_id)) return chunks - def fake_get_entity_context(query, user_id, document_id=None): + def fake_get_entity_context(query, user_id, document_id=None, **kwargs): graph_calls.append((query, user_id, document_id)) return "Alpha -> acquired -> Beta" diff --git a/backend/tests/test_retriever.py b/backend/tests/test_retriever.py index 6045dde4..4448cdce 100644 --- a/backend/tests/test_retriever.py +++ b/backend/tests/test_retriever.py @@ -30,7 +30,25 @@ def test_retrieve_fans_out_transformed_queries_and_merges_duplicates(monkeypatch monkeypatch.setattr(retriever, "embed_query", lambda query: f"embedding:{query}") monkeypatch.setattr(retriever, "get_reranker", lambda: None) - def fake_query_chunks(query_embedding, user_id, document_id=None, top_k=10): + # Mock SessionLocal and Document + class MockDoc: + id = "policy.pdf" + + class MockQuery: + def filter(self, *args, **kwargs): + return self + def all(self): + return [MockDoc()] + + class MockSession: + def query(self, *args, **kwargs): + return MockQuery() + def close(self): + pass + + monkeypatch.setattr("app.database.SessionLocal", lambda: MockSession()) + + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): searched_queries.append(query_embedding) if query_embedding == "embedding:taxes": return [ @@ -75,3 +93,92 @@ def fake_query_chunks(query_embedding, user_id, document_id=None, top_k=10): assert [chunk["id"] for chunk in chunks] == ["shared", "taxes", "healthcare"] assert chunks[0]["score"] == 1.0 assert chunks[0]["confidence"] == 100.0 + + +def test_retrieve_excludes_soft_deleted_documents(db_session, user, monkeypatch): + from app.models import Document + from app.rag import retriever + + # Create one active document and one deleted document + active_doc = Document( + id="active-doc-id", + user_id=user.id, + filename="active.pdf", + original_name="active.pdf", + is_deleted=False, + ) + deleted_doc = Document( + id="deleted-doc-id", + user_id=user.id, + filename="deleted.pdf", + original_name="deleted.pdf", + is_deleted=True, + ) + db_session.add(active_doc) + db_session.add(deleted_doc) + db_session.commit() + + monkeypatch.setattr("app.database.SessionLocal", lambda: db_session) + monkeypatch.setattr(retriever, "transform_query", lambda _query: ["query"]) + monkeypatch.setattr(retriever, "embed_query", lambda query: "embedding") + monkeypatch.setattr(retriever, "get_reranker", lambda: None) + + captured_doc_ids = [] + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): + nonlocal captured_doc_ids + captured_doc_ids = document_ids + return [] + + monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks) + monkeypatch.setattr(retriever.CustomBM25Retriever, "_get_relevant_documents", lambda *args, **kwargs: []) + + retriever.retrieve("test query", user_id=user.id) + + # Should only query for the active document ID + assert captured_doc_ids == ["active-doc-id"] + + +def test_retrieve_with_document_ids_list_and_rbac_checks(db_session, user, monkeypatch): + from app.models import Document + from app.rag import retriever + + # Create two documents: one owned by user (should be allowed), one owned by someone else (should be excluded) + owned_doc = Document( + id="owned-doc-id", + user_id=user.id, + filename="owned.pdf", + original_name="owned.pdf", + is_deleted=False, + ) + other_doc = Document( + id="other-doc-id", + user_id="another-user-id", + filename="other.pdf", + original_name="other.pdf", + is_deleted=False, + ) + db_session.add(owned_doc) + db_session.add(other_doc) + db_session.commit() + + monkeypatch.setattr("app.database.SessionLocal", lambda: db_session) + monkeypatch.setattr(retriever, "transform_query", lambda _query: ["query"]) + monkeypatch.setattr(retriever, "embed_query", lambda query: "embedding") + monkeypatch.setattr(retriever, "get_reranker", lambda: None) + + captured_doc_ids = [] + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): + nonlocal captured_doc_ids + captured_doc_ids = document_ids + return [] + + monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks) + monkeypatch.setattr(retriever.CustomBM25Retriever, "_get_relevant_documents", lambda *args, **kwargs: []) + + # Retrieve with both document IDs + retriever.retrieve("test query", user_id=user.id, document_ids=["owned-doc-id", "other-doc-id"]) + + # Should only query for the owned document ID (excluding other_doc because of RBAC check) + assert captured_doc_ids == ["owned-doc-id"] + + diff --git a/backend/tests/test_workspaces.py b/backend/tests/test_workspaces.py index effdbe3b..01ae03f0 100644 --- a/backend/tests/test_workspaces.py +++ b/backend/tests/test_workspaces.py @@ -1,8 +1,20 @@ -from app.auth import create_access_token, hash_password -from app.models import User, WorkspaceInvitation +import hashlib +from datetime import datetime, timedelta, timezone +from app.auth import create_access_token, hash_password, create_invite_token +from app.models import User, Workspace, WorkspaceMembership, WorkspaceInvitation + + +def test_workspace_invite_creates_workspace_and_membership_for_user(client, db_session, user, monkeypatch): + sent = {} + + def fake_send_email(to, subject, body, html=None): + sent["to"] = to + sent["subject"] = subject + sent["body"] = body + + monkeypatch.setattr("app.routes.workspaces.send_email", fake_send_email) -def test_workspace_invite_requires_admin(client, db_session, user): token = create_access_token(user.id) response = client.post( "/api/v1/workspaces/invite", @@ -10,47 +22,125 @@ def test_workspace_invite_requires_admin(client, db_session, user): json={"email": "invitee@example.com", "workspace_name": "Engineering"}, ) - assert response.status_code == 403 - assert response.json()["error"]["message"] == "Admin access required" + assert response.status_code == 200 + payload = response.json() + assert payload["email"] == "invitee@example.com" + assert payload["workspace_name"] == "Engineering" + assert "invite_link" in payload + + # Verify workspace and membership were created + workspace = db_session.query(Workspace).filter_by(name="Engineering").first() + assert workspace is not None + + membership = db_session.query(WorkspaceMembership).filter_by( + workspace_id=workspace.id, user_id=user.id + ).first() + assert membership is not None + assert membership.role == "admin" -def test_workspace_invite_creates_invitation_and_sends_email(client, db_session, monkeypatch): - admin = User( - username="admin", - email="admin@example.com", +def test_workspace_invite_existing_member_fails(client, db_session, user, monkeypatch): + # Setup workspace and membership + workspace = Workspace(name="Marketing") + db_session.add(workspace) + db_session.commit() + + member = User( + username="member", + email="member@example.com", hashed_password=hash_password("password123"), - is_admin=True, ) - db_session.add(admin) + db_session.add(member) db_session.commit() - db_session.refresh(admin) - sent = {} - - def fake_send_email(to, subject, body, html=None): - sent["to"] = to - sent["subject"] = subject - sent["body"] = body - - monkeypatch.setattr("app.routes.workspaces.send_email", fake_send_email) + membership1 = WorkspaceMembership(workspace_id=workspace.id, user_id=user.id, role="admin") + membership2 = WorkspaceMembership(workspace_id=workspace.id, user_id=member.id, role="member") + db_session.add(membership1) + db_session.add(membership2) + db_session.commit() - token = create_access_token(admin.id) + token = create_access_token(user.id) response = client.post( "/api/v1/workspaces/invite", headers={"Authorization": f"Bearer {token}"}, - json={"email": "invitee@example.com", "workspace_name": "Engineering"}, + json={"email": "member@example.com", "workspace_name": "Marketing"}, + ) + + assert response.status_code == 400 + assert "already a member" in response.json()["detail"] + + +def test_workspace_invite_verify(client, db_session, user): + invite_token = create_invite_token(user.id, "invitee@example.com", "Sales") + token_hash = hashlib.sha256(invite_token.encode("utf-8")).hexdigest() + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) + + invitation = WorkspaceInvitation( + email="invitee@example.com", + inviter_id=user.id, + token_hash=token_hash, + workspace_name="Sales", + expires_at=expires_at, + ) + db_session.add(invitation) + db_session.commit() + + response = client.get( + f"/api/v1/workspaces/invite/verify?token={invite_token}" ) assert response.status_code == 200 payload = response.json() + assert payload["workspace_name"] == "Sales" assert payload["email"] == "invitee@example.com" - assert payload["workspace_name"] == "Engineering" - assert "invite_link" in payload - assert payload["invite_link"].startswith("http") - assert "token=" in payload["invite_link"] - assert sent["to"] == "invitee@example.com" - assert "Invitation to join workspace" in sent["subject"] - - invitation = db_session.query(WorkspaceInvitation).filter_by(email="invitee@example.com").first() - assert invitation is not None - assert invitation.workspace_name == "Engineering" + assert payload["inviter_email"] == user.email + assert payload["is_expired"] is False + assert payload["is_accepted"] is False + + +def test_workspace_invite_accept(client, db_session, user): + invitee = User( + username="invitee", + email="invitee@example.com", + hashed_password=hash_password("password123"), + ) + db_session.add(invitee) + db_session.commit() + + invite_token = create_invite_token(user.id, "invitee@example.com", "Support") + token_hash = hashlib.sha256(invite_token.encode("utf-8")).hexdigest() + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) + + invitation = WorkspaceInvitation( + email="invitee@example.com", + inviter_id=user.id, + token_hash=token_hash, + workspace_name="Support", + expires_at=expires_at, + ) + db_session.add(invitation) + db_session.commit() + + # Pre-create workspace to map + workspace = Workspace(name="Support") + db_session.add(workspace) + db_session.commit() + + token = create_access_token(invitee.id) + response = client.post( + f"/api/v1/workspaces/invite/accept?token={invite_token}", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 200 + assert response.json()["workspace_name"] == "Support" + + # Verify membership and invitation status + membership = db_session.query(WorkspaceMembership).filter_by( + workspace_id=workspace.id, user_id=invitee.id + ).first() + assert membership is not None + assert membership.role == "member" + + db_session.refresh(invitation) + assert invitation.accepted_at is not None diff --git a/frontend/package-lock.json b/frontend/package-lock.json index ac9effab..41685711 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -6631,6 +6631,7 @@ "version": "2.3.2", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", + "dev": true, "hasInstallScript": true, "license": "MIT", "optional": true, diff --git a/frontend/src/app/dashboard/page.tsx b/frontend/src/app/dashboard/page.tsx index f5a25f0a..d9ef446c 100644 --- a/frontend/src/app/dashboard/page.tsx +++ b/frontend/src/app/dashboard/page.tsx @@ -10,6 +10,7 @@ import Header from "@/components/layout/Header"; import DocumentSidebar from "@/components/document/DocumentSidebar"; import ChatSessionSidebar from "@/components/chat/ChatSessionSidebar"; import ChatPanel from "@/components/chat/ChatPanel"; +import { useWorkspaceStore } from "@/store/workspace-store"; function PDFViewerSkeleton() { return (
s.workspace); const [documents, setDocuments] = useState([]); const prevDocsRef = useRef>({}); const [activeDoc, setActiveDoc] = useState(null); + const [selectedDocIds, setSelectedDocIds] = useState([]); const [pdfPage, setPdfPage] = useState(1); + + useEffect(() => { + setSelectedDocIds([]); + }, [workspace]); const [pdfHighlightTarget, setPdfHighlightTarget] = useState<{ page: number; rects?: { @@ -91,6 +98,21 @@ export default function DashboardPage() { if (initialized && !user) router.replace("/login"); }, [user, initialized, router]); + // Handle pending workspace invitations after login/register + useEffect(() => { + if (initialized && user) { + try { + const pendingToken = sessionStorage.getItem("pending_invite_token"); + if (pendingToken) { + sessionStorage.removeItem("pending_invite_token"); + router.replace(`/invite?token=${encodeURIComponent(pendingToken)}`); + } + } catch (e) { + console.warn("sessionStorage not accessible", e); + } + } + }, [user, initialized, router]); + // Check if Hugging Face token configuration is present useEffect(() => { if (user) { @@ -110,7 +132,7 @@ export default function DashboardPage() { setDocumentsLoading(true); try { const data = await api.get<{ documents?: DocInfo[]; items?: DocInfo[] }>( - "/api/v1/documents/" + `/api/v1/documents/?workspace=${encodeURIComponent(workspace)}` ); setDocuments(data?.documents ?? data?.items ?? []); setConnectionError(""); @@ -124,7 +146,7 @@ export default function DashboardPage() { } finally { setDocumentsLoading(false); } - }, []); + }, [workspace]); useEffect(() => { if (!user) return; @@ -180,9 +202,12 @@ export default function DashboardPage() { onSelectDoc={(doc) => { setActiveDoc(doc); setPdfPage(1); + setSelectedDocIds([]); }} onDocumentsChange={loadDocuments} onDocumentRenamed={handleDocumentRenamed} + selectedDocIds={selectedDocIds} + onSelectDocsChange={setSelectedDocIds} /> ); @@ -220,6 +245,7 @@ export default function DashboardPage() {
{ setPdfPage(target.page); setPdfHighlightTarget({ page: target.page, rects: target.highlightRects }); diff --git a/frontend/src/app/invite/page.tsx b/frontend/src/app/invite/page.tsx new file mode 100644 index 00000000..0d7a5ce1 --- /dev/null +++ b/frontend/src/app/invite/page.tsx @@ -0,0 +1,215 @@ +"use client"; + +import { useEffect, useState, Suspense } from "react"; +import { useRouter, useSearchParams } from "next/navigation"; +import { api } from "@/lib/api"; +import { useAuth } from "@/lib/auth"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card"; +import { Brain, CheckCircle2, AlertTriangle, Loader2, ArrowRight } from "lucide-react"; +import { toast } from "sonner"; +import { useWorkspaceStore } from "@/store/workspace-store"; + +interface InviteInfo { + workspace_name: string; + inviter_email: string; + inviter_username: string; + email: string; + expires_at: string; + is_expired: boolean; + is_accepted: boolean; +} + +function InviteContent() { + const router = useRouter(); + const searchParams = useSearchParams(); + const { user, initialized } = useAuth(); + const setWorkspace = useWorkspaceStore((s) => s.setWorkspace); + + const token = searchParams.get("token"); + + const [loading, setLoading] = useState(true); + const [accepting, setAccepting] = useState(false); + const [error, setError] = useState(""); + const [inviteInfo, setInviteInfo] = useState(null); + + useEffect(() => { + if (!token) { + setError("Invitation token is missing. Please check the link in your email."); + setLoading(false); + return; + } + + // Save token to sessionStorage in case they need to log in/register + try { + sessionStorage.setItem("pending_invite_token", token); + } catch (e) { + console.warn("sessionStorage is not available", e); + } + + api + .get(`/api/v1/workspaces/invite/verify?token=${encodeURIComponent(token)}`) + .then((data) => { + setInviteInfo(data); + if (data.is_expired) { + setError(`This invitation has expired (expired at ${new Date(data.expires_at).toLocaleString()}).`); + } else if (data.is_accepted) { + setError("This invitation has already been accepted."); + } + }) + .catch((err) => { + setError(err instanceof Error ? err.message : "Failed to verify invitation token."); + }) + .finally(() => { + setLoading(false); + }); + }, [token]); + + const handleAccept = async () => { + if (!token) return; + setAccepting(true); + + try { + await api.post(`/api/v1/workspaces/invite/accept?token=${encodeURIComponent(token)}`); + toast.success(`🎉 Welcome to the '${inviteInfo?.workspace_name}' workspace!`); + + // Clean up sessionStorage + try { + sessionStorage.removeItem("pending_invite_token"); + } catch (e) { + // ignore + } + + // Switch to company workspace in store so they see the documents immediately + setWorkspace("company"); + + router.push("/dashboard"); + } catch (err) { + toast.error(err instanceof Error ? err.message : "Failed to accept invitation."); + setAccepting(false); + } + }; + + if (loading || !initialized) { + return ( +
+ +

Verifying invitation details...

+
+ ); + } + + if (error) { + return ( + + +
+
+ +
+
+ Invitation Error + {error} +
+ + + +
+ ); + } + + const isLoggedIn = !!user; + + return ( + + +
+
+ +
+
+ Workspace Invitation + + You've been invited to join a collaborative workspace + +
+ + +
+

Workspace

+

{inviteInfo?.workspace_name}

+

+ Invited by {inviteInfo?.inviter_username} ({inviteInfo?.inviter_email}) +

+
+ + {isLoggedIn ? ( +
+

+ You are logged in as {user.username} ({user.email}). +

+ +
+ ) : ( +
+
+ Please log in or register an account to accept this invitation. +
+
+ + +
+
+ )} +
+
+ ); +} + +export default function InvitePage() { + return ( +
+ {/* Aesthetic blur gradients */} +
+
+ + + +

Loading...

+
+ }> + + +
+ ); +} diff --git a/frontend/src/components/chat/ChatPanel.tsx b/frontend/src/components/chat/ChatPanel.tsx index dde5d025..7efc5fa0 100644 --- a/frontend/src/components/chat/ChatPanel.tsx +++ b/frontend/src/components/chat/ChatPanel.tsx @@ -13,6 +13,7 @@ import MessageBubble from "./MessageBubble"; import SourceCard from "./SourceCard"; import { Send, Loader2, Trash2, MessageSquare, Download, Mic, MicOff, HelpCircle } from "lucide-react"; import { cn } from "@/lib/utils"; +import { useWorkspaceStore } from "@/store/workspace-store"; interface ISpeechRecognitionEvent { resultIndex: number; @@ -55,11 +56,13 @@ interface CitationTarget { interface Props { activeDoc: DocInfo | null; + selectedDocIds?: string[]; onCitationClick: (target: CitationTarget) => void; } -export default function ChatPanel({ activeDoc, onCitationClick }: Props) { +export default function ChatPanel({ activeDoc, selectedDocIds = [], onCitationClick }: Props) { const { t, i18n } = useTranslation(); + const workspace = useWorkspaceStore((s) => s.workspace); const messages = useChatStore((state) => state.messages); const input = useChatStore((state) => state.input); const streaming = useChatStore((state) => state.streaming); @@ -198,7 +201,13 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { const wsDone = new Promise((resolve, reject) => { ws.onopen = () => { // Send initial payload - ws.send(JSON.stringify({ question, document_id: activeDoc?.id || null, session_id: activeSessionId })); + ws.send(JSON.stringify({ + question, + document_id: selectedDocIds.length > 0 ? null : (activeDoc?.id || null), + document_ids: selectedDocIds.length > 0 ? selectedDocIds : null, + session_id: activeSessionId, + workspace + })); }; // If WS doesn't open within 800ms, treat as failure and fallback @@ -273,8 +282,10 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { try { const stream = api.streamPost("/api/v1/chat/ask/stream", { question, - document_id: activeDoc?.id || null, + document_id: selectedDocIds.length > 0 ? null : (activeDoc?.id || null), + document_ids: selectedDocIds.length > 0 ? selectedDocIds : null, session_id: activeSessionId, + workspace, }); for await (const event of stream) { @@ -544,10 +555,16 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) {

- {activeDoc ? t("chat.askAboutDocument") : t("chat.selectDocument")} + {selectedDocIds.length > 0 + ? `Chatting with ${selectedDocIds.length} selected ${selectedDocIds.length === 1 ? "file" : "files"}` + : activeDoc + ? t("chat.askAboutDocument") + : t("chat.selectDocument")}

- {activeDoc + {selectedDocIds.length > 0 + ? `Ask a question to query and synthesize across the selected ${selectedDocIds.length === 1 ? "file" : "files"}.` + : activeDoc ? t("chat.readyPrompt", { name: activeDoc.original_name }) : t("chat.uploadPrompt")}

@@ -624,7 +641,9 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { onChange={(e) => setInput(e.target.value)} onKeyDown={handleKeyDown} placeholder={ - activeDoc + selectedDocIds.length > 0 + ? `Ask about the ${selectedDocIds.length} selected ${selectedDocIds.length === 1 ? "file" : "files"}...` + : activeDoc ? t("chat.askPlaceholder", { name: activeDoc.original_name }) : t("chat.selectPlaceholder") } @@ -694,7 +713,7 @@ export default function ChatPanel({ activeDoc, onCitationClick }: Props) { )} - {messages.length > 0 && ( + {messages.length > 0 && activeDoc && selectedDocIds.length === 0 && ( <> {/* Export dropdown */}
diff --git a/frontend/src/components/document/DocumentSidebar.tsx b/frontend/src/components/document/DocumentSidebar.tsx index a4c158e5..0f42a9e7 100644 --- a/frontend/src/components/document/DocumentSidebar.tsx +++ b/frontend/src/components/document/DocumentSidebar.tsx @@ -16,7 +16,9 @@ import { import { useDropzone } from "react-dropzone"; import { Settings } from "lucide-react"; import DocumentSettings from "./DocumentSettings"; +import TrashModal from "./TrashModal"; import { toast } from "sonner"; +import { useWorkspaceStore } from "@/store/workspace-store"; interface Props { documents: DocInfo[]; @@ -25,6 +27,8 @@ interface Props { onSelectDoc: (doc: DocInfo) => void; onDocumentsChange: () => void; onDocumentRenamed: (doc: DocInfo) => void; + selectedDocIds?: string[]; + onSelectDocsChange?: (ids: string[]) => void; } function DocumentListSkeleton() { @@ -56,8 +60,11 @@ export default function DocumentSidebar({ onSelectDoc, onDocumentsChange, onDocumentRenamed, + selectedDocIds = [], + onSelectDocsChange, }: Props) { const { t } = useTranslation(); + const workspace = useWorkspaceStore((s) => s.workspace); const [uploading, setUploading] = useState(false); const [uploadProgress, setUploadProgress] = useState(0); const [uploadError, setUploadError] = useState(""); @@ -70,6 +77,7 @@ export default function DocumentSidebar({ const [driveLoading, setDriveLoading] = useState(true); const [driveConnecting, setDriveConnecting] = useState(false); const [driveError, setDriveError] = useState(""); + const [trashOpen, setTrashOpen] = useState(false); useEffect(() => { let cancelled = false; @@ -105,6 +113,7 @@ export default function DocumentSidebar({ const file = acceptedFiles[i]; const formData = new FormData(); formData.append("file", file); + formData.append("workspace", workspace); toast.info(`⏳ Uploading '${file.name}'...`); await api.postForm("/api/v1/documents/upload", formData); @@ -122,7 +131,7 @@ export default function DocumentSidebar({ } })(); }, - [onDocumentsChange, t] + [onDocumentsChange, t, workspace] ); const { getRootProps, getInputProps, isDragActive } = useDropzone({ @@ -334,10 +343,29 @@ export default function DocumentSidebar({ {/* ── Documents List ──────────────────────────── */}
-

- {loading - ? t("documents.documentsTitle", { count: "..." }) - : t("documents.documentsTitle", { count: documents.length })} +

+
+ + {loading + ? t("documents.documentsTitle", { count: "..." }) + : t("documents.documentsTitle", { count: documents.length })} + + {selectedDocIds && selectedDocIds.length > 0 && ( + + {selectedDocIds.length} selected + + )} +
+

@@ -369,10 +397,28 @@ export default function DocumentSidebar({ className={`w-full text-left p-2.5 rounded-lg transition-all duration-200 group ${activeDoc?.id === doc.id ? "bg-primary/15 border border-primary/30" + : selectedDocIds.includes(doc.id) + ? "bg-sidebar-accent/70 border border-sidebar-border/50" : "hover:bg-sidebar-accent border border-transparent"} ${doc.status !== "ready" ? "opacity-60 cursor-default" : "cursor-pointer focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2"}`} >
+ {doc.status === "ready" && ( + e.stopPropagation()} + onChange={(e) => { + e.stopPropagation(); + if (e.target.checked) { + onSelectDocsChange?.([...selectedDocIds, doc.id]); + } else { + onSelectDocsChange?.(selectedDocIds.filter((id) => id !== doc.id)); + } + }} + className="mt-1 h-3.5 w-3.5 rounded border-sidebar-border bg-transparent text-primary focus:ring-primary focus:ring-offset-sidebar cursor-pointer shrink-0" + /> + )} {statusIcon(doc.status)}
{isEditing ? ( @@ -479,6 +525,13 @@ export default function DocumentSidebar({ }} /> )} + {trashOpen && ( + + )}
); } diff --git a/frontend/src/components/document/TrashModal.tsx b/frontend/src/components/document/TrashModal.tsx new file mode 100644 index 00000000..4300c399 --- /dev/null +++ b/frontend/src/components/document/TrashModal.tsx @@ -0,0 +1,188 @@ +"use client"; + +import { useEffect, useState } from "react"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogDescription, +} from "@/components/ui/dialog"; +import { Button } from "@/components/ui/button"; +import { ScrollArea } from "@/components/ui/scroll-area"; +import { Trash2, RotateCcw, FileText, Loader2, AlertCircle } from "lucide-react"; +import { api } from "@/lib/api"; +import { toast } from "sonner"; +import type { DocInfo } from "@/app/dashboard/page"; + +interface Props { + open: boolean; + onOpenChange: (open: boolean) => void; + onDocumentsChange: () => void; +} + +export default function TrashModal({ open, onOpenChange, onDocumentsChange }: Props) { + const [trashDocs, setTrashDocs] = useState([]); + const [loading, setLoading] = useState(false); + const [actionId, setActionId] = useState(null); + const [actionType, setActionType] = useState<"restore" | "purge" | null>(null); + + const fetchTrash = async () => { + setLoading(true); + try { + const data = await api.get("/api/v1/documents/trash"); + setTrashDocs(data); + } catch (err) { + console.error("Failed to load trash documents:", err); + toast.error("Failed to load Recycle Bin items"); + } finally { + setLoading(false); + } + }; + + useEffect(() => { + if (open) { + void fetchTrash(); + } + }, [open]); + + const handleRestore = async (doc: DocInfo) => { + setActionId(doc.id); + setActionType("restore"); + try { + await api.post(`/api/v1/documents/${doc.id}/restore`); + toast.success(`🎉 Restored '${doc.original_name}' successfully!`); + // Update local state and trigger parent refresh + setTrashDocs((prev) => prev.filter((d) => d.id !== doc.id)); + onDocumentsChange(); + } catch (err) { + toast.error(err instanceof Error ? err.message : "Failed to restore document"); + } finally { + setActionId(null); + setActionType(null); + } + }; + + const handlePurge = async (doc: DocInfo) => { + if (!confirm(`⚠️ Are you sure you want to permanently delete '${doc.original_name}'? This action is irreversible and will purge all files, vector chunks, and graph data immediately.`)) { + return; + } + + setActionId(doc.id); + setActionType("purge"); + try { + await api.delete(`/api/v1/documents/${doc.id}/purge`); + toast.success(`🗑️ Permanently deleted '${doc.original_name}'`); + setTrashDocs((prev) => prev.filter((d) => d.id !== doc.id)); + onDocumentsChange(); + } catch (err) { + toast.error(err instanceof Error ? err.message : "Failed to delete document"); + } finally { + setActionId(null); + setActionType(null); + } + }; + + const formatSize = (bytes: number) => { + if (bytes < 1024) return `${bytes} B`; + if (bytes < 1048576) return `${(bytes / 1024).toFixed(1)} KB`; + return `${(bytes / 1048576).toFixed(1)} MB`; + }; + + return ( + + + + + + Recycle Bin + + + Items in the Recycle Bin will be permanently purged after 30 days. You can restore them or delete them permanently now. + + + +
+ + {loading && trashDocs.length === 0 ? ( +
+ +

Loading trashed files...

+
+ ) : trashDocs.length === 0 ? ( +
+
+ +
+

Recycle Bin is empty

+

+ Soft-deleted documents will appear here. +

+
+ ) : ( +
+ {trashDocs.map((doc) => ( +
+
+ +
+

+ {doc.original_name} +

+
+ {formatSize(doc.file_size)} + + Deleted: {new Date(doc.uploaded_at).toLocaleDateString()} +
+
+
+ +
+ + +
+
+ ))} +
+ )} +
+
+ +
+ +
+
+
+ ); +}