diff --git a/backend/app/database.py b/backend/app/database.py index 8b644760..f74c408b 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -166,6 +166,7 @@ def _migrate_schema(): ("documents", "drive_file_id", "ALTER TABLE documents ADD COLUMN drive_file_id VARCHAR(255)"), ("documents", "drive_folder_id", "ALTER TABLE documents ADD COLUMN drive_folder_id VARCHAR(255)"), ("documents", "drive_synced_at", "ALTER TABLE documents ADD COLUMN drive_synced_at TIMESTAMP"), + ("documents", "workspace_id", "ALTER TABLE documents ADD COLUMN workspace_id VARCHAR(36)"), ] for table, column, ddl in docs_migrations: if column not in existing_docs_columns: diff --git a/backend/app/models.py b/backend/app/models.py index a8611fe6..a812213b 100644 --- a/backend/app/models.py +++ b/backend/app/models.py @@ -163,6 +163,11 @@ class User(Base): back_populates="user", cascade="all, delete-orphan", ) + workspace_memberships = relationship( + "WorkspaceMembership", + back_populates="user", + cascade="all, delete-orphan", + ) class ApiKey(Base): @@ -185,6 +190,32 @@ class ApiKey(Base): user = relationship("User", back_populates="api_keys") +class Workspace(Base): + __tablename__ = "workspaces" + + id = Column(String(36), primary_key=True, default=generate_uuid) + name = Column(String(100), nullable=False) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + + # Relationships + memberships = relationship("WorkspaceMembership", back_populates="workspace", cascade="all, delete-orphan") + documents = relationship("Document", back_populates="workspace") + + +class WorkspaceMembership(Base): + __tablename__ = "workspace_memberships" + + id = Column(String(36), primary_key=True, default=generate_uuid) + workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=False, index=True) + user_id = Column(GUID, ForeignKey("users.id"), nullable=False, index=True) + role = Column(String(20), default="member", nullable=False) # "admin" | "member" + joined_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + + # Relationships + workspace = relationship("Workspace", back_populates="memberships") + user = relationship("User", back_populates="workspace_memberships") + + class WorkspaceInvitation(Base): __tablename__ = "workspace_invitations" @@ -254,6 +285,7 @@ class Document(Base): drive_synced_at = Column(DateTime, nullable=True) is_deleted = Column(Boolean, default=False, nullable=False, index=True) deleted_at = Column(DateTime, nullable=True) + workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=True, index=True) processing_progress = Column(Integer, default=0) processing_stage = Column(String(20), default="queued") retry_count = Column(Integer, default=0) @@ -264,6 +296,7 @@ class Document(Base): # Relationships owner = relationship("User", back_populates="documents") + workspace = relationship("Workspace", back_populates="documents") messages = relationship( "ChatMessage", back_populates="document", diff --git a/backend/app/rag/agent.py b/backend/app/rag/agent.py index 4917cfe5..d8bc0990 100644 --- a/backend/app/rag/agent.py +++ b/backend/app/rag/agent.py @@ -56,11 +56,12 @@ def get_agent_executor( hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, + workspace: Optional[str] = None, ): """Initialize the LangChain ReAct agent executor.""" # Initialize tools - pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, top_k=top_k) + pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, workspace=workspace, top_k=top_k) tools = [pdf_tool, MathTool(), WebSearchTool()] # Initialize LLM @@ -122,6 +123,7 @@ def generate_answer( hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, + workspace: Optional[str] = None, ) -> Dict[str, Any]: """ Agentic generation: retrieve via tools → reason → generate answer. @@ -146,7 +148,7 @@ def generate_answer( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) + executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history, workspace) result = executor.invoke({"input": question, "chat_history": formatted_history}) raw_answer = result.get("output", "") @@ -194,6 +196,7 @@ def generate_answer_stream( hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[List[Dict[str, str]]] = None, + workspace: Optional[str] = None, ) -> Generator[str, None, None]: """ Streaming Agentic pipeline. @@ -219,7 +222,7 @@ def generate_answer_stream( # ── Run Agent ──────────────────────────────────── try: - executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history) + executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history, workspace) sources_sent = False diff --git a/backend/app/rag/bm25.py b/backend/app/rag/bm25.py index 84d89c88..9ca8d196 100644 --- a/backend/app/rag/bm25.py +++ b/backend/app/rag/bm25.py @@ -108,6 +108,7 @@ def query_bm25( query: str, user_id: str, document_id: Optional[str] = None, + document_ids: Optional[List[str]] = None, top_k: int = 10, ) -> List[Dict[str, Any]]: """ @@ -127,6 +128,11 @@ def query_bm25( all_results = [] for path in glob.glob(os.path.join(user_dir, "*.pkl")): + # Filter by document_ids if provided + if document_ids is not None: + doc_id = os.path.basename(path).rsplit(".", 1)[0] + if doc_id not in document_ids: + continue results = _query_single_index(path, tokenized_query, top_k) all_results.extend(results) diff --git a/backend/app/rag/retriever.py b/backend/app/rag/retriever.py index e542c17f..65d96700 100644 --- a/backend/app/rag/retriever.py +++ b/backend/app/rag/retriever.py @@ -41,6 +41,7 @@ def invoke(self, query): class CustomVectorRetriever(BaseRetriever): user_id: str = Field(description="User ID") document_id: Optional[str] = Field(default=None, description="Document ID") + document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs") top_k: int = Field(default=10, description="Top K results") def _get_relevant_documents( @@ -51,6 +52,7 @@ def _get_relevant_documents( query_embedding=query_vector, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, top_k=self.top_k, ) return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates] @@ -59,6 +61,7 @@ def _get_relevant_documents( class CustomBM25Retriever(BaseRetriever): user_id: str = Field(description="User ID") document_id: Optional[str] = Field(default=None, description="Document ID") + document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs") top_k: int = Field(default=10, description="Top K results") def _get_relevant_documents( @@ -69,11 +72,13 @@ def _get_relevant_documents( query=query, user_id=self.user_id, document_id=self.document_id, + document_ids=self.document_ids, top_k=self.top_k, ) return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates] + def transform_query(query: str) -> List[str]: """Rewrite a user question into multiple retrieval-friendly search queries.""" original_query = query.strip() @@ -219,6 +224,7 @@ def retrieve( query: str, user_id: str, document_id: Optional[str] = None, + workspace: Optional[str] = None, top_k: Optional[int] = None, ) -> List[Dict[str, Any]]: """ @@ -228,17 +234,80 @@ def retrieve( Returns chunks with confidence scores. """ + from app.database import SessionLocal + from app.models import Document, WorkspaceMembership + + # Fetch active document IDs + db = SessionLocal() + try: + if document_id: + # Check if specific document is active (not deleted) + doc = db.query(Document).filter( + Document.id == document_id, + Document.is_deleted.is_(False), + ).first() + if not doc: + return [] + + # Verify user has access to the document + has_access = False + if str(doc.user_id) == str(user_id): + has_access = True + elif doc.workspace_id: + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == doc.workspace_id, + WorkspaceMembership.user_id == user_id, + ).first() + if membership: + has_access = True + + if not has_access: + return [] + active_doc_ids = [str(doc.id)] + else: + # Filter documents by workspace + filters = [Document.is_deleted.is_(False)] + if workspace == "company": + memberships = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user_id + ).all() + workspace_ids = [m.workspace_id for m in memberships] + if not workspace_ids: + return [] + filters.append(Document.workspace_id.in_(workspace_ids)) + elif workspace == "personal" or not workspace: + filters.append(Document.user_id == user_id) + filters.append(Document.workspace_id.is_(None)) + else: + # Specific workspace ID + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace, + WorkspaceMembership.user_id == user_id, + ).first() + if not membership: + return [] + filters.append(Document.workspace_id == workspace) + + docs = db.query(Document).filter(*filters).all() + if not docs: + return [] + active_doc_ids = [str(doc.id) for doc in docs] + finally: + db.close() + # ── Stage 1: Hybrid Search with Query Transformation ───────────── effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL vector_retriever = CustomVectorRetriever( user_id=user_id, document_id=document_id, + document_ids=active_doc_ids if not document_id else None, top_k=effective_top_k, ) bm25_retriever = CustomBM25Retriever( user_id=user_id, document_id=document_id, + document_ids=active_doc_ids if not document_id else None, top_k=effective_top_k, ) diff --git a/backend/app/rag/tools.py b/backend/app/rag/tools.py index 03813756..c03c56ab 100644 --- a/backend/app/rag/tools.py +++ b/backend/app/rag/tools.py @@ -156,6 +156,7 @@ class PDFSearchTool(BaseTool): user_id: str document_id: Optional[str] = None + workspace: Optional[str] = None top_k: Optional[int] = None # We'll store sources here to retrieve them after agent execution last_sources: List[Dict[str, Any]] = [] @@ -167,6 +168,7 @@ def _run(self, query: str) -> str: query=query, user_id=self.user_id, document_id=self.document_id, + workspace=self.workspace, top_k=self.top_k, ) diff --git a/backend/app/routes/chat.py b/backend/app/routes/chat.py index 00f14cc3..03e9bff1 100644 --- a/backend/app/routes/chat.py +++ b/backend/app/routes/chat.py @@ -444,6 +444,7 @@ def generate_answer( hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, + workspace: Optional[str] = None, ): from app.rag.agent import generate_answer as _generate_answer @@ -454,6 +455,7 @@ def generate_answer( hf_token=hf_token, top_k=top_k, chat_history=chat_history, + workspace=workspace, ) @@ -464,6 +466,7 @@ def generate_answer_stream( hf_token: Optional[str] = None, top_k: Optional[int] = None, chat_history: Optional[list] = None, + workspace: Optional[str] = None, ): from app.rag.agent import generate_answer_stream as _generate_answer_stream @@ -474,6 +477,7 @@ def generate_answer_stream( hf_token=hf_token, top_k=top_k, chat_history=chat_history, + workspace=workspace, ) @@ -575,6 +579,7 @@ def ask_question( hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, + workspace=payload.workspace, ) # Store result in cache for future identical questions @@ -716,6 +721,7 @@ def event_stream(): hf_token=user.hf_token, top_k=payload.top_k, chat_history=chat_history, + workspace=payload.workspace, ): yield chunk diff --git a/backend/app/routes/documents.py b/backend/app/routes/documents.py index 5aa5c73f..f4f9577c 100644 --- a/backend/app/routes/documents.py +++ b/backend/app/routes/documents.py @@ -16,7 +16,7 @@ import ipaddress import tempfile from urllib.parse import urlparse -from fastapi import APIRouter, Depends, UploadFile, File, status, Query, BackgroundTasks +from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status, Query, BackgroundTasks, Form from fastapi.responses import FileResponse from sqlalchemy.orm import Session @@ -28,7 +28,7 @@ AppException, ForbiddenException, ) -from app.models import User, Document +from app.models import User, Document, Workspace, WorkspaceMembership from app.schemas import ( DocumentResponse, DocumentListResponse, @@ -60,6 +60,50 @@ router = APIRouter(prefix="/documents", tags=["Documents"]) +def get_target_workspace_id(workspace_param: Optional[str], user_id: str, db: Session) -> Optional[str]: + if not workspace_param or workspace_param == "personal": + return None + + if workspace_param == "company": + # Check if the user is a member of any workspace + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user_id + ).first() + if membership: + return membership.workspace_id + + # If not, create a default "Company" workspace for them + workspace = Workspace(name="Company") + db.add(workspace) + db.commit() + db.refresh(workspace) + + membership = WorkspaceMembership( + workspace_id=workspace.id, + user_id=user_id, + role="admin", + ) + db.add(membership) + db.commit() + return workspace.id + + # If a specific workspace ID was passed, just return it + return workspace_param + + +def check_document_access(doc: Document, user_id: str, db: Session) -> bool: + if str(doc.user_id) == str(user_id): + return True + if doc.workspace_id: + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == doc.workspace_id, + WorkspaceMembership.user_id == user_id, + ).first() + if membership: + return True + return False + + ALLOWED_MIME_TYPES = settings.ALLOWED_MIME_TYPES def _deserialize_doc(doc: Document) -> DocumentResponse: @@ -190,6 +234,7 @@ async def _crawl(): @router.post("/upload", response_model=DocumentResponse, status_code=status.HTTP_202_ACCEPTED) async def upload_document( file: UploadFile = File(...), + workspace: Optional[str] = Form(None), background_tasks: BackgroundTasks = None, user: User = Depends(get_current_user), db: Session = Depends(get_db), @@ -205,10 +250,11 @@ async def upload_document( Args: file: The uploaded file, provided as a multipart/form-data field in the request. + workspace: Optional workspace context ('personal', 'company', or specific UUID). background_tasks: FastAPI BackgroundTasks instance for in-process fallback execution. user: The currently authenticated user, injected by the `get_current_user` dependency. db: Database session, injected by the `get_db` dependency. - + Returns: DocumentResponse: The created document record, validated against the response model (includes id, filename, original_name, file_size, status, etc.). @@ -245,6 +291,11 @@ async def upload_document( file_size = Path(filepath).stat().st_size + # Resolve target workspace + if not isinstance(workspace, str): + workspace = None + target_workspace_id = get_target_workspace_id(workspace, user.id, db) + # ── Create database record ─────────────────────── document = Document( user_id=user.id, @@ -252,6 +303,7 @@ async def upload_document( original_name=file.filename, file_size=file_size, status="pending", + workspace_id=target_workspace_id, ) db.add(document) db.commit() @@ -356,6 +408,9 @@ async def upload_document_url( url_path = parsed.path.rstrip("/") original_name = f"{parsed.netloc}{url_path or ''}.txt" + # Resolve target workspace + target_workspace_id = get_target_workspace_id(payload.workspace, user.id, db) + # ── Create database record ───────────────────────────── document = Document( user_id=user.id, @@ -363,6 +418,7 @@ async def upload_document_url( original_name=original_name, file_size=file_size, status="pending", + workspace_id=target_workspace_id, ) db.add(document) db.commit() @@ -422,11 +478,10 @@ def get_document_status( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") return DocumentStatusResponse.model_validate(doc) @@ -436,36 +491,53 @@ def get_document_status( def list_documents( page: int = Query(1, ge=1), per_page: int = Query(20, ge=1), + workspace: Optional[str] = Query(None), user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """ - List all documents for the authenticated user with pagination. + List all documents for the authenticated user with pagination, filtered by workspace. - Returns a paginated list of documents belonging to the current user, + Returns a paginated list of documents belonging to the current user or their workspaces, ordered by upload date (newest first). - - Args: - page: The page number to retrieve (1: indexed). Defaults to 1. - per_page: The number of documents to return per page. Defaults to 20. - user: The currently authenticated user, injected by the `get_current_user` dependency. - db: Database session, injected by the `get_db` dependency. - - Returns: - DocumentListResponse: A response model containing: - - items: A list of DocumentResponse objects for the current page. - - total: The total number of documents for the user. - - page: The current page number. - - pages: The total number of pages available. """ """Number of rows to skip""" skip: int = (page - 1) * per_page + # Base query filters + filters = [Document.is_deleted.is_(False)] + + if workspace == "company": + memberships = db.query(WorkspaceMembership).filter( + WorkspaceMembership.user_id == user.id + ).all() + workspace_ids = [m.workspace_id for m in memberships] + if not workspace_ids: + return DocumentListResponse(items=[], total=0, page=page, pages=0) + filters.append(Document.workspace_id.in_(workspace_ids)) + elif workspace == "personal" or not workspace: + # Default to personal: own documents with no workspace + filters.append(Document.user_id == user.id) + filters.append(Document.workspace_id.is_(None)) + else: + # Filter by a specific workspace ID + # Verify user has access to this workspace + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace, + WorkspaceMembership.user_id == user.id, + ).first() + if not membership: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have access to this workspace.", + ) + filters.append(Document.workspace_id == workspace) + """Total Pages""" totalDocuments = ( db.query(Document) - .filter(Document.user_id == user.id, Document.is_deleted.is_(False)) + .filter(*filters) .count() ) """Total Pages""" @@ -474,7 +546,7 @@ def list_documents( """List all documents for the authenticated user in Paginated form""" docs = (( db.execute(select(Document) - .where(Document.user_id == user.id, Document.is_deleted.is_(False)) + .where(*filters) .order_by(Document.uploaded_at.desc()) .limit(per_page).offset(skip)) ) @@ -542,11 +614,10 @@ def get_document( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") return _deserialize_doc(doc) @@ -579,14 +650,13 @@ def serve_pdf( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") - filepath = os.path.join(settings.UPLOAD_DIR, user.id, doc.filename) + filepath = os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename) if not os.path.exists(filepath): raise NotFoundException("File") @@ -629,11 +699,10 @@ def delete_document( """ doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") doc.is_deleted = True @@ -669,14 +738,13 @@ def update_chunk_settings( HTTPException: With status code 404 if the document is not found or does not belong to the authenticated user. HTTPException: With status code 400 if the provided chunk size or overlap values are invalid (e.g., chunk size less than 100, or overlap greater than or equal to chunk size). """ - # Validate if the document exists and belongs to the user + # Validate if the document exists and belongs to the user or workspace doc = db.query(Document).filter( Document.id == document_id, - Document.user_id == user.id, Document.is_deleted.is_(False), ).first() - if not doc: + if not doc or not check_document_access(doc, user.id, db): raise NotFoundException("Document") if settings_update.chunk_size is not None: @@ -706,9 +774,9 @@ def update_chunk_settings( try: task = process_document.delay( document_id=doc.id, - filepath=os.path.join(settings.UPLOAD_DIR, user.id, doc.filename), + filepath=os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename), original_name=doc.original_name, - user_id=user.id, + user_id=str(doc.user_id), ) task_id = task.id except Exception as e: @@ -717,9 +785,9 @@ def update_chunk_settings( background_tasks.add_task( ingest_document, document_id=doc.id, - filepath=os.path.join(settings.UPLOAD_DIR, user.id, doc.filename), + filepath=os.path.join(settings.UPLOAD_DIR, str(doc.user_id), doc.filename), original_name=doc.original_name, - user_id=user.id, + user_id=str(doc.user_id), ) task_id = f"local_{uuid.uuid4().hex}" diff --git a/backend/app/routes/workspaces.py b/backend/app/routes/workspaces.py index 7ce2de63..8157389a 100644 --- a/backend/app/routes/workspaces.py +++ b/backend/app/routes/workspaces.py @@ -8,12 +8,17 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from app.auth import create_invite_token, get_admin_user +from app.auth import create_invite_token, get_current_user, decode_invite_token from app.config import get_settings from app.database import get_db from app.email_service import send_email -from app.models import User, WorkspaceInvitation -from app.schemas import WorkspaceInviteRequest, WorkspaceInviteResponse +from app.models import User, Workspace, WorkspaceMembership, WorkspaceInvitation +from app.schemas import ( + WorkspaceInviteRequest, + WorkspaceInviteResponse, + WorkspaceInviteVerifyResponse, + WorkspaceInviteAcceptResponse, +) router = APIRouter(prefix="/workspaces", tags=["Workspaces"]) settings = get_settings() @@ -23,24 +28,59 @@ @router.post("/invite", response_model=WorkspaceInviteResponse, status_code=status.HTTP_200_OK) def invite_workspace( payload: WorkspaceInviteRequest, - admin_user: User = Depends(get_admin_user), + current_user: User = Depends(get_current_user), db: Session = Depends(get_db), ): """Invite a user by email to join a workspace via a secure time-bound token.""" - existing_user = db.query(User).filter(User.email == payload.email).first() - if existing_user: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="A user with this email already exists.", + # 1. Lookup or create the Workspace + workspace = db.query(Workspace).filter(Workspace.name == payload.workspace_name).first() + if not workspace: + workspace = Workspace(name=payload.workspace_name) + db.add(workspace) + db.commit() + db.refresh(workspace) + + # Add the inviter as the admin of this workspace + membership = WorkspaceMembership( + workspace_id=workspace.id, + user_id=current_user.id, + role="admin", ) + db.add(membership) + db.commit() + else: + # Verify if current_user is a member of this workspace + user_membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace.id, + WorkspaceMembership.user_id == current_user.id, + ).first() + if not user_membership: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You do not have permission to invite users to this workspace.", + ) + + # 2. Check if the invited user is already a member of this workspace + invited_user = db.query(User).filter(User.email == payload.email).first() + if invited_user: + existing_membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace.id, + WorkspaceMembership.user_id == invited_user.id, + ).first() + if existing_membership: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="The user is already a member of this workspace.", + ) - token = create_invite_token(admin_user.id, payload.email, payload.workspace_name) + # 3. Create invitation + token = create_invite_token(current_user.id, payload.email, payload.workspace_name) token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() expires_at = datetime.now(timezone.utc) + timedelta(hours=settings.INVITE_TOKEN_EXPIRY_HOURS) invitation = WorkspaceInvitation( email=payload.email, - inviter_id=admin_user.id, + inviter_id=current_user.id, token_hash=token_hash, workspace_name=payload.workspace_name, expires_at=expires_at, @@ -71,3 +111,114 @@ def invite_workspace( invite_link=join_link, expires_in_hours=settings.INVITE_TOKEN_EXPIRY_HOURS, ) + + +@router.get("/invite/verify", response_model=WorkspaceInviteVerifyResponse, status_code=status.HTTP_200_OK) +def verify_workspace_invite( + token: str, + db: Session = Depends(get_db), +): + """Verify an invitation token and return workspace details.""" + payload = decode_invite_token(token) + if not payload: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired invitation token.", + ) + + token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() + invitation = db.query(WorkspaceInvitation).filter( + WorkspaceInvitation.token_hash == token_hash + ).first() + + if not invitation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Invitation not found.", + ) + + inviter = db.query(User).filter(User.id == invitation.inviter_id).first() + inviter_email = inviter.email if inviter else "unknown" + inviter_username = inviter.username if inviter else "unknown" + + is_expired = invitation.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc) + + return WorkspaceInviteVerifyResponse( + workspace_name=invitation.workspace_name, + inviter_email=inviter_email, + inviter_username=inviter_username, + email=invitation.email, + expires_at=invitation.expires_at, + is_expired=is_expired, + is_accepted=invitation.accepted_at is not None, + ) + + +@router.post("/invite/accept", response_model=WorkspaceInviteAcceptResponse, status_code=status.HTTP_200_OK) +def accept_workspace_invite( + token: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Accept a workspace invitation using the time-bound token.""" + payload = decode_invite_token(token) + if not payload: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid or expired invitation token.", + ) + + token_hash = hashlib.sha256(token.encode("utf-8")).hexdigest() + invitation = db.query(WorkspaceInvitation).filter( + WorkspaceInvitation.token_hash == token_hash + ).first() + + if not invitation: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Invitation not found.", + ) + + if invitation.accepted_at is not None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="This invitation has already been accepted.", + ) + + if invitation.expires_at.replace(tzinfo=timezone.utc) < datetime.now(timezone.utc): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="This invitation has expired.", + ) + + # Lookup or create the workspace + workspace = db.query(Workspace).filter( + Workspace.name == invitation.workspace_name + ).first() + if not workspace: + workspace = Workspace(name=invitation.workspace_name) + db.add(workspace) + db.commit() + db.refresh(workspace) + + # Check if the user is already a member + membership = db.query(WorkspaceMembership).filter( + WorkspaceMembership.workspace_id == workspace.id, + WorkspaceMembership.user_id == current_user.id, + ).first() + + if not membership: + membership = WorkspaceMembership( + workspace_id=workspace.id, + user_id=current_user.id, + role="member", + ) + db.add(membership) + + invitation.accepted_at = datetime.now(timezone.utc) + db.commit() + + return WorkspaceInviteAcceptResponse( + message="Invitation accepted successfully.", + workspace_name=workspace.name, + ) diff --git a/backend/app/schemas.py b/backend/app/schemas.py index c630538f..44a2e2cf 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -106,6 +106,21 @@ class WorkspaceInviteResponse(BaseModel): expires_in_hours: int +class WorkspaceInviteVerifyResponse(BaseModel): + workspace_name: str + inviter_email: str + inviter_username: str + email: str + expires_at: datetime + is_expired: bool + is_accepted: bool + + +class WorkspaceInviteAcceptResponse(BaseModel): + message: str + workspace_name: str + + class TokenResponse(BaseModel): access_token: str refresh_token: str @@ -180,6 +195,7 @@ class DocumentResponse(BaseModel): uploaded_at: datetime summary: Optional[str] = None # New field for document summary task_id: Optional[str] = None + workspace_id: Optional[str] = None extracted_urls: Optional[List[str]] = None class Config: @@ -251,6 +267,7 @@ class ChatRequest(BaseModel): document_ids: Optional[List[str]] = None session_id: Optional[str] = None top_k: int = Field(default=5, ge=1, le=20) + workspace: Optional[str] = None class SourceChunk(BaseModel): @@ -295,6 +312,7 @@ class ChunkSettings(BaseModel): class UploadUrl(BaseModel): url: str + workspace: Optional[str] = None class ShareAnswerResponse(BaseModel): id: str diff --git a/backend/tests/test_rag_tools.py b/backend/tests/test_rag_tools.py index 30bbc9fa..12eddb81 100644 --- a/backend/tests/test_rag_tools.py +++ b/backend/tests/test_rag_tools.py @@ -154,7 +154,7 @@ def test_pdf_search_tool_formats_chunks_and_graph_context(monkeypatch): retrieve_calls = [] graph_calls = [] - def fake_retrieve(query, user_id, document_id=None, top_k=None): + def fake_retrieve(query, user_id, document_id=None, workspace=None, top_k=None, **kwargs): retrieve_calls.append((query, user_id, document_id)) return chunks diff --git a/backend/tests/test_retriever.py b/backend/tests/test_retriever.py index 6045dde4..66e29083 100644 --- a/backend/tests/test_retriever.py +++ b/backend/tests/test_retriever.py @@ -30,7 +30,25 @@ def test_retrieve_fans_out_transformed_queries_and_merges_duplicates(monkeypatch monkeypatch.setattr(retriever, "embed_query", lambda query: f"embedding:{query}") monkeypatch.setattr(retriever, "get_reranker", lambda: None) - def fake_query_chunks(query_embedding, user_id, document_id=None, top_k=10): + # Mock SessionLocal and Document + class MockDoc: + id = "policy.pdf" + + class MockQuery: + def filter(self, *args, **kwargs): + return self + def all(self): + return [MockDoc()] + + class MockSession: + def query(self, *args, **kwargs): + return MockQuery() + def close(self): + pass + + monkeypatch.setattr("app.database.SessionLocal", lambda: MockSession()) + + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): searched_queries.append(query_embedding) if query_embedding == "embedding:taxes": return [ @@ -75,3 +93,47 @@ def fake_query_chunks(query_embedding, user_id, document_id=None, top_k=10): assert [chunk["id"] for chunk in chunks] == ["shared", "taxes", "healthcare"] assert chunks[0]["score"] == 1.0 assert chunks[0]["confidence"] == 100.0 + + +def test_retrieve_excludes_soft_deleted_documents(db_session, user, monkeypatch): + from app.models import Document + from app.rag import retriever + + # Create one active document and one deleted document + active_doc = Document( + id="active-doc-id", + user_id=user.id, + filename="active.pdf", + original_name="active.pdf", + is_deleted=False, + ) + deleted_doc = Document( + id="deleted-doc-id", + user_id=user.id, + filename="deleted.pdf", + original_name="deleted.pdf", + is_deleted=True, + ) + db_session.add(active_doc) + db_session.add(deleted_doc) + db_session.commit() + + monkeypatch.setattr("app.database.SessionLocal", lambda: db_session) + monkeypatch.setattr(retriever, "transform_query", lambda _query: ["query"]) + monkeypatch.setattr(retriever, "embed_query", lambda query: "embedding") + monkeypatch.setattr(retriever, "get_reranker", lambda: None) + + captured_doc_ids = [] + def fake_query_chunks(query_embedding, user_id, document_id=None, document_ids=None, top_k=10): + nonlocal captured_doc_ids + captured_doc_ids = document_ids + return [] + + monkeypatch.setattr(retriever, "query_chunks", fake_query_chunks) + monkeypatch.setattr(retriever.CustomBM25Retriever, "_get_relevant_documents", lambda *args, **kwargs: []) + + retriever.retrieve("test query", user_id=user.id) + + # Should only query for the active document ID + assert captured_doc_ids == ["active-doc-id"] + diff --git a/backend/tests/test_workspaces.py b/backend/tests/test_workspaces.py index effdbe3b..01ae03f0 100644 --- a/backend/tests/test_workspaces.py +++ b/backend/tests/test_workspaces.py @@ -1,8 +1,20 @@ -from app.auth import create_access_token, hash_password -from app.models import User, WorkspaceInvitation +import hashlib +from datetime import datetime, timedelta, timezone +from app.auth import create_access_token, hash_password, create_invite_token +from app.models import User, Workspace, WorkspaceMembership, WorkspaceInvitation + + +def test_workspace_invite_creates_workspace_and_membership_for_user(client, db_session, user, monkeypatch): + sent = {} + + def fake_send_email(to, subject, body, html=None): + sent["to"] = to + sent["subject"] = subject + sent["body"] = body + + monkeypatch.setattr("app.routes.workspaces.send_email", fake_send_email) -def test_workspace_invite_requires_admin(client, db_session, user): token = create_access_token(user.id) response = client.post( "/api/v1/workspaces/invite", @@ -10,47 +22,125 @@ def test_workspace_invite_requires_admin(client, db_session, user): json={"email": "invitee@example.com", "workspace_name": "Engineering"}, ) - assert response.status_code == 403 - assert response.json()["error"]["message"] == "Admin access required" + assert response.status_code == 200 + payload = response.json() + assert payload["email"] == "invitee@example.com" + assert payload["workspace_name"] == "Engineering" + assert "invite_link" in payload + + # Verify workspace and membership were created + workspace = db_session.query(Workspace).filter_by(name="Engineering").first() + assert workspace is not None + + membership = db_session.query(WorkspaceMembership).filter_by( + workspace_id=workspace.id, user_id=user.id + ).first() + assert membership is not None + assert membership.role == "admin" -def test_workspace_invite_creates_invitation_and_sends_email(client, db_session, monkeypatch): - admin = User( - username="admin", - email="admin@example.com", +def test_workspace_invite_existing_member_fails(client, db_session, user, monkeypatch): + # Setup workspace and membership + workspace = Workspace(name="Marketing") + db_session.add(workspace) + db_session.commit() + + member = User( + username="member", + email="member@example.com", hashed_password=hash_password("password123"), - is_admin=True, ) - db_session.add(admin) + db_session.add(member) db_session.commit() - db_session.refresh(admin) - sent = {} - - def fake_send_email(to, subject, body, html=None): - sent["to"] = to - sent["subject"] = subject - sent["body"] = body - - monkeypatch.setattr("app.routes.workspaces.send_email", fake_send_email) + membership1 = WorkspaceMembership(workspace_id=workspace.id, user_id=user.id, role="admin") + membership2 = WorkspaceMembership(workspace_id=workspace.id, user_id=member.id, role="member") + db_session.add(membership1) + db_session.add(membership2) + db_session.commit() - token = create_access_token(admin.id) + token = create_access_token(user.id) response = client.post( "/api/v1/workspaces/invite", headers={"Authorization": f"Bearer {token}"}, - json={"email": "invitee@example.com", "workspace_name": "Engineering"}, + json={"email": "member@example.com", "workspace_name": "Marketing"}, + ) + + assert response.status_code == 400 + assert "already a member" in response.json()["detail"] + + +def test_workspace_invite_verify(client, db_session, user): + invite_token = create_invite_token(user.id, "invitee@example.com", "Sales") + token_hash = hashlib.sha256(invite_token.encode("utf-8")).hexdigest() + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) + + invitation = WorkspaceInvitation( + email="invitee@example.com", + inviter_id=user.id, + token_hash=token_hash, + workspace_name="Sales", + expires_at=expires_at, + ) + db_session.add(invitation) + db_session.commit() + + response = client.get( + f"/api/v1/workspaces/invite/verify?token={invite_token}" ) assert response.status_code == 200 payload = response.json() + assert payload["workspace_name"] == "Sales" assert payload["email"] == "invitee@example.com" - assert payload["workspace_name"] == "Engineering" - assert "invite_link" in payload - assert payload["invite_link"].startswith("http") - assert "token=" in payload["invite_link"] - assert sent["to"] == "invitee@example.com" - assert "Invitation to join workspace" in sent["subject"] - - invitation = db_session.query(WorkspaceInvitation).filter_by(email="invitee@example.com").first() - assert invitation is not None - assert invitation.workspace_name == "Engineering" + assert payload["inviter_email"] == user.email + assert payload["is_expired"] is False + assert payload["is_accepted"] is False + + +def test_workspace_invite_accept(client, db_session, user): + invitee = User( + username="invitee", + email="invitee@example.com", + hashed_password=hash_password("password123"), + ) + db_session.add(invitee) + db_session.commit() + + invite_token = create_invite_token(user.id, "invitee@example.com", "Support") + token_hash = hashlib.sha256(invite_token.encode("utf-8")).hexdigest() + expires_at = datetime.now(timezone.utc) + timedelta(hours=24) + + invitation = WorkspaceInvitation( + email="invitee@example.com", + inviter_id=user.id, + token_hash=token_hash, + workspace_name="Support", + expires_at=expires_at, + ) + db_session.add(invitation) + db_session.commit() + + # Pre-create workspace to map + workspace = Workspace(name="Support") + db_session.add(workspace) + db_session.commit() + + token = create_access_token(invitee.id) + response = client.post( + f"/api/v1/workspaces/invite/accept?token={invite_token}", + headers={"Authorization": f"Bearer {token}"}, + ) + + assert response.status_code == 200 + assert response.json()["workspace_name"] == "Support" + + # Verify membership and invitation status + membership = db_session.query(WorkspaceMembership).filter_by( + workspace_id=workspace.id, user_id=invitee.id + ).first() + assert membership is not None + assert membership.role == "member" + + db_session.refresh(invitation) + assert invitation.accepted_at is not None diff --git a/frontend/src/app/dashboard/page.tsx b/frontend/src/app/dashboard/page.tsx index f5a25f0a..3384e242 100644 --- a/frontend/src/app/dashboard/page.tsx +++ b/frontend/src/app/dashboard/page.tsx @@ -10,6 +10,7 @@ import Header from "@/components/layout/Header"; import DocumentSidebar from "@/components/document/DocumentSidebar"; import ChatSessionSidebar from "@/components/chat/ChatSessionSidebar"; import ChatPanel from "@/components/chat/ChatPanel"; +import { useWorkspaceStore } from "@/store/workspace-store"; function PDFViewerSkeleton() { return (
Verifying invitation details...
+Workspace
+{inviteInfo?.workspace_name}
++ Invited by {inviteInfo?.inviter_username} ({inviteInfo?.inviter_email}) +
++ You are logged in as {user.username} ({user.email}). +
+ +Loading...
+