Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/app/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def _migrate_schema():
("documents", "drive_file_id", "ALTER TABLE documents ADD COLUMN drive_file_id VARCHAR(255)"),
("documents", "drive_folder_id", "ALTER TABLE documents ADD COLUMN drive_folder_id VARCHAR(255)"),
("documents", "drive_synced_at", "ALTER TABLE documents ADD COLUMN drive_synced_at TIMESTAMP"),
("documents", "workspace_id", "ALTER TABLE documents ADD COLUMN workspace_id VARCHAR(36)"),
]
for table, column, ddl in docs_migrations:
if column not in existing_docs_columns:
Expand Down
7 changes: 7 additions & 0 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ async def document_cleanup_job():
except Exception as e:
logger.warning(f"Auto-cleanup: Error deleting vectors for document {doc.id}: {e}")

# Delete knowledge graph
try:
from app.rag.graph_builder import delete_graph
delete_graph(user_id=doc.user_id, document_id=doc.id)
except Exception as e:
logger.warning(f"Auto-cleanup: Error deleting graph for document {doc.id}: {e}")

# Delete database record
db.delete(doc)

Expand Down
33 changes: 33 additions & 0 deletions backend/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ class User(Base):
back_populates="user",
cascade="all, delete-orphan",
)
workspace_memberships = relationship(
"WorkspaceMembership",
back_populates="user",
cascade="all, delete-orphan",
)


class ApiKey(Base):
Expand All @@ -185,6 +190,32 @@ class ApiKey(Base):
user = relationship("User", back_populates="api_keys")


class Workspace(Base):
__tablename__ = "workspaces"

id = Column(String(36), primary_key=True, default=generate_uuid)
name = Column(String(100), nullable=False)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))

# Relationships
memberships = relationship("WorkspaceMembership", back_populates="workspace", cascade="all, delete-orphan")
documents = relationship("Document", back_populates="workspace")


class WorkspaceMembership(Base):
__tablename__ = "workspace_memberships"

id = Column(String(36), primary_key=True, default=generate_uuid)
workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=False, index=True)
user_id = Column(GUID, ForeignKey("users.id"), nullable=False, index=True)
role = Column(String(20), default="member", nullable=False) # "admin" | "member"
joined_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))

# Relationships
workspace = relationship("Workspace", back_populates="memberships")
user = relationship("User", back_populates="workspace_memberships")


class WorkspaceInvitation(Base):
__tablename__ = "workspace_invitations"

Expand Down Expand Up @@ -254,6 +285,7 @@ class Document(Base):
drive_synced_at = Column(DateTime, nullable=True)
is_deleted = Column(Boolean, default=False, nullable=False, index=True)
deleted_at = Column(DateTime, nullable=True)
workspace_id = Column(String(36), ForeignKey("workspaces.id"), nullable=True, index=True)
processing_progress = Column(Integer, default=0)
processing_stage = Column(String(20), default="queued")
retry_count = Column(Integer, default=0)
Expand All @@ -264,6 +296,7 @@ class Document(Base):

# Relationships
owner = relationship("User", back_populates="documents")
workspace = relationship("Workspace", back_populates="documents")
messages = relationship(
"ChatMessage",
back_populates="document",
Expand Down
34 changes: 31 additions & 3 deletions backend/app/rag/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,22 @@ def _format_chat_history(messages: List[Dict[str, str]]) -> str:
def get_agent_executor(
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[List[Dict[str, str]]] = None,
workspace: Optional[str] = None,
):
"""Initialize the LangChain ReAct agent executor."""

# Initialize tools
pdf_tool = PDFSearchTool(user_id=user_id, document_id=document_id, top_k=top_k)
pdf_tool = PDFSearchTool(
user_id=user_id,
document_id=document_id,
document_ids=document_ids,
workspace=workspace,
top_k=top_k,
)
tools = [pdf_tool, MathTool(), WebSearchTool()]

# Initialize LLM
Expand Down Expand Up @@ -119,9 +127,11 @@ def generate_answer(
question: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[List[Dict[str, str]]] = None,
workspace: Optional[str] = None,
) -> Dict[str, Any]:
"""
Agentic generation: retrieve via tools β†’ reason β†’ generate answer.
Expand All @@ -146,7 +156,15 @@ def generate_answer(

# ── Run Agent ────────────────────────────────────
try:
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history)
executor, pdf_tool, formatted_history = get_agent_executor(
user_id=user_id,
document_id=document_id,
document_ids=document_ids,
hf_token=hf_token,
top_k=top_k,
chat_history=chat_history,
workspace=workspace,
)
result = executor.invoke({"input": question, "chat_history": formatted_history})

raw_answer = result.get("output", "")
Expand Down Expand Up @@ -191,9 +209,11 @@ def generate_answer_stream(
question: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
hf_token: Optional[str] = None,
top_k: Optional[int] = None,
chat_history: Optional[List[Dict[str, str]]] = None,
workspace: Optional[str] = None,
) -> Generator[str, None, None]:
"""
Streaming Agentic pipeline.
Expand All @@ -219,7 +239,15 @@ def generate_answer_stream(

# ── Run Agent ────────────────────────────────────
try:
executor, pdf_tool, formatted_history = get_agent_executor(user_id, document_id, hf_token, top_k, chat_history)
executor, pdf_tool, formatted_history = get_agent_executor(
user_id=user_id,
document_id=document_id,
document_ids=document_ids,
hf_token=hf_token,
top_k=top_k,
chat_history=chat_history,
workspace=workspace,
)

sources_sent = False

Expand Down
6 changes: 6 additions & 0 deletions backend/app/rag/bm25.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def query_bm25(
query: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
top_k: int = 10,
) -> List[Dict[str, Any]]:
"""
Expand All @@ -127,6 +128,11 @@ def query_bm25(
all_results = []

for path in glob.glob(os.path.join(user_dir, "*.pkl")):
# Filter by document_ids if provided
if document_ids is not None:
doc_id = os.path.basename(path).rsplit(".", 1)[0]
if doc_id not in document_ids:
continue
results = _query_single_index(path, tokenized_query, top_k)
all_results.extend(results)

Expand Down
20 changes: 18 additions & 2 deletions backend/app/rag/graph_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,21 @@
settings = get_settings()


def _candidate_graphs(user_id: str, document_id: Optional[str]) -> Iterable[nx.Graph]:
def _candidate_graphs(
user_id: str,
document_id: Optional[str],
document_ids: Optional[List[str]] = None,
) -> Iterable[nx.Graph]:
if document_id:
graph = load_graph(user_id, document_id)
return [graph] if graph is not None else []
elif document_ids:
graphs = []
for doc_id in document_ids:
graph = load_graph(user_id, doc_id)
if graph is not None:
graphs.append(graph)
return graphs

graphs = []
for path in iter_graph_paths(user_id):
Expand Down Expand Up @@ -67,12 +78,17 @@ def get_entity_context(
query: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
) -> str:
"""Return compact graph relationship context relevant to the query."""
relationships: Dict[Tuple[str, str], Dict[str, object]] = {}

try:
graphs = _candidate_graphs(user_id=user_id, document_id=document_id)
graphs = _candidate_graphs(
user_id=user_id,
document_id=document_id,
document_ids=document_ids,
)
for graph in graphs:
matched_nodes = _match_query_nodes(graph, query)

Expand Down
100 changes: 99 additions & 1 deletion backend/app/rag/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def invoke(self, query):
class CustomVectorRetriever(BaseRetriever):
user_id: str = Field(description="User ID")
document_id: Optional[str] = Field(default=None, description="Document ID")
document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs")
top_k: int = Field(default=10, description="Top K results")

def _get_relevant_documents(
Expand All @@ -51,6 +52,7 @@ def _get_relevant_documents(
query_embedding=query_vector,
user_id=self.user_id,
document_id=self.document_id,
document_ids=self.document_ids,
top_k=self.top_k,
)
return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates]
Expand All @@ -59,6 +61,7 @@ def _get_relevant_documents(
class CustomBM25Retriever(BaseRetriever):
user_id: str = Field(description="User ID")
document_id: Optional[str] = Field(default=None, description="Document ID")
document_ids: Optional[List[str]] = Field(default=None, description="List of Document IDs")
top_k: int = Field(default=10, description="Top K results")

def _get_relevant_documents(
Expand All @@ -69,11 +72,13 @@ def _get_relevant_documents(
query=query,
user_id=self.user_id,
document_id=self.document_id,
document_ids=self.document_ids,
top_k=self.top_k,
)
return [LangchainDocument(page_content=c["text"], metadata=c) for c in candidates]



def transform_query(query: str) -> List[str]:
"""Rewrite a user question into multiple retrieval-friendly search queries."""
original_query = query.strip()
Expand Down Expand Up @@ -206,9 +211,11 @@ def _merge_candidates(candidates: List[Dict[str, Any]]) -> List[Dict[str, Any]]:

@trace_function(
"retrieve",
metadata_factory=lambda query, user_id, document_id=None, top_k=None: {
metadata_factory=lambda query, user_id, document_id=None, top_k=None, **kwargs: {
"user_id": user_id,
"document_id": document_id,
"document_ids": kwargs.get("document_ids"),
"workspace": kwargs.get("workspace"),
"embedding_model": settings.EMBEDDING_MODEL,
"reranker_model": settings.RERANKER_MODEL,
"top_k_retrieval": settings.TOP_K_RETRIEVAL,
Expand All @@ -219,6 +226,8 @@ def retrieve(
query: str,
user_id: str,
document_id: Optional[str] = None,
document_ids: Optional[List[str]] = None,
workspace: Optional[str] = None,
top_k: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Expand All @@ -228,17 +237,106 @@ def retrieve(

Returns chunks with confidence scores.
"""
from app.database import SessionLocal
from app.models import Document, WorkspaceMembership

# Fetch active document IDs
db = SessionLocal()
try:
if document_id:
# Check if specific document is active (not deleted)
doc = db.query(Document).filter(
Document.id == document_id,
Document.is_deleted.is_(False),
).first()
if not doc:
return []

# Verify user has access to the document
has_access = False
if str(doc.user_id) == str(user_id):
has_access = True
elif doc.workspace_id:
membership = db.query(WorkspaceMembership).filter(
WorkspaceMembership.workspace_id == doc.workspace_id,
WorkspaceMembership.user_id == user_id,
).first()
if membership:
has_access = True

if not has_access:
return []
active_doc_ids = [str(doc.id)]
elif document_ids:
# Filter by a list of document IDs and verify access for each
filters = [
Document.id.in_(document_ids),
Document.is_deleted.is_(False),
]
docs = db.query(Document).filter(*filters).all()

accessible_docs = []
for doc in docs:
has_access = False
if str(doc.user_id) == str(user_id):
has_access = True
elif doc.workspace_id:
membership = db.query(WorkspaceMembership).filter(
WorkspaceMembership.workspace_id == doc.workspace_id,
WorkspaceMembership.user_id == user_id,
).first()
if membership:
has_access = True
if has_access:
accessible_docs.append(doc)

if not accessible_docs:
return []
active_doc_ids = [str(doc.id) for doc in accessible_docs]
else:
# Filter documents by workspace
filters = [Document.is_deleted.is_(False)]
if workspace == "company":
memberships = db.query(WorkspaceMembership).filter(
WorkspaceMembership.user_id == user_id
).all()
workspace_ids = [m.workspace_id for m in memberships]
if not workspace_ids:
return []
filters.append(Document.workspace_id.in_(workspace_ids))
elif workspace == "personal" or not workspace:
filters.append(Document.user_id == user_id)
filters.append(Document.workspace_id.is_(None))
else:
# Specific workspace ID
membership = db.query(WorkspaceMembership).filter(
WorkspaceMembership.workspace_id == workspace,
WorkspaceMembership.user_id == user_id,
).first()
if not membership:
return []
filters.append(Document.workspace_id == workspace)

docs = db.query(Document).filter(*filters).all()
if not docs:
return []
active_doc_ids = [str(doc.id) for doc in docs]
finally:
db.close()

# ── Stage 1: Hybrid Search with Query Transformation ─────────────
effective_top_k = top_k if top_k is not None else settings.TOP_K_RETRIEVAL
vector_retriever = CustomVectorRetriever(
user_id=user_id,
document_id=document_id,
document_ids=active_doc_ids if not document_id else None,
top_k=effective_top_k,
)

bm25_retriever = CustomBM25Retriever(
user_id=user_id,
document_id=document_id,
document_ids=active_doc_ids if not document_id else None,
top_k=effective_top_k,
)

Expand Down
Loading
Loading