From ea21639f0b5a67a5d1f4bb5fecf3793a38b70f73 Mon Sep 17 00:00:00 2001 From: Andreas Chandra Date: Mon, 21 Jul 2025 21:13:32 +0700 Subject: [PATCH] #18 add structured output keypoints and summary --- .pylintrc | 2 +- chatbot.py | 94 ++++++++++++++++++++++++++++++++++++++++++------------ schema.py | 14 ++++++++ 3 files changed, 89 insertions(+), 21 deletions(-) create mode 100644 schema.py diff --git a/.pylintrc b/.pylintrc index 1d78cb2..6610280 100644 --- a/.pylintrc +++ b/.pylintrc @@ -345,7 +345,7 @@ indent-after-paren=4 indent-string=' ' # Maximum number of characters on a single line. -max-line-length=100 +max-line-length=150 # Maximum number of lines in a module. max-module-lines=1000 diff --git a/chatbot.py b/chatbot.py index 25aaf9f..bdcb555 100644 --- a/chatbot.py +++ b/chatbot.py @@ -11,7 +11,7 @@ TextLoader, ) from langchain_core.documents import Document -from langchain_core.messages import SystemMessage +from langchain_core.messages import SystemMessage, ToolMessage from langchain_core.tools import tool from langchain_huggingface import HuggingFaceEmbeddings from langchain_text_splitters.character import RecursiveCharacterTextSplitter @@ -21,6 +21,8 @@ from langgraph.graph import END, MessagesState, StateGraph from langgraph.prebuilt import ToolNode, tools_condition +from schema import PointSchema + class KnowledgeBaseChatbot: """Chatbot class for setting up embeddings, vector store, llm, graph, and tools.""" @@ -118,6 +120,53 @@ def retrieve(query: str): return retrieve + def _clean_message_sequence(self, messages: List) -> List: + """Clean up message sequence to ensure tool_use blocks have corresponding tool_result blocks. + + Args: + messages: List of messages to clean + + Returns: + List of cleaned messages + """ + cleaned_messages = [] + i = 0 + + while i < len(messages): + current_msg = messages[i] + + # If current message is AI with tool calls + if ( + current_msg.type == "ai" + and hasattr(current_msg, "tool_calls") + and current_msg.tool_calls + ): + + # Check if next message is a tool result + if i + 1 < len(messages) and messages[i + 1].type == "tool": + # Valid tool_use -> tool_result sequence + cleaned_messages.append(current_msg) + cleaned_messages.append(messages[i + 1]) + i += 2 + else: + # Incomplete tool sequence - skip the tool_use message + print(f"Adding missing tool results for message {i}") + cleaned_messages.append(current_msg) + for tool_call in current_msg.tool_calls: + tool_result = ToolMessage( + content=tool_call["args"], + tool_call_id=tool_call["id"], + name=tool_call["name"], + ) + cleaned_messages.append(tool_result) + i += 1 + else: + # Regular message + cleaned_messages.append(current_msg) + i += 1 + + return cleaned_messages + def query_or_respond(self, state: MessagesState) -> Dict: """entry point for the chatbot to either query the knowledge base or respond to a user message. @@ -132,6 +181,13 @@ def query_or_respond(self, state: MessagesState) -> Dict: print(f"State messages in query_or_respond: {state['messages']}\n") messages = state["messages"] + # Clean up incomplete tool sequences + messages = self._clean_message_sequence(messages) + + for msg in messages: + print("Message type: ", msg.type) + print(f"Msg: {msg}\n") + if not any(msg.type == "system" for msg in messages): system_msg = SystemMessage( content="You are a helpful assistant with access to a document knowledge base. " @@ -143,7 +199,7 @@ def query_or_respond(self, state: MessagesState) -> Dict: messages = [system_msg] + messages llm_with_tools = self.llm.bind_tools([self.retrieve]) - response = llm_with_tools.invoke(state["messages"]) + response = llm_with_tools.invoke(messages) # MessagesState appends messages to state instead of overwriting return {"messages": [response]} @@ -174,6 +230,8 @@ def generate(self, state: MessagesState) -> Dict: "Use the following pieces of retrieved context to answer " "the question. If you don't know the answer, say that you " "don't know." + "Use tool calls to extract key points and summary from the retrieved documents." + "do not use any tool if it is not needed." "\n\n" f"{docs_content}" ) @@ -186,7 +244,9 @@ def generate(self, state: MessagesState) -> Dict: prompt = [SystemMessage(system_message_content)] + conversation_messages # Run - response = self.llm.invoke(prompt) + llm_with_tools = self.llm.bind_tools([PointSchema]) + response = llm_with_tools.invoke(prompt) + return {"messages": [response]} def build_graph(self): @@ -260,25 +320,19 @@ def ask_question(self, question: str, thread_id: str = None) -> Tuple[str, List] doc = artifact retrieved_docs.append(doc) - # retrieved_docs = [] - # for msg in response["messages"]: - # if msg.type == "tool" and hasattr(msg, "artifact") and msg.artifact: - # for artifact in msg.artifact: - # if isinstance(artifact, dict): - # # Convert dict to Document if necessary - # doc = Document( - # id=artifact["id"], - # page_content=artifact["page_content"], - # metadata=artifact["metadata"], - # page_content_type=artifact["page_content"], - # ) - # else: - # doc = artifact - - # retrieved_docs.append(doc) + if latest_response.tool_calls: + for tool_call in latest_response.tool_calls: + print(f"tool_call: {tool_call}") + if tool_call["name"] == "PointSchema": + keypoints = tool_call["args"]["keypoints"] + summary = tool_call["args"]["summary"] + + response = keypoints + "\n\n" + summary + else: + response = latest_response.content print(f"Total retrieved documents: {len(retrieved_docs)}\n") print(f"Retrieved documents: {retrieved_docs}\n") print("=" * 50) - return latest_response.content, retrieved_docs + return response, retrieved_docs diff --git a/schema.py b/schema.py new file mode 100644 index 0000000..cd3d3a6 --- /dev/null +++ b/schema.py @@ -0,0 +1,14 @@ +"""Define all the schemas used in the chatbot.""" + +from pydantic import BaseModel, Field + + +class PointSchema(BaseModel): + """Use this tool if the user asks about keypoints, main points, main arguments, and summary from the document.""" + + keypoints: str = Field( + description="A concise list of the most important facts, findings, or arguments from the source material." + ) + summary: str = Field( + description="A short, readable overview that captures the main message or conclusion." + )