diff --git a/integrations/langchain/src/databricks_langchain/chat_models.py b/integrations/langchain/src/databricks_langchain/chat_models.py index e32098edb..d74ef5c92 100644 --- a/integrations/langchain/src/databricks_langchain/chat_models.py +++ b/integrations/langchain/src/databricks_langchain/chat_models.py @@ -1055,6 +1055,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return {"role": "user", **message_dict} elif isinstance(message, AIMessage): if tool_calls := _get_tool_calls_from_ai_message(message): + print(tool_calls) message_dict["tool_calls"] = tool_calls # type: ignore[assignment] # If tool calls present, content null value should be None not empty string. message_dict["content"] = message_dict["content"] or None # type: ignore[assignment] @@ -1196,15 +1197,37 @@ def _get_tool_calls_from_ai_message(message: AIMessage) -> List[Dict]: for tc in message.invalid_tool_calls ] + """ + thought signature encodes model reasoning + it is required for each tool call to gemini 3 pro - https://arc.net/l/quote/jhoeoqbl + this means we need to encode this info in the responses events in order to fix this bug, in addition to the work on this PR + + will have to change _langchain_message_stream_to_responses_stream + """ + if tool_calls or invalid_tool_calls: - return tool_calls + invalid_tool_calls + # Merge thoughtSignature from additional_kwargs if present + all_tool_calls = tool_calls + invalid_tool_calls + additional_tool_calls = message.additional_kwargs.get("tool_calls", []) + if additional_tool_calls: + # Create a mapping of tool call IDs to their thoughtSignature + thought_signatures = { + tc.get("id"): tc.get("thoughtSignature") + for tc in additional_tool_calls + if tc.get("thoughtSignature") + } + # Add thoughtSignature to matching tool calls + for tc in all_tool_calls: + if tc["id"] in thought_signatures: + tc["thoughtSignature"] = thought_signatures[tc["id"]] + return all_tool_calls # Get tool calls from additional kwargs if present. return [ { k: v for k, v in tool_call.items() # type: ignore[union-attr] - if k in {"id", "type", "function"} + if k in {"id", "type", "function", "thoughtSignature"} } for tool_call in message.additional_kwargs.get("tool_calls", []) ] diff --git a/integrations/langchain/tests/integration_tests/agent.py b/integrations/langchain/tests/integration_tests/agent.py new file mode 100644 index 000000000..3a4e15a4a --- /dev/null +++ b/integrations/langchain/tests/integration_tests/agent.py @@ -0,0 +1,200 @@ +from typing import Annotated, Any, Generator, Optional, Sequence, TypedDict, Union + +import mlflow +from langchain.messages import AIMessage, AIMessageChunk, AnyMessage +from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.tools import BaseTool +from langgraph.graph import END, StateGraph +from langgraph.graph.message import add_messages +from langgraph.prebuilt.tool_node import ToolNode +from mlflow.pyfunc import ResponsesAgent +from mlflow.types.responses import ( + ResponsesAgentRequest, + ResponsesAgentResponse, + ResponsesAgentStreamEvent, + output_to_responses_items_stream, + to_chat_completions_input, +) + +from databricks_langchain import ( + ChatDatabricks, + UCFunctionToolkit, +) + +############################################ +# Define your LLM endpoint and system prompt +############################################ +# TODO: Replace with your model serving endpoint +# LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5" +LLM_ENDPOINT_NAME = "databricks-gemini-3-pro" +llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME) + +# TODO: Update with your system prompt +system_prompt = "You are a helpful assistant that can run Python code." + +############################################################################### +## Define tools for your agent, enabling it to retrieve data or take actions +## beyond text generation +## To create and see usage examples of more tools, see +## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html +############################################################################### +tools = [] + +# You can use UDFs in Unity Catalog as agent tools +# Below, we add the `system.ai.python_exec` UDF, which provides +# a python code interpreter tool to our agent +# You can also add local LangChain python tools. See https://python.langchain.com/docs/concepts/tools + +# TODO: Add additional tools +UC_TOOL_NAMES = ["system.ai.python_exec"] +uc_toolkit = UCFunctionToolkit(function_names=UC_TOOL_NAMES) +tools.extend(uc_toolkit.tools) + +# Use Databricks vector search indexes as tools +# See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html#locally-develop-vector-search-retriever-tools-with-ai-bridge +# List to store vector search tool instances for unstructured retrieval. +VECTOR_SEARCH_TOOLS = [] + +# To add vector search retriever tools, +# use VectorSearchRetrieverTool and create_tool_info, +# then append the result to TOOL_INFOS. +# Example: +# VECTOR_SEARCH_TOOLS.append( +# VectorSearchRetrieverTool( +# index_name="", +# # filters="..." +# ) +# ) + +tools.extend(VECTOR_SEARCH_TOOLS) + +##################### +## Define agent logic +##################### + + +class AgentState(TypedDict): + messages: Annotated[Sequence[AnyMessage], add_messages] + custom_inputs: Optional[dict[str, Any]] + custom_outputs: Optional[dict[str, Any]] + + +def create_tool_calling_agent( + model: ChatDatabricks, + tools: Union[ToolNode, Sequence[BaseTool]], + system_prompt: Optional[str] = None, +): + model = model.bind_tools(tools) + + # Define the function that determines which node to go to + def should_continue(state: AgentState): + messages = state["messages"] + last_message = messages[-1] + # If there are function calls, continue. else, end + if isinstance(last_message, AIMessage) and last_message.tool_calls: + return "continue" + else: + return "end" + + if system_prompt: + preprocessor = RunnableLambda( + lambda state: [{"role": "system", "content": system_prompt}] + state["messages"] + ) + else: + preprocessor = RunnableLambda(lambda state: state["messages"]) + model_runnable = preprocessor | model + + def call_model( + state: AgentState, + config: RunnableConfig, + ): + response = model_runnable.invoke(state, config) + + return {"messages": [response]} + + workflow = StateGraph(AgentState) + + workflow.add_node("agent", RunnableLambda(call_model)) + workflow.add_node("tools", ToolNode(tools)) + + workflow.set_entry_point("agent") + workflow.add_conditional_edges( + "agent", + should_continue, + { + "continue": "tools", + "end": END, + }, + ) + workflow.add_edge("tools", "agent") + + return workflow.compile() + + +class LangGraphResponsesAgent(ResponsesAgent): + def __init__(self, agent): + self.agent = agent + + def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse: + session_id = None + if request.custom_inputs and "session_id" in request.custom_inputs: + session_id = request.custom_inputs.get("session_id") + elif request.context and request.context.conversation_id: + session_id = request.context.conversation_id + + if session_id: + mlflow.update_current_trace( + metadata={ + "mlflow.trace.session": session_id, + } + ) + + outputs = [ + event.item + for event in self.predict_stream(request) + if event.type == "response.output_item.done" + ] + return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs) + + def predict_stream( + self, + request: ResponsesAgentRequest, + ) -> Generator[ResponsesAgentStreamEvent, None, None]: + session_id = None + if request.custom_inputs and "session_id" in request.custom_inputs: + session_id = request.custom_inputs.get("session_id") + elif request.context and request.context.conversation_id: + session_id = request.context.conversation_id + + if session_id: + mlflow.update_current_trace( + metadata={ + "mlflow.trace.session": session_id, + } + ) + + cc_msgs = to_chat_completions_input([i.model_dump() for i in request.input]) + + for event in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates", "messages"]): + if event[0] == "updates": + for node_data in event[1].values(): + if len(node_data.get("messages", [])) > 0: + yield from output_to_responses_items_stream(node_data["messages"]) + # filter the streamed messages to just the generated text messages + elif event[0] == "messages": + try: + chunk = event[1][0] + if isinstance(chunk, AIMessageChunk) and (content := chunk.content): + yield ResponsesAgentStreamEvent( + **self.create_text_delta(delta=content, item_id=chunk.id), + ) + except Exception as e: + print(e) + + +# Create the agent object, and specify it as the agent object to use when +# loading the agent back for inference via mlflow.models.set_model() +mlflow.langchain.autolog() +agent = create_tool_calling_agent(llm, tools, system_prompt) +AGENT = LangGraphResponsesAgent(agent) +mlflow.models.set_model(AGENT) diff --git a/integrations/langchain/tests/integration_tests/test_chat_models.py b/integrations/langchain/tests/integration_tests/test_chat_models.py index 9986c6b31..06803d421 100644 --- a/integrations/langchain/tests/integration_tests/test_chat_models.py +++ b/integrations/langchain/tests/integration_tests/test_chat_models.py @@ -243,9 +243,9 @@ class GetWeather(BaseModel): return # Models should make at least one tool call when tool_choice is not "none" - assert len(response.tool_calls) >= 1, ( - f"Expected at least 1 tool call, got {len(response.tool_calls)}" - ) + assert ( + len(response.tool_calls) >= 1 + ), f"Expected at least 1 tool call, got {len(response.tool_calls)}" # The first tool call should be for GetWeather first_call = response.tool_calls[0] @@ -267,9 +267,9 @@ class GetWeather(BaseModel): ] ) # Should call GetWeather tool for the followup question - assert len(response.tool_calls) >= 1, ( - f"Expected at least 1 tool call, got {len(response.tool_calls)}" - ) + assert ( + len(response.tool_calls) >= 1 + ), f"Expected at least 1 tool call, got {len(response.tool_calls)}" tool_call = response.tool_calls[0] assert tool_call["name"] == "GetWeather", f"Expected GetWeather tool, got {tool_call['name']}" assert "location" in tool_call["args"], f"Expected location in args, got {tool_call['args']}" @@ -559,12 +559,8 @@ def test_chat_databricks_chatagent_invoke(): ): python_tool_used = True - assert has_tool_calls, ( - f"Expected ChatAgent to use tool calls for fibonacci computation. Content: {response.content}" - ) - assert python_tool_used, ( - f"Expected ChatAgent to use python execution tool for fibonacci computation. Content: {response.content}" - ) + assert has_tool_calls, f"Expected ChatAgent to use tool calls for fibonacci computation. Content: {response.content}" + assert python_tool_used, f"Expected ChatAgent to use python execution tool for fibonacci computation. Content: {response.content}" @pytest.mark.st_endpoints @@ -847,9 +843,9 @@ def test_chat_databricks_gpt5_stream_with_usage(): ] # Should have exactly ONE usage chunk from the final usage-only chunk - assert len(usage_chunks) == 1, ( - f"Expected exactly 1 usage chunk from GPT-5 final chunk, got {len(usage_chunks)}" - ) + assert ( + len(usage_chunks) == 1 + ), f"Expected exactly 1 usage chunk from GPT-5 final chunk, got {len(usage_chunks)}" # Verify usage chunk has correct metadata structure usage_chunk = usage_chunks[0] @@ -860,12 +856,12 @@ def test_chat_databricks_gpt5_stream_with_usage(): assert "total_tokens" in usage_chunk.usage_metadata # Verify token counts are positive - assert usage_chunk.usage_metadata["input_tokens"] > 0, ( - f"Expected positive input_tokens, got {usage_chunk.usage_metadata['input_tokens']}" - ) - assert usage_chunk.usage_metadata["output_tokens"] > 0, ( - f"Expected positive output_tokens, got {usage_chunk.usage_metadata['output_tokens']}" - ) + assert ( + usage_chunk.usage_metadata["input_tokens"] > 0 + ), f"Expected positive input_tokens, got {usage_chunk.usage_metadata['input_tokens']}" + assert ( + usage_chunk.usage_metadata["output_tokens"] > 0 + ), f"Expected positive output_tokens, got {usage_chunk.usage_metadata['output_tokens']}" # Verify total_tokens equals sum of input and output expected_total = ( @@ -875,3 +871,40 @@ def test_chat_databricks_gpt5_stream_with_usage(): f"Expected total_tokens ({usage_chunk.usage_metadata['total_tokens']}) " f"to equal input_tokens + output_tokens ({expected_total})" ) + + +def test_chat_databricks_with_gemini(): + os.environ["DATABRICKS_CONFIG_PROFILE"] = "dogfood" + from .agent import AGENT + + result = AGENT.predict( + { + "input": [ + {"role": "user", "content": "What is 6*7 in Python?"}, + { + "type": "function_call", + "id": "lc_run--e58dec26-ce5d-4597-b4f8-28e6db62cd49", + "call_id": "system__ai__python_exec", + "name": "system__ai__python_exec", + "arguments": '{"code": "print(6 * 7)"}', + }, + { + "type": "function_call_output", + "call_id": "system__ai__python_exec", + "output": '{"format": "SCALAR", "value": "42\\n"}', + }, + { + "type": "message", + "id": "lc_run--dd658def-dfdc-4bc7-b0d9-b6e25d1ecc48", + "content": [ + {"text": "The result of `6 * 7` in Python is 42.", "type": "output_text"} + ], + "role": "assistant", + }, + ] + } + ) + assert result is not None + assert result.output is not None + print(result.model_dump()) + assert False diff --git a/integrations/langchain/tests/unit_tests/test_chat_models.py b/integrations/langchain/tests/unit_tests/test_chat_models.py index db5eac7e4..ea5103655 100644 --- a/integrations/langchain/tests/unit_tests/test_chat_models.py +++ b/integrations/langchain/tests/unit_tests/test_chat_models.py @@ -479,9 +479,9 @@ def test_chat_model_stream_usage_only_chunk_missing_tokens(): usage_chunks = [ chunk for chunk in chunks if chunk.content == "" and chunk.usage_metadata is not None ] - assert len(usage_chunks) == 0, ( - f"Expected 0 usage chunks when tokens are missing, got {len(usage_chunks)}" - ) + assert ( + len(usage_chunks) == 0 + ), f"Expected 0 usage chunks when tokens are missing, got {len(usage_chunks)}" def test_chat_model_stream_usage_only_chunk_stream_usage_false(): @@ -532,9 +532,9 @@ def test_chat_model_stream_usage_only_chunk_stream_usage_false(): usage_chunks = [ chunk for chunk in chunks if chunk.content == "" and chunk.usage_metadata is not None ] - assert len(usage_chunks) == 0, ( - f"Expected 0 usage chunks when stream_usage=False, got {len(usage_chunks)}" - ) + assert ( + len(usage_chunks) == 0 + ), f"Expected 0 usage chunks when stream_usage=False, got {len(usage_chunks)}" class GetWeather(BaseModel): @@ -713,6 +713,35 @@ def test_convert_message_with_tool_calls() -> None: assert dict_result == message_with_tools +def test_convert_message_with_tool_calls_and_thought_signature() -> None: + ID = "system__ai__python_exec" + THOUGHT_SIG = "CikBjz1rXxsXPO9F7LWvkXdS3Fkl7lMvmk9yp2iIuuTv0vWI2wRd0vHm5QpZAY89a1" + tool_calls = [ + { + "id": ID, + "type": "function", + "function": { + "name": "system__ai__python_exec", + "arguments": '{"code":"print(6 * 7)"}', + }, + "thoughtSignature": THOUGHT_SIG, + } + ] + message = AIMessage( + content="", + additional_kwargs={"tool_calls": tool_calls}, + ) + + dict_result = _convert_message_to_dict(message) + + assert "tool_calls" in dict_result + assert len(dict_result["tool_calls"]) == 1 + assert dict_result["tool_calls"][0]["id"] == ID + assert dict_result["tool_calls"][0]["type"] == "function" + assert dict_result["tool_calls"][0]["function"]["name"] == "system__ai__python_exec" + assert dict_result["tool_calls"][0]["thoughtSignature"] == THOUGHT_SIG + + def test_convert_tool_message() -> None: tool_message = ToolMessage(content="result", tool_call_id="call_123") result = _convert_message_to_dict(tool_message)