Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _migrate_schema():
("documents", "drive_file_id", "ALTER TABLE documents ADD COLUMN drive_file_id VARCHAR(255)"),
("documents", "drive_folder_id", "ALTER TABLE documents ADD COLUMN drive_folder_id VARCHAR(255)"),
("documents", "drive_synced_at", "ALTER TABLE documents ADD COLUMN drive_synced_at TIMESTAMP"),
("documents", "workspace_id", "ALTER TABLE documents ADD COLUMN workspace_id VARCHAR(36)"),
]
for table, column, ddl in docs_migrations:
if column not in existing_docs_columns:
Expand Down
33 changes: 33 additions & 0 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ class User(Base):
back_populates="user",
cascade="all, delete-orphan",
)
workspace_memberships = relationship(
"WorkspaceMembership",
back_populates="user",
cascade="all, delete-orphan",
)


class ApiKey(Base):
Expand All @@ -185,6 +190,32 @@ class ApiKey(Base):
user = relationship("User", back_populates="api_keys")


class Workspace(Base):
__tablename__ = "workspaces"

id = Column(String(36), primary_key=True, default=generate_uuid)
name = Column(String(100), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))

# Relationships
memberships = relationship("WorkspaceMembership", back_populates="workspace", cascade="all, delete-orphan")
documents = relationship("Document", back_populates="workspace")


class WorkspaceMembership(Base):
__tablename__ = "workspace_memberships"

id = Column(String(36), primary_key=True, default=generate_uuid)
workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=False, index=True)
user_id = Column(GUID, ForeignKey("users.id"), nullable=False, index=True)
role = Column(String(20), default="member", nullable=False) # "admin" | "member"
joined_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))

# Relationships
workspace = relationship("Workspace", back_populates="memberships")
user = relationship("User", back_populates="workspace_memberships")


class WorkspaceInvitation(Base):
__tablename__ = "workspace_invitations"

Expand Down Expand Up @@ -254,6 +285,7 @@ class Document(Base):
drive_synced_at = Column(DateTime, nullable=True)
is_deleted = Column(Boolean, default=False, nullable=False, index=True)
deleted_at = Column(DateTime, nullable=True)
workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=True, index=True)
processing_progress = Column(Integer, default=0)
processing_stage = Column(String(20), default="queued")
retry_count = Column(Integer, default=0)
Expand All @@ -264,6 +296,7 @@ class Document(Base):

# Relationships
owner = relationship("User", back_populates="documents")
workspace = relationship("Workspace", back_populates="documents")
messages = relationship(
"ChatMessage",
back_populates="document",
Expand Down
9 changes: 6 additions & 3 deletions backend/app/rag/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ def get_agent_executor(
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[List[Dict[str, str]]] = None,
workspace: Optional[str] = None,
):
"""Initialize the LangChain ReAct agent executor."""

# Initialize tools
pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, top_k=top_k)
pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, workspace=workspace, top_k=top_k)
tools = [pdf_tool, MathTool(), WebSearchTool()]

# Initialize LLM
Expand Down Expand Up @@ -122,6 +123,7 @@ def generate_answer(
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[List[Dict[str, str]]] = None,
workspace: Optional[str] = None,
) -> Dict[str, Any]:
"""
Agentic generation: retrieve via tools β†’ reason β†’ generate answer.
Expand All @@ -146,7 +148,7 @@ def generate_answer(

# ── Run Agent ────────────────────────────────────
try:
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history)
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history, workspace)
result = executor.invoke({"input": question, "chat_history": formatted_history})

raw_answer = result.get("output", "")
Expand Down Expand Up @@ -194,6 +196,7 @@ def generate_answer_stream(
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[List[Dict[str, str]]] = None,
workspace: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Streaming Agentic pipeline.
Expand All @@ -219,7 +222,7 @@ def generate_answer_stream(

# ── Run Agent ────────────────────────────────────
try:
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history)
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history, workspace)

sources_sent = False

Expand Down
6 changes: 6 additions & 0 deletions backend/app/rag/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def query_bm25(
query: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
top_k: int = 10,
) -> List[Dict[str, Any]]:
"""
Expand All @@ -127,6 +128,11 @@ def query_bm25(
all_results = []

for path in glob.glob(os.path.join(user_dir, "*.pkl")):
# Filter by document_ids if provided
if document_ids is not None:
doc_id = os.path.basename(path).rsplit(".", 1)[0]
if doc_id not in document_ids:
continue
results = _query_single_index(path, tokenized_query, top_k)
all_results.extend(results)

Expand Down
69 changes: 69 additions & 0 deletions backend/app/rag/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def invoke(self, query):
class CustomVectorRetriever(BaseRetriever):
user_id: str = Field(description="User ID")
document_id: Optional[str] = Field(default=None, description="Document ID")
document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs")
top_k: int = Field(default=10, description="Top K results")

def _get_relevant_documents(
Expand All @@ -51,6 +52,7 @@ def _get_relevant_documents(
query_embedding=query_vector,
user_id=self.user_id,
document_id=self.document_id,
document_ids=self.document_ids,
top_k=self.top_k,
)
return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates]
Expand All @@ -59,6 +61,7 @@ def _get_relevant_documents(
class CustomBM25Retriever(BaseRetriever):
user_id: str = Field(description="User ID")
document_id: Optional[str] = Field(default=None, description="Document ID")
document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs")
top_k: int = Field(default=10, description="Top K results")

def _get_relevant_documents(
Expand All @@ -69,11 +72,13 @@ def _get_relevant_documents(
query=query,
user_id=self.user_id,
document_id=self.document_id,
document_ids=self.document_ids,
top_k=self.top_k,
)
return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates]



def transform_query(query: str) -> List[str]:
"""Rewrite a user question into multiple retrieval-friendly search queries."""
original_query = query.strip()
Expand Down Expand Up @@ -219,6 +224,7 @@ def retrieve(
query: str,
user_id: str,
document_id: Optional[str] = None,
workspace: Optional[str] = None,
top_k: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Expand All @@ -228,17 +234,80 @@ def retrieve(

Returns chunks with confidence scores.
"""
from app.database import SessionLocal
from app.models import Document, WorkspaceMembership

# Fetch active document IDs
db = SessionLocal()
try:
if document_id:
# Check if specific document is active (not deleted)
doc = db.query(Document).filter(
Document.id == document_id,
Document.is_deleted.is_(False),
).first()
if not doc:
return []

# Verify user has access to the document
has_access = False
if str(doc.user_id) == str(user_id):
has_access = True
elif doc.workspace_id:
membership = db.query(WorkspaceMembership).filter(
WorkspaceMembership.workspace_id == doc.workspace_id,
WorkspaceMembership.user_id == user_id,
).first()
if membership:
has_access = True

if not has_access:
return []
active_doc_ids = [str(doc.id)]
else:
# Filter documents by workspace
filters = [Document.is_deleted.is_(False)]
if workspace == "company":
memberships = db.query(WorkspaceMembership).filter(
WorkspaceMembership.user_id == user_id
).all()
workspace_ids = [m.workspace_id for m in memberships]
if not workspace_ids:
return []
filters.append(Document.workspace_id.in_(workspace_ids))
elif workspace == "personal" or not workspace:
filters.append(Document.user_id == user_id)
filters.append(Document.workspace_id.is_(None))
else:
# Specific workspace ID
membership = db.query(WorkspaceMembership).filter(
WorkspaceMembership.workspace_id == workspace,
WorkspaceMembership.user_id == user_id,
).first()
if not membership:
return []
filters.append(Document.workspace_id == workspace)

docs = db.query(Document).filter(*filters).all()
if not docs:
return []
active_doc_ids = [str(doc.id) for doc in docs]
finally:
db.close()

# ── Stage 1: Hybrid Search with Query Transformation ─────────────
effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL
vector_retriever = CustomVectorRetriever(
user_id=user_id,
document_id=document_id,
document_ids=active_doc_ids if not document_id else None,
top_k=effective_top_k,
)

bm25_retriever = CustomBM25Retriever(
user_id=user_id,
document_id=document_id,
document_ids=active_doc_ids if not document_id else None,
top_k=effective_top_k,
)

Expand Down
2 changes: 2 additions & 0 deletions backend/app/rag/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ class PDFSearchTool(BaseTool):

user_id: str
document_id: Optional[str] = None
workspace: Optional[str] = None
top_k: Optional[int] = None
# We'll store sources here to retrieve them after agent execution
last_sources: List[Dict[str, Any]] = []
Expand All @@ -167,6 +168,7 @@ def _run(self, query: str) -> str:
query=query,
user_id=self.user_id,
document_id=self.document_id,
workspace=self.workspace,
top_k=self.top_k,
)

Expand Down
6 changes: 6 additions & 0 deletions backend/app/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ def generate_answer(
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[list] = None,
workspace: Optional[str] = None,
):
from app.rag.agent import generate_answer as _generate_answer

Expand All @@ -454,6 +455,7 @@ def generate_answer(
hf_token=hf_token,
top_k=top_k,
chat_history=chat_history,
workspace=workspace,
)


Expand All @@ -464,6 +466,7 @@ def generate_answer_stream(
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[list] = None,
workspace: Optional[str] = None,
):
from app.rag.agent import generate_answer_stream as _generate_answer_stream

Expand All @@ -474,6 +477,7 @@ def generate_answer_stream(
hf_token=hf_token,
top_k=top_k,
chat_history=chat_history,
workspace=workspace,
)


Expand Down Expand Up @@ -575,6 +579,7 @@ def ask_question(
hf_token=user.hf_token,
top_k=payload.top_k,
chat_history=chat_history,
workspace=payload.workspace,
)

# Store result in cache for future identical questions
Expand Down Expand Up @@ -716,6 +721,7 @@ def event_stream():
hf_token=user.hf_token,
top_k=payload.top_k,
chat_history=chat_history,
workspace=payload.workspace,
):
yield chunk

Expand Down
Loading
Loading