Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ indent-after-paren=4
indent-string=' '

# Maximum number of characters on a single line.
max-line-length=100
max-line-length=150

# Maximum number of lines in a module.
max-module-lines=1000
Expand Down
94 changes: 74 additions & 20 deletions chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
TextLoader,
)
from langchain_core.documents import Document
from langchain_core.messages import SystemMessage
from langchain_core.messages import SystemMessage, ToolMessage
from langchain_core.tools import tool
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_text_splitters.character import RecursiveCharacterTextSplitter
Expand All @@ -21,6 +21,8 @@
from langgraph.graph import END, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition

from schema import PointSchema


class KnowledgeBaseChatbot:
"""Chatbot class for setting up embeddings, vector store, llm, graph, and tools."""
Expand Down Expand Up @@ -118,6 +120,53 @@ def retrieve(query: str):

return retrieve

def _clean_message_sequence(self, messages: List) -> List:
"""Clean up message sequence to ensure tool_use blocks have corresponding tool_result blocks.

Args:
messages: List of messages to clean

Returns:
List of cleaned messages
"""
cleaned_messages = []
i = 0

while i < len(messages):
current_msg = messages[i]

# If current message is AI with tool calls
if (
current_msg.type == "ai"
and hasattr(current_msg, "tool_calls")
and current_msg.tool_calls
):

# Check if next message is a tool result
if i + 1 < len(messages) and messages[i + 1].type == "tool":
# Valid tool_use -> tool_result sequence
cleaned_messages.append(current_msg)
cleaned_messages.append(messages[i + 1])
i += 2
else:
# Incomplete tool sequence - skip the tool_use message
print(f"Adding missing tool results for message {i}")
cleaned_messages.append(current_msg)
for tool_call in current_msg.tool_calls:
tool_result = ToolMessage(
content=tool_call["args"],
tool_call_id=tool_call["id"],
name=tool_call["name"],
)
cleaned_messages.append(tool_result)
i += 1
else:
# Regular message
cleaned_messages.append(current_msg)
i += 1

return cleaned_messages

def query_or_respond(self, state: MessagesState) -> Dict:
"""entry point for the chatbot to either
query the knowledge base or respond to a user message.
Expand All @@ -132,6 +181,13 @@ def query_or_respond(self, state: MessagesState) -> Dict:
print(f"State messages in query_or_respond: {state['messages']}\n")

messages = state["messages"]
# Clean up incomplete tool sequences
messages = self._clean_message_sequence(messages)

for msg in messages:
print("Message type: ", msg.type)
print(f"Msg: {msg}\n")

if not any(msg.type == "system" for msg in messages):
system_msg = SystemMessage(
content="You are a helpful assistant with access to a document knowledge base. "
Expand All @@ -143,7 +199,7 @@ def query_or_respond(self, state: MessagesState) -> Dict:
messages = [system_msg] + messages

llm_with_tools = self.llm.bind_tools([self.retrieve])
response = llm_with_tools.invoke(state["messages"])
response = llm_with_tools.invoke(messages)
# MessagesState appends messages to state instead of overwriting
return {"messages": [response]}

Expand Down Expand Up @@ -174,6 +230,8 @@ def generate(self, state: MessagesState) -> Dict:
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know."
"Use tool calls to extract key points and summary from the retrieved documents."
"do not use any tool if it is not needed."
"\n\n"
f"{docs_content}"
)
Expand All @@ -186,7 +244,9 @@ def generate(self, state: MessagesState) -> Dict:
prompt = [SystemMessage(system_message_content)] + conversation_messages

# Run
response = self.llm.invoke(prompt)
llm_with_tools = self.llm.bind_tools([PointSchema])
response = llm_with_tools.invoke(prompt)

return {"messages": [response]}

def build_graph(self):
Expand Down Expand Up @@ -260,25 +320,19 @@ def ask_question(self, question: str, thread_id: str = None) -> Tuple[str, List]
doc = artifact
retrieved_docs.append(doc)

# retrieved_docs = []
# for msg in response["messages"]:
# if msg.type == "tool" and hasattr(msg, "artifact") and msg.artifact:
# for artifact in msg.artifact:
# if isinstance(artifact, dict):
# # Convert dict to Document if necessary
# doc = Document(
# id=artifact["id"],
# page_content=artifact["page_content"],
# metadata=artifact["metadata"],
# page_content_type=artifact["page_content"],
# )
# else:
# doc = artifact

# retrieved_docs.append(doc)
if latest_response.tool_calls:
for tool_call in latest_response.tool_calls:
print(f"tool_call: {tool_call}")
if tool_call["name"] == "PointSchema":
keypoints = tool_call["args"]["keypoints"]
summary = tool_call["args"]["summary"]

response = keypoints + "\n\n" + summary
else:
response = latest_response.content

print(f"Total retrieved documents: {len(retrieved_docs)}\n")
print(f"Retrieved documents: {retrieved_docs}\n")
print("=" * 50)

return latest_response.content, retrieved_docs
return response, retrieved_docs
14 changes: 14 additions & 0 deletions schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""Define all the schemas used in the chatbot."""

from pydantic import BaseModel, Field


class PointSchema(BaseModel):
"""Use this tool if the user asks about keypoints, main points, main arguments, and summary from the document."""

keypoints: str = Field(
description="A concise list of the most important facts, findings, or arguments from the source material."
)
summary: str = Field(
description="A short, readable overview that captures the main message or conclusion."
)