diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 77a591e..555263c 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -60,7 +60,7 @@ jobs: npm install npx semantic-release env: - PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} + # PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }} RELEASE_TEST_PYPI: ${{ github.event.repository.is_template || contains(github.repository, 'template') }} # dry run if not on main/master branch, or if initial commit diff --git a/.gitignore b/.gitignore index 3cc91b1..1439a90 100644 --- a/.gitignore +++ b/.gitignore @@ -142,3 +142,6 @@ dmypy.json # Pyre type checker .pyre/ + +# lm eval cache +hf_cache/ diff --git a/GoT/__init__.py b/GoT/__init__.py index e6c9568..c9d8fc7 100644 --- a/GoT/__init__.py +++ b/GoT/__init__.py @@ -1,39 +1,75 @@ import json import logging +from dotenv import load_dotenv from lm_eval import evaluator, tasks -from GoT.model.graph_model import invoke_graph, set_prompt -from GoT.model.lm_wrapper import LangGraphLMWrapper +from GoT.model.graph_model import call_graph +from GoT.model.lm_wrapper import LangGraphBigBenchWrapper, TestBigBenchWrapper +from GoT.model.utils.parse_args import call_benchmark, defining_and_parse_args +from GoT.model.utils.utils import ( + print_benchmark_result, + print_benchmark_result_loglikehood, +) logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger("GoT") +load_dotenv() -def lm_eval_benchmark(): - task_list = ["gsm8k"] - lm = LangGraphLMWrapper() +# Possible filter = "flexible", "none", "strict" + + +def lm_eval_test_benchmark(): + task_name = "gpqa_diamond_zeroshot" + task_list = [task_name] + test_lm = TestBigBenchWrapper() + task_dict = tasks.get_task_dict(task_list) + + results = evaluator.evaluate( + lm=test_lm, + task_dict=task_dict, + limit=2, # Limit the number of samples + log_samples=True, + # samples={task_name: [20, 25, 100]}, + ) + + # Save results to a JSON file + with open("test_benchmark_results.json", "w") as f: + json.dump(results["samples"], f, indent=2) + + print_benchmark_result(results, task_name, filter="strict-match") + + +def lm_eval_graph_benchmark(): + # hendrycks_math_geometry + task_name = "gpqa_diamond_zeroshot" + task_list = [task_name] + lm = LangGraphBigBenchWrapper() task_dict = tasks.get_task_dict(task_list) results = evaluator.evaluate( lm=lm, + # limit=1, task_dict=task_dict, - limit=5, # Limit to 2 samples for quick testing + samples={task_name: [20, 25]}, log_samples=True, ) # Save results to a JSON file with open("graph_benchmark_results.json", "w") as f: - json.dump(results, f, indent=2) + json.dump(results, f, indent=2, default=str) + + print_benchmark_result_loglikehood(results, task_name, filter_val="none") def custom_test(): - set_prompt("What is 4726621 + 2 * 392 - 3432?") - invoke_graph() + call_graph("Solve this integral ∫x2⋅ex2dx") def main(): # It could be changed with custom_test() to test a custom problem instead of the benchmark - lm_eval_benchmark() + args = defining_and_parse_args() + call_benchmark(args) # let this be the last line of this file diff --git a/GoT/model/graph_model.py b/GoT/model/graph_model.py index 36cdf9b..57491ad 100644 --- a/GoT/model/graph_model.py +++ b/GoT/model/graph_model.py @@ -2,14 +2,15 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage from langgraph.graph import StateGraph, MessagesState, START, END -from GoT.model.ollama_llm import OllamaLLM +from GoT.model.ollama_llm import LLM from GoT.model.runtime_graph import ( BacktrackNode, CompletitionNode, + CraftingNode, GoalNode, + ReasoningNode, Response, RuntimeGraph, - RuntimeNode, Score, TestNode, ToolNode, @@ -19,43 +20,112 @@ parse_response, parse_response_for_tool_node, parse_score, - parse_tool_list, - remove_tools_from_list, ) +from GoT.tools.runtime_graph_tool import divide_thought +SCORE_THRESHOLD = 5 +COMPLEXITY_COEFFICIENT = 0.5 +COMPLEXITY_THRESHOLD = 5.5 +MAX_INTERACTIONS = 10 # Defining agents -starting_agent = OllamaLLM().create_custom_agent( - OllamaLLM().get_tools(), +starting_agent = LLM().create_custom_agent( + LLM().get_tools(), SystemMessage( - "You are an assistant specialized in tools. Your goal is not to resolve the problem," - " only to make list with the best tool to use. " + "You are an assistant specialized in tools. " + " Your goal is not to resolve the problem, but only to make list with the best tool to use. " + " If you are not sure of the tool you have, think also a generic tool to craft that could be useful to solve the problem (Specify the craft is needed). " + " If the problem require too much steps with the current tools, consider crafting a new tool." "The list MUST be in this format and it is not possible to format the tool_name in any way: " - "- tool_name " - "- tool_name " - "- tool_name " + "- tool_name: explanation of why this tool is useful in this problem " + "- tool_name: explanation of why this tool is useful in this problem " + "- tool_name: explanation of why this tool is useful in this problem " ), + type="remote_standard", ) -chat_completition_agent = OllamaLLM().create_custom_agent([]) +chat_completition_agent = LLM().create_custom_agent([]) -judge_agent = OllamaLLM().create_custom_agent( +judge_agent = LLM().create_custom_agent( [], SystemMessage( - "You are an assistant specialized in validation of response, like an LLM-as-a-judge. " - "Your duty is to score, from 0 to 6, the response that user gives to you and assign to it a score. " - "You MUST respond ONLY using the Score function. " - "Do not write natural language outside the function. " - "If you fail to respect the format, the evaluation will fail." - "\n0: The response is impossible to understand and completely wrong. " - "\n1: The response is near to be completely wrong. " - "\n2: The response is in the correct language but it doesn't follow the instruction. " - "\n3: The response try to resolve the problem but doesn't follow the instruction or the response is wrong. " - "\n4: The response follow the instruction but the result is wrong or the result is correct but doesn't follow the instruction. " - "\n5: The response follow the instruction and the result is near to the solution (If the task is hard, the solution should be near to the corrected one). " - "\n6: The response follow the instruction and the result is perfectly correct." + """ + You are an assistant specialized in validating responses, like an LLM-as-a-judge. + Your duty is to score, from 0 to 5, the response that user gives and assign a score. + + Rules: + - You MUST respond ONLY using the Score function. + - You cannot give the full solution, only hints. + - If a response suggest the need of crafting a tool, score it with 1 or less and specify clearly the need of a new tool to solve the problem. + - Do not write natural language outside the function. + - Always consider creating a tool if it makes the response correct or reusable. + + Score meanings: + 0: Impossible to understand / completely wrong + 1: Nearly completely wrong + 2: Correct language but does not follow instruction + 3: Tries to solve but fails instruction / wrong + 4: Follows instruction but result wrong or incomplete + 5: Follows instruction, result correct + + Example: + - User response: cannot compute because missing helper + - Hint: implement a tool that computes the missing function + """ ), response_format=Score, + type="remote_score_format", +) + +crafter_agent = LLM().create_custom_agent( + LLM().get_craft_tool(), + SystemMessage( + """ + You create reusable Python tools for other agents. + + The tool must be GENERAL and parameterized. + Never hardcode values from the current problem. + + Bad example (too specific): + def multiply_3_and_7(): + return 3 * 7 + + Good example (general): + def multiply(a: float, b: float) -> float: + ' + Arguments: + a: the first number to multiply + b: the second number to multiply + Returns: + The product of a and b + ' + return a * b + + Rules: + - Prefer generic names and parameters, never craft specific functions. + - If the function contains specific numbers or values, it is wrong. + - Craft only one function, it must contains always the docs. + - Never craft tool that raise exceptions. + - Respond ONLY using the tool available. + - No natural language. + - No comments in the python interpreter. + """ + ), + response_format=Response, + type="remote_response_format", +) + +reasoning_agent = LLM().create_custom_agent( + [divide_thought], + SystemMessage( + "You are an assistant specialized in divide the complexity. " + "Your goal is to resolve the problem with reasoning. You should to reason step by step and write all your reasoning. " + "You are forced to divide it into smaller parts." + "The parts must be indipendent from each other and must be solvable at the same time." + "Respond with the indicated format." + ), + response_format=Response, + type="remote_response_format", ) # Defining runtime graph @@ -67,63 +137,107 @@ def goal(prompt: MessagesState): goal_node = GoalNode(parse_response(prompt), resolved=True) runtime_graph.add_node(goal_node) runtime_graph.temp_node = goal_node - return prompt - - -def tool_expand(goal: MessagesState): - msg = parse_response(goal) - sys_msg = "Which tools that I have can I use to solve this problem? Please make a list using '-' to denote each tool in a probabilistic order, don't use this character for other reasons." - messages = [ - HumanMessage(msg), - SystemMessage(sys_msg), - ] - res = starting_agent.invoke({"messages": messages}) - str_res = parse_response(res) - tool_list = parse_tool_list(str_res) # Toglie elementi inutili - # add tool nodes in the runtime graph - for tool in tool_list: - tool_node = RuntimeNode(resolved=True) + for i in range(0, 3): call_node = ToolNode( - f"Please, resolve the problem with the tool: {tool}", + "Please, resolve the problem with the tools given, you MUST follow the previous reasoning.", "", - tool_name=tool, + tool_name="", ) - runtime_graph.add_node(tool_node) - runtime_graph.add_edge( - runtime_graph.temp_node, tool_node - ) # edge from goal to tool node + reasoning_node = ReasoningNode("") + runtime_graph.add_node(reasoning_node) runtime_graph.add_node(call_node) - runtime_graph.add_edge(tool_node, call_node) - runtime_graph.add_tool_link(call_node, tool) - # extract a tool to call + runtime_graph.add_edge(goal_node, reasoning_node) + runtime_graph.add_edge(reasoning_node, call_node) + reasoning_node = ReasoningNode("") + runtime_graph.add_node(reasoning_node) + runtime_graph.add_edge(goal_node, reasoning_node) + # extract a reasoning node to resolve runtime_graph.temp_node = runtime_graph.call_tool_node() - return runtime_graph.append_prompt_to_messages_state(runtime_graph.temp_node) + return prompt + + +# def tool_expand(goal: MessagesState): +# msg = parse_response(goal) +# sys_msg = "Please make a list using '-' to denote each tool in a probabilistic order, don't use this character for other reasons. Select only the tool(s) you want to use to solve this problem." +# messages = [ +# HumanMessage(msg), +# SystemMessage(sys_msg), +# ] +# res = starting_agent.invoke({"messages": messages}, config={"recursion_limit": MAX_INTERACTIONS}) +# str_res = parse_response(res) +# goal["messages"].append(AIMessage(content=str_res)) +# # tool_list = parse_tool_list(str_res) # Toglie elementi inutili +# # add tool nodes in the runtime graph + +# return goal + + +def tool_reasoning(messages: MessagesState): + messages["messages"].append( + HumanMessage( + "Please, reason step by step about how to use these tools to solve the problem, without solving it. If the main solution require to craft a tool, only explain how to craft the tool" + ) + ) + result = parse_response( + starting_agent.invoke(messages, config={"recursion_limit": MAX_INTERACTIONS}) + ) + messages["messages"].append(AIMessage(result)) + runtime_graph.resolve_node(runtime_graph.temp_node, result) + runtime_graph.temp_node = runtime_graph.nodes.get(runtime_graph.temp_node, [])[0] + if not isinstance(runtime_graph.temp_node, ToolNode): + raise TypeError("Expected ToolNode after reasoning") + messages["messages"].append(SystemMessage(runtime_graph.temp_node.prompt)) + return messages def tool_call(messages: MessagesState): # It calls the llm and it resolves the call node call_node = runtime_graph.temp_node - tool_agent = OllamaLLM().create_custom_agent( - remove_tools_from_list( - OllamaLLM().get_tools(), runtime_graph.get_resolved_tools() - ), + tool_agent = LLM().create_custom_agent( + LLM().get_tools() + [divide_thought], SystemMessage( "You are an assistant specialized in tools. Your goal is to resolve the problem with " - " the tool that the user indicates to you. You MUST use the tool that user indicates to you." - "You MUST respond ONLY using the Response function. " + " the tool that the user indicates to you. You HAVE to use or craft the tool that the assistant indicates to you." "Do not write natural language outside the function. " "If you fail to respect the format, the evaluation will fail." ), response_format=Response, + type="remote_response_format", ) - - res = tool_agent.invoke(messages) + try: + res = tool_agent.invoke( + {"messages": messages["messages"], "tool_choice": Response}, + config={"recursion_limit": MAX_INTERACTIONS}, + ) + except Exception: + message = [ + HumanMessage( + content="Original task:\n" + parse_response(runtime_graph.goal) + ), + HumanMessage( + content="Divide the problem into smaller parts, you can't call the same tool twice." + ), + ] + tool_agent = LLM().create_custom_agent( + [divide_thought], + SystemMessage( + "You are an assistant specialized in tools. Your goal is to resolve the problem with " + " the tool that the user indicates to you. You HAVE to use the tool that the assistant indicates to you." + "Do not write natural language outside the function. " + "If you fail to respect the format, the evaluation will fail." + ), + response_format=Response, + type="remote_response_format", + ) + res = tool_agent.invoke( + {"messages": message, "tool_choice": Response}, + config={"recursion_limit": MAX_INTERACTIONS}, + ) tool_used = extract_tool_used(res) runtime_graph.temp_response.response = parse_response_for_tool_node(res).response parsed_res = f"Response: {parse_response_for_tool_node(res).response}\nExplanation: {parse_response_for_tool_node(res).explanation}" runtime_graph.resolve_node(call_node, parsed_res) - # Add test node test_node = TestNode( f"{parsed_res}", "", @@ -147,7 +261,7 @@ def response_evaluation(messages: MessagesState): # Create a proper message for the judge with the solution judge_messages = [ HumanMessage(content="Original task:\n" + parse_response(runtime_graph.goal)), - HumanMessage(content="Solution:\n" + call_node_response), + HumanMessage(content=call_node_response), SystemMessage( content="Score this solution based on correctness and following instructions." ), @@ -156,10 +270,54 @@ def response_evaluation(messages: MessagesState): ), ] - score_res = parse_score(judge_agent.invoke({"messages": judge_messages})) + score_res = parse_score( + judge_agent.invoke( + {"messages": judge_messages, "tool_choice": Score}, + config={"recursion_limit": MAX_INTERACTIONS}, + ) + ) test_node.score = score_res.score + test_node.need_tool_crafting = score_res.need_tool_crafting + test_node.problem_complexity = score_res.problem_complexity runtime_graph.resolve_node(test_node, score_res.description) + return messages + +def crafting(messages: MessagesState): + crafting_node = CraftingNode(response="", tool_crafted="", resolved=False) + runtime_graph.add_node(crafting_node) + runtime_graph.add_edge(runtime_graph.temp_node, crafting_node) + runtime_graph.temp_node = crafting_node + crafting_messages = [ + HumanMessage(content="Original task:\n" + parse_response(runtime_graph.goal)), + SystemMessage( + content="Craft a tool to solve this problem using craft_tool. It must be a function" + ), + ] + craft_res = crafter_agent.invoke( + {"messages": crafting_messages}, config={"recursion_limit": MAX_INTERACTIONS} + ) + runtime_graph.temp_response.response = parse_response_for_tool_node( + craft_res + ).response + parsed_res = f"Response: {parse_response_for_tool_node(craft_res).response}\nExplanation: {parse_response_for_tool_node(craft_res).explanation}" + runtime_graph.resolve_node(crafting_node, parsed_res) + runtime_graph.temp_node = runtime_graph.call_tool_node() + runtime_graph.add_edge(crafting_node, runtime_graph.temp_node) + global starting_agent + starting_agent = LLM().create_custom_agent( + LLM().get_tools(), + SystemMessage( + "You are an assistant specialized in tools. Your goal is not to resolve the problem," + " only to make list with the best tool to use. " + "The list MUST be in this format and it is not possible to format the tool_name in any way: " + "- tool_name " + "- tool_name " + "- tool_name " + ), + type="remote_standard", + ) + messages["messages"].append(AIMessage(parsed_res)) return messages @@ -169,15 +327,37 @@ def test_result(messages: MessagesState): if not isinstance(test_node, TestNode): raise TypeError("Expected TestNode for scoring") - if test_node.score >= 5: + if test_node.score >= ( + COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity + ): runtime_graph.add_edge(test_node, runtime_graph.temp_response) runtime_graph.temp_response.resolved = True return END - elif test_node.score < 5 and n is True: + elif ( + test_node.score + < (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity) + and n is True + and test_node.need_tool_crafting is True + ): + return "crafting" + elif ( + test_node.score + < (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity) + and n is True + ): + if test_node.need_tool_crafting is True: + test_node.response = "The problem is too complex to craft a new tool, try reason step by step or divide complexity." return "backtrack" + elif ( + test_node.score + < (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity) + and n is False + and runtime_graph.exist_reasoning_node_available() + ): + return "reasoning_mode" else: chat_completition_node = CompletitionNode( - "Please, solve this problem", + "Please, solve this problem with the corrected format.", "", ) runtime_graph.add_node(chat_completition_node) @@ -186,6 +366,31 @@ def test_result(messages: MessagesState): return "chat_completition" +def reasoning_mode(messages: MessagesState): + reasoning_node = runtime_graph.call_tool_node() + if not isinstance(reasoning_node, ReasoningNode): + raise TypeError("Expected ReasoningNode for reasoning mode") + msg = [HumanMessage(content=parse_response(runtime_graph.goal))] + result = parse_response( + reasoning_agent.invoke( + {"messages": msg}, config={"recursion_limit": MAX_INTERACTIONS} + ) + ) + runtime_graph.resolve_node(reasoning_node, result) + runtime_graph.temp_response.response = result + test_node = TestNode( + f"{result}", + "", + score=0, + tool_used=[], + ) + runtime_graph.add_node(test_node) + runtime_graph.add_edge(reasoning_node, test_node) + runtime_graph.temp_node = test_node + messages["messages"].append(AIMessage(result)) + return messages + + def backtrack(messages: MessagesState): test_node = runtime_graph.temp_node if not isinstance(test_node, TestNode): @@ -197,7 +402,7 @@ def backtrack(messages: MessagesState): runtime_graph.add_edge( backtrack_node, runtime_graph.temp_node ) # tool call node that we want to resolve - messages = runtime_graph.append_prompt_to_messages_state(runtime_graph.temp_node) + # messages = runtime_graph.append_prompt_to_messages_state(runtime_graph.temp_node) messages.get("messages", []).append(AIMessage(backtrack_node.feedback)) return messages @@ -216,50 +421,64 @@ def chat_completition(messages: MessagesState): return new_messages_history -# https://docs.langchain.com/oss/python/langgraph/overview - content = "" -def set_prompt(prompt: str): +def call_graph(prompt: str): global content content = prompt + return invoke_graph() def invoke_graph(): graph = StateGraph(MessagesState) graph.add_node(goal) - graph.add_node(tool_expand) + # graph.add_node(tool_expand) + graph.add_node(tool_reasoning) graph.add_node(tool_call) graph.add_node(backtrack) + graph.add_node(crafting) graph.add_node(chat_completition) graph.add_node(response_evaluation) + graph.add_node(reasoning_mode) graph.add_edge(START, "goal") - graph.add_edge("goal", "tool_expand") - graph.add_edge("tool_expand", "tool_call") + # graph.add_edge("goal", "tool_expand") + graph.add_edge("goal", "tool_reasoning") + graph.add_edge("tool_reasoning", "tool_call") graph.add_edge("tool_call", "response_evaluation") - graph.add_edge("backtrack", "tool_call") + graph.add_edge("reasoning_mode", "response_evaluation") + graph.add_edge("backtrack", "tool_reasoning") + graph.add_edge("crafting", "tool_reasoning") graph.add_edge("chat_completition", END) graph.add_conditional_edges( "response_evaluation", test_result, - {"backtrack": "backtrack", "chat_completition": "chat_completition", END: END}, + { + "crafting": "crafting", + "backtrack": "backtrack", + "reasoning_mode": "reasoning_mode", + "chat_completition": "chat_completition", + END: END, + }, ) graph = graph.compile() logger.info(graph.get_graph().draw_mermaid()) - - res = graph.invoke( - { - "messages": [ - { - "role": "user", - "content": content, - } - ] - } - ) + try: + res = graph.invoke( + { + "messages": [ + { + "role": "user", + "content": content, + } + ] + } + ) + except Exception as e: + runtime_graph.clear() + raise e res["output"] = runtime_graph.temp_response.response diff --git a/GoT/model/lm_wrapper.py b/GoT/model/lm_wrapper.py index 332d0a6..1013c0b 100644 --- a/GoT/model/lm_wrapper.py +++ b/GoT/model/lm_wrapper.py @@ -1,9 +1,10 @@ -import math -from GoT.model.graph_model import invoke_graph, set_prompt +from GoT.model.graph_model import call_graph from lm_eval.api.registry import register_model from lm_eval.api.model import LM -from GoT.model.utils.utils import extract_output +from GoT.model.ollama_llm import LLM +from langchain_core.messages import HumanMessage +from GoT.model.utils.utils import extract_output, normalize_number, parse_response class LangGraphLM: @@ -18,12 +19,25 @@ def generate(self, requests, max_new_tokens=None): outputs = [] for r in requests: prompt = r["prompt"] - set_prompt(prompt) # Imposta il prompt globale - result = invoke_graph() # usa la tua funzione + result = call_graph(prompt) outputs.append(result["output"]) return outputs +class OllamaTestLM: + def __init__(self): + pass + + def generate(self, requests, max_new_tokens=None): + agent = LLM().create_custom_agent(tools=LLM().get_tools()) + outputs = [] + for r in requests: + prompt = r["prompt"] + result = agent.invoke({"messages": [HumanMessage(content=prompt)]}) + outputs.append(normalize_number(parse_response(result))) + return outputs + + # Register your custom model with lm_eval @register_model("langgraph_lm") class LangGraphLMWrapper(LM): @@ -58,28 +72,72 @@ def generate_until(self, requests, until=None, max_new_tokens=None, **kwargs): question = str(request) # Call invoke_graph with just the question - set_prompt(question) - result = invoke_graph() + result = call_graph(question) # Extract the final answer from the result - output = "" - if isinstance(result, dict) and "messages" in result: - messages = result["messages"] - if messages: - # The last message should be the response - last_msg = messages[-1] - - # Extract content from message - if hasattr(last_msg, "content"): - output = last_msg.content - elif isinstance(last_msg, dict) and "content" in last_msg: - output = last_msg["content"] - else: - output = str(last_msg) + output = ( + result["output"] + if isinstance(result, dict) and "output" in result + else "" + ) + normalize_output = normalize_number(output) + + outputs.append(normalize_output if normalize_output else "") + + except Exception as e: + print(f"Error in request {i}: {e}") + import traceback + + traceback.print_exc() + outputs.append("") + + return outputs + + def loglikelihood(self, requests): + return super().loglikelihood(requests) + + def loglikelihood_rolling(self, requests): + return super().loglikelihood_rolling(requests) + + +@register_model("ollama_lm_test") +class OllamaTestLMWrapper(LM): + def __init__(self, model_args=""): + super().__init__() + self.lm = OllamaTestLM() + + def generate_until(self, requests, until=None, max_new_tokens=None, **kwargs): + """ + Generate text until a stopping condition is met. + """ + agent = LLM().create_custom_agent(tools=LLM().get_tools()) + outputs = [] + for i, request in enumerate(requests): + try: + # Extract the question directly from request.doc + # This is much simpler than parsing the full prompt + if hasattr(request, "doc") and isinstance(request.doc, dict): + question = request.doc.get("question", "") else: - output = str(result) if result else "" + # Fallback: try to extract from arguments + if hasattr(request, "arguments") and request.arguments: + full_prompt = request.arguments[0][0] + # Just take everything after the last "Answer:" as the question + last_answer_idx = full_prompt.rfind("Answer:") + if last_answer_idx != -1: + question = full_prompt[ + last_answer_idx + 8 : + ].strip() # 8 = len("Answer:") + else: + question = full_prompt + else: + question = str(request) - outputs.append(output if output else "") + # Call invoke_graph with just the question + result = agent.invoke({"messages": [HumanMessage(content=question)]}) + normalize_output = normalize_number(parse_response(result)) + + outputs.append(normalize_output if normalize_output else "") except Exception as e: print(f"Error in request {i}: {e}") @@ -122,7 +180,14 @@ def _extract_text_from_request(self, request): return prompt # Fallback per altri campi - for field in ["input", "question", "text", "prompt", "instruction"]: + for field in [ + "input", + "question", + "problem", + "text", + "prompt", + "instruction", + ]: if field in doc and doc[field]: return str(doc[field]) @@ -156,8 +221,7 @@ def generate_until(self, requests, until=None, max_new_tokens=None, **kwargs): try: question = self._extract_text_from_request(request) - set_prompt(question) - result = invoke_graph() + result = call_graph(question) # Estrai l'output output = extract_output(result) @@ -180,41 +244,33 @@ def generate_until(self, requests, until=None, max_new_tokens=None, **kwargs): return outputs def loglikelihood(self, requests): - """ - Calcola la log-likelihood per BigBench multiple choice tasks. - - Riceve una lista di tuple: (context, continuation) - Restituisce lista di tuple: (log_likelihood_score, is_greedy) - - NOTA: Poiché invoke_graph non fornisce logits, usiamo un'euristica. - Se hai bisogno di precision massima, modifica invoke_graph per restituire logits. - """ outputs = [] + cache = {} # context -> score già calcolato for i, request in enumerate(requests): try: - # Estrai context e continuation - if isinstance(request, tuple) and len(request) == 2: + if hasattr(request, "arguments") and request.arguments: + context, continuation = request.arguments + elif isinstance(request, tuple) and len(request) == 2: context, continuation = request else: - # Fallback - context = "" - continuation = str(request) - - # Prepara il prompt per il modello - full_text = context + continuation + context, continuation = request - # Chiama il grafo - set_prompt(full_text) - result = invoke_graph() - generated_output = extract_output(result) + # Se già calcolato, riutilizza lo score + if context in cache: + score = cache[context] + else: + result = call_graph(context) + generated_output = normalize_number(extract_output(result)) + print(f"DEBUG - Context: {context}") + cache[context] = generated_output - # Calcola il score di likelihood + gen_text = cache[context] score = self._calculate_likelihood_score( - generated_output, continuation, context + gen_text, continuation, context ) - outputs.append((score, False)) + outputs.append((score, score >= 1.0)) except Exception as e: print(f"Error in BigBench loglikelihood request {i}: {e}") @@ -239,36 +295,134 @@ def _calculate_likelihood_score( Calcola un score di likelihood per BigBench. Strategie: - 1. Se la generazione contiene esattamente la continuation target -> score alto (0.0) + 1. Se la generazione contiene esattamente la continuation target -> score alto (1.0) 2. Se contiene la continuation parzialmente -> score medio - 3. Altrimenti -> score basso (log della similarità) + 3. Altrimenti -> score basso (0.0) Questo è un workaround poiché non abbiamo i veri logits. """ gen_text = str(generated_text).strip().lower() - target = str(target_continuation).strip().lower() - - # Caso 1: Match esatto - if gen_text == target: - return 0.0 # log(1) = 0, massima likelihood + target = ( + str(target_continuation).strip().lower().replace("(", "").replace(")", "") + ) - # Caso 2: Target è contenuto nel testo generato + # Se il target è contenuto nella risposta (es: "a" è in "la risposta è a") if target in gen_text: - # Parziale match, score moderato - return math.log(0.8) + return 5.0 # Match trovato - # Caso 3: Calcola sovrapposizione di parole - gen_words = set(gen_text.split()) - target_words = set(target.split()) + return -1.0 # Invece di -inf, usa un valore molto basso ma numerico - if not target_words: - return 0.0 - overlap = len(gen_words & target_words) / len(target_words) +@register_model("test_bigbench") +class TestBigBenchWrapper(LM): + def __init__(self, model_args=""): + super().__init__() + self.agent = LLM().create_custom_agent(LLM().get_tools()) - if overlap > 0: - # Maggiore è l'overlap, più alto lo score - return math.log(overlap) - else: - # Nessuna sovrapposizione - return float("-inf") + def _extract_text_from_request(self, request): + """ + Estrae il testo dalla request BigBench. + BigBench passa il prompt completo in doc['inputs']. + """ + # Prova con request.doc - BigBench usa il campo 'inputs' + if hasattr(request, "doc") and isinstance(request.doc, dict): + doc = request.doc + + # BigBench mette il prompt completo in 'inputs' + if "inputs" in doc and doc["inputs"]: + prompt = str(doc["inputs"]).strip() + return prompt + + # Fallback per altri campi + for field in [ + "input", + "question", + "problem", + "text", + "prompt", + "instruction", + ]: + if field in doc and doc[field]: + return str(doc[field]) + + # Fallback: prova con request.arguments + if hasattr(request, "arguments") and request.arguments: + try: + full_prompt = request.arguments[0][0] + if full_prompt and len(str(full_prompt).strip()) > 1: + return full_prompt + except (IndexError, TypeError): + pass + + return str(request) + + def generate_until(self, requests, until=None, **kwargs): + outputs = [] + + if not until: + until = ["\n\n", "\n"] + elif isinstance(until, str): + until = [until] + + for request in requests: + try: + prompt = self._extract_text_from_request(request) + + response = self.agent.invoke( + {"messages": [HumanMessage(content=prompt)]} + ) + response = extract_output(response) + + if not isinstance(response, str): + response = str(response) + + for stop_seq in until: + if stop_seq in response: + response = response.split(stop_seq)[0] + + outputs.append(response.strip()) + + except Exception as e: + print(f"Agent error: {e}") + outputs.append("") + + return outputs + + def loglikelihood(self, requests): + """ + Per agent è meglio NON usare loglikelihood classico. + Usiamo matching diretto. + """ + outputs = [] + + for request in requests: + try: + if hasattr(request, "arguments") and request.arguments: + context, continuation = request.arguments + else: + context, continuation = request + + response = self.agent.invoke( + {"messages": [HumanMessage(content=context)]} + ) + response = str(response).strip().lower() + target = str(continuation).strip().lower() + + # scoring semplice + if response == target: + score = 0.0 + elif target in response: + score = -0.5 + else: + score = -5.0 + + outputs.append((score, False)) + + except Exception as e: + print(f"loglikelihood error: {e}") + outputs.append((float("-inf"), False)) + + return outputs + + def loglikelihood_rolling(self, requests): + return self.loglikelihood(requests) diff --git a/GoT/model/ollama_llm.py b/GoT/model/ollama_llm.py index 37c1c5e..65bfcf2 100644 --- a/GoT/model/ollama_llm.py +++ b/GoT/model/ollama_llm.py @@ -1,3 +1,10 @@ +import inspect +import os +import importlib +import sys +from langchain.tools import BaseTool, tool +from langchain_google_genai import ChatGoogleGenerativeAI + from dotenv import load_dotenv from langchain_openai import ChatOpenAI from langchain_core.messages import SystemMessage @@ -6,17 +13,19 @@ from GoT.tools.math_tool import ( multiply, - sum_four, - sum_three, summing, minus, square_root, + divide, ) +from GoT.tools.craft_tool import craft_tool, install_dependency + load_dotenv() mlflow.set_experiment("marcoraggini-experiment") mlflow.openai.autolog() +mlflow.gemini.autolog() mlflow.langchain.autolog() SYSTEM_PROMPT_GENERAL = """ @@ -25,34 +34,72 @@ """ -class OllamaLLM: +class LLM: def __init__(self): with mlflow.set_active_model(name="ollama-agent-ministral-3"): self.ollamaLLM = ChatOpenAI( base_url="http://localhost:11434/v1", api_key="dummy", model="ministral-3:8b", + temperature=0.5, + reasoning_effort="none", ) - self.system_prompt = SystemMessage(SYSTEM_PROMPT_GENERAL) - - self.agent = create_agent( - model=self.ollamaLLM, - tools=[sum_four, summing, minus, sum_three, square_root], - system_prompt=self.system_prompt, + # GEMINI LLMs + self.remoteLLMStandard = ChatGoogleGenerativeAI( + model="gemini-2.5-flash", + api_key=os.environ.get("GEMINI_API_KEY"), + temperature=1.0, # Gemini 3.0+ defaults to 1.0 + ) + self.remoteLLMResponseFormat = ChatGoogleGenerativeAI( + model="gemini-2.5-flash", + api_key=os.environ.get("GEMINI_API_KEY"), + temperature=1.0, # Gemini 3.0+ defaults to 1.0 + ) + self.remoteLLMScoreFormat = ChatGoogleGenerativeAI( + model="gemini-2.5-flash", + api_key=os.environ.get("GEMINI_API_KEY"), + temperature=0.7, # Gemini 3.0+ defaults to 1.0 ) + self.remoteLLMs = { + "remote_standard": self.remoteLLMStandard, + "remote_response_format": self.remoteLLMResponseFormat, + "remote_score_format": self.remoteLLMScoreFormat, + } + + self.system_prompt = SystemMessage(SYSTEM_PROMPT_GENERAL) + def get_tools(self): - return [sum_four, summing, minus, sum_three, square_root, multiply] + initial_tools = [summing, minus, square_root, multiply, divide] + crafted_tools = self.get_crafted_tools() + return initial_tools + crafted_tools + + def get_craft_tool(self): + return [craft_tool, install_dependency] + + def get_crafted_tools(self) -> list[BaseTool]: + module_name = "GoT.tools.ai_tool" + if module_name in sys.modules: + module = importlib.reload(sys.modules[module_name]) + else: + module = importlib.import_module(module_name) + tools = [] + for name, obj in module.__dict__.items(): + if inspect.isfunction(obj) and obj.__module__ == module.__name__: + tools.append(tool(obj)) # wrap runtime + + return tools def create_custom_agent( self, tools, system_prompt: SystemMessage = SystemMessage(SYSTEM_PROMPT_GENERAL), response_format=None, + type: str = "remote_standard", ): return create_agent( - model=self.ollamaLLM, + model=self.remoteLLMs[type], tools=tools, system_prompt=system_prompt, response_format=response_format, diff --git a/GoT/model/runtime_graph.py b/GoT/model/runtime_graph.py index 221358e..506f3d9 100644 --- a/GoT/model/runtime_graph.py +++ b/GoT/model/runtime_graph.py @@ -1,7 +1,7 @@ from typing import Dict, List from langgraph.graph import MessagesState from langchain_core.messages import AnyMessage, HumanMessage -from pydantic import BaseModel +from pydantic import BaseModel, Field class RuntimeNode: @@ -33,14 +33,18 @@ def __init__( prompt: str, response: str, score: int, + problem_complexity: int = 0, tool_used: List[str] = [], resolved: bool = False, + need_tool_crafting: bool = False, ): super().__init__(resolved) self.prompt = prompt self.response = response self.score = score self.tool_used = tool_used + self.need_tool_crafting = need_tool_crafting + self.problem_complexity = problem_complexity class ToolNode(RuntimeNode): @@ -89,14 +93,14 @@ def __init__( self.feedback = feedback -class ContextNode(RuntimeNode): +class ReasoningNode(RuntimeNode): def __init__( self, - context: str, + reasoning: str, resolved: bool = False, ): super().__init__(resolved) - self.context = context + self.reasoning = reasoning class ResponseNode(RuntimeNode): @@ -109,28 +113,55 @@ def __init__( self.response = response +class CraftingNode(RuntimeNode): + def __init__( + self, + response: str, + tool_crafted: str = "", + resolved: bool = False, + ): + super().__init__(resolved) + self.response = response + self.tool_crafted = tool_crafted + + class Score(BaseModel): """Rapresents a score for a test node. Attributes: score: int - The score assigned to the test node. description: str - A description or rationale for the assigned score. + need_tool_crafting: bool - Indicates whether the test node requires crafting a new tool to be resolved. """ - score: int - description: str + score: int = Field(..., description="Integer score between 0 and 6 inclusive.") + description: str = Field( + ..., + description="Short justification (1-3 sentences) explaining why the score was assigned.", + ) -class Response(BaseModel): - """Rapresents a response for a tool node. + need_tool_crafting: bool = Field( + ..., + description="Indicates whether the problem requires or it is useful to craft a new tool to be resolved.", + ) - Attributes: - response: str - The synthethic response. - explanation: str - An explanation or rationale for the response. - """ + problem_complexity: int = Field( + ..., + description="Integer between 0 and 5 indicating the complexity of the problem for an AI.", + ) + + +class Response(BaseModel): + """Rapresents a response for a tool node.""" - response: str - explanation: str + response: str = Field( + ..., + description="Final answer to the user request. No meta commentary. If there is a format requirement, follow it.", + ) + explanation: str = Field( + ..., description="Short reasoning explaining how the answer was produced." + ) class RuntimeGraph: @@ -156,10 +187,19 @@ def resolve_node(self, node: RuntimeNode, response: str) -> None: node.response = response node.resolved = True - def call_tool_node(self) -> ToolNode: + def call_tool_node(self) -> ReasoningNode: nodes = list(self.nodes.keys()) - call_nodes = [n for n in nodes if (isinstance(n, ToolNode) and not n.resolved)] - return call_nodes[0] + reasoning_nodes = [ + n for n in nodes if (isinstance(n, ReasoningNode) and not n.resolved) + ] + return reasoning_nodes[0] + + def exist_reasoning_node_available(self) -> bool: + nodes = list(self.nodes.keys()) + reasoning_nodes = [ + n for n in nodes if (isinstance(n, ReasoningNode) and not n.resolved) + ] + return True if reasoning_nodes else False def exist_tool_available(self) -> bool: nodes = list(self.nodes.keys()) @@ -170,6 +210,13 @@ def get_resolved_tools(self): resolved_nodes = [t for t in self.tools_available.keys() if t.resolved is True] return [self.tools_available[n] for n in resolved_nodes] + def is_craftin_node_resolved(self) -> bool: + nodes = list(self.nodes.keys()) + crafting_nodes = [ + n for n in nodes if (isinstance(n, CraftingNode) and n.resolved) + ] + return True if crafting_nodes else False + def append_prompt_to_messages_state( self, node: TestNode | ToolNode | CompletitionNode | GoalNode ) -> MessagesState: @@ -181,6 +228,7 @@ def append_prompt_to_messages_state( return MessagesState(messages=messages) def clear(self): + RuntimeNode._id_counter = 0 self.nodes = {} self.tools_available = {} self.temp_node = RuntimeNode() diff --git a/GoT/model/utils/hf_formatter.py b/GoT/model/utils/hf_formatter.py new file mode 100644 index 0000000..7b1eddf --- /dev/null +++ b/GoT/model/utils/hf_formatter.py @@ -0,0 +1,368 @@ +import json +from random import shuffle +import re +from datasets import Dataset, load_dataset + +from langchain.messages import HumanMessage + +from GoT.model.graph_model import call_graph +from GoT.model.ollama_llm import LLM +from GoT.model.utils.utils import ( + extract_output, + normalize_list, + normalize_number, + symbolic_equal, +) + + +class ResultEval: + def __init__( + self, + question: str, + response: str, + filtered_answer: str, + correct_answer: str, + answer_success: float, + ): + self.question = question + self.response = response + self.filtered_answer = filtered_answer + self.correct_answer = correct_answer + self.answer_success = answer_success + + @staticmethod + def create_empty_result(question: str, correct_answer: str): + return ResultEval( + question=question, + response="Error", + filtered_answer="", + correct_answer=correct_answer, + answer_success=0.0, + ) + + def to_dict(self): + return { + "question": self.question, + "response": self.response, + "filtered_answer": self.filtered_answer, + "correct_answer": self.correct_answer, + "answer_success": self.answer_success, + } + + +def save_eval_results(responses: list[ResultEval], model_name: str): + with open(f"{model_name}_eval_results.json", "w") as f: + json.dump([res.to_dict() for res in responses], f, indent=2) + + +def gpqa_format(dataset: Dataset) -> list[ResultEval]: + questions = [] + # Mapping per trasformare l'indice della lista nelle lettere A, B, C, D + index_to_letter = {0: "A", 1: "B", 2: "C", 3: "D"} + + for data in dataset: # Vediamo i primi 2 esempi + sample = data + + question = sample["Question"] + correct_answer = sample["Correct Answer"] + + # Creiamo la lista delle opzioni partendo dai dati del sample + choices = [ + correct_answer, + sample["Incorrect Answer 1"], + sample["Incorrect Answer 2"], + sample["Incorrect Answer 3"], + ] + + shuffle(choices) + + correct_idx = choices.index(correct_answer) + correct_letter = index_to_letter[correct_idx] + + prompt = ( + "Answer the following multiple choice question. " + "The last line of your response should be of the following format: " + "‘ANSWER: LETTER’ (without quotes) where LETTER is one of ABCD. " + "Think step by step before answering.\n\n" + f"{question}\n" + f"A) {choices[0]}\n" + f"B) {choices[1]}\n" + f"C) {choices[2]}\n" + f"D) {choices[3]}\n" + "Answer:" + ) + questions.append( + ResultEval.create_empty_result( + question=prompt, correct_answer=correct_letter + ) + ) + + return questions + + +def gpqa_run(questions: list[ResultEval], max_run: int, test: bool) -> list[ResultEval]: + responses = [] + run_counter = 0 + agent = LLM().create_custom_agent(LLM().get_tools() + LLM().get_craft_tool()) + for q in questions[25:]: + if run_counter >= max_run: + break + prompt = q.question + correct_letter = q.correct_answer + try: + if test: + response = extract_output( + agent.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config={"recursion_limit": 10}, + ) + ) + else: + response = extract_output(call_graph(prompt)) + norm_res = normalize_number(response) + responses.append( + ResultEval( + question=prompt, + response=norm_res, + filtered_answer="", + correct_answer=correct_letter, + answer_success=0.0, + ) + ) + except Exception as e: + print(f"Error processing question: {e}") + responses.append( + ResultEval( + question=prompt, + response="Error", + filtered_answer="", + correct_answer=correct_letter, + answer_success=0.0, + ) + ) + run_counter += 1 + return responses + + +def gpqa_eval(responses: list[ResultEval]): + correct = 0 + + for res in responses: + match = re.search(r"ANSWER:\s*([A-D])", res.response, re.IGNORECASE) + res.filtered_answer = match.group(1).upper() if match else "N/A" + + if f"ANSWER: {res.correct_answer}" in res.response: + correct += 1 + res.answer_success = 1.0 + + accuracy = correct / len(responses) * 100 + print(f"Accuracy: {accuracy:.2f}%") + print(f"Total: {len(responses)}") + print(f"Correct: {correct}") + + +def gsm8k_format(dataset: Dataset) -> list[ResultEval]: + questions = [] + for data in dataset: + sample = data + question = sample["question"] + correct_answer = sample["answer"] + prompt = ( + "Answer the following math problem. Respond in the following format: #### NUMBER " + "Think step by step before answering.\n\n" + f"{question}\n" + "Answer:" + ) + + questions.append( + ResultEval.create_empty_result( + question=prompt, correct_answer=correct_answer + ) + ) + + return questions + + +def gsm8k_run( + questions: list[ResultEval], max_run: int, test: bool +) -> list[ResultEval]: + responses = [] + run_counter = 0 + agent = LLM().create_custom_agent(LLM().get_tools() + LLM().get_craft_tool()) + for q in questions: + if run_counter >= max_run: + break + prompt = q.question + correct_answer = q.correct_answer + try: + if test: + response = extract_output( + agent.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config={"recursion_limit": 20}, + ) + ) + else: + response = extract_output(call_graph(prompt)) + norm_res = normalize_number(response) + responses.append( + ResultEval( + question=prompt, + response=norm_res, + filtered_answer="", + correct_answer=correct_answer, + answer_success=0.0, + ) + ) + except Exception as e: + print(f"Error processing question: {e}") + responses.append( + ResultEval( + question=prompt, + response="Error", + filtered_answer="", + correct_answer=correct_answer, + answer_success=0.0, + ) + ) + run_counter += 1 + return responses + + +def gsm8k_eval(responses: list[ResultEval]): + correct = 0 + + for res in responses: + opt_res = re.search(r"####\s*(-?[\d,.]+)", res.response) + norm_res = opt_res.group(1) if opt_res else "N/A" + norm_correct = normalize_number(res.correct_answer) + res.filtered_answer = norm_res + + if norm_res in norm_correct: + correct += 1 + res.answer_success = 1.0 + + accuracy = correct / len(responses) * 100 + print(f"Accuracy: {accuracy:.2f}%") + print(f"Total: {len(responses)}") + print(f"Correct: {correct}") + + +def hendrycks_math_format(dataset: Dataset) -> list[ResultEval]: + questions = [] + for data in dataset: + sample = data + question = sample["problem"] + reg_exp = re.search(r"\\boxed\{(.*)\}", sample["solution"]) + correct_answer = reg_exp.group(1) if reg_exp else "N/A" + + prompt = ( + "Answer the following math problem. Respond in the following format: \\boxed{answer}. " + "Think step by step before answering.\n\n" + f"{question}\n" + "Answer:" + ) + + questions.append( + ResultEval.create_empty_result( + question=prompt, correct_answer=correct_answer + ) + ) + + return questions + + +def hendrycks_math_run( + questions: list[ResultEval], max_run: int, test: bool +) -> list[ResultEval]: + responses = [] + run_counter = 0 + agent = LLM().create_custom_agent(LLM().get_tools() + LLM().get_craft_tool()) + for q in questions: + if run_counter >= max_run: + break + prompt = q.question + correct_answer = q.correct_answer + try: + if test: + response = extract_output( + agent.invoke( + {"messages": [HumanMessage(content=prompt)]}, + config={"recursion_limit": 20}, + ) + ) + else: + response = extract_output(call_graph(prompt)) + norm_res = normalize_number(response) + responses.append( + ResultEval( + question=prompt, + response=norm_res, + filtered_answer="", + correct_answer=correct_answer, + answer_success=0.0, + ) + ) + except Exception as e: + print(f"Error processing question: {e}") + responses.append( + ResultEval( + question=prompt, + response="Error", + filtered_answer="", + correct_answer=correct_answer, + answer_success=0.0, + ) + ) + run_counter += 1 + return responses + + +def hendrycks_math_eval(responses: list[ResultEval]): + correct = 0 + + for res in responses: + opt_res = re.search(r"\\boxed\{(.*)\}", res.response) + norm_res = opt_res.group(1) if opt_res else "N/A" + norm_correct = normalize_number(res.correct_answer) + res.filtered_answer = norm_res + + if ( + (norm_res in norm_correct) + or (normalize_list(norm_res) == normalize_list(norm_correct)) + or (symbolic_equal(norm_res, norm_correct)) + ): + correct += 1 + res.answer_success = 1.0 + + accuracy = correct / len(responses) * 100 + print(f"Accuracy: {accuracy:.2f}%") + print(f"Total: {len(responses)}") + print(f"Correct: {correct}") + + +def use_gpqa(max_run: int, test: bool, model_name: str): + ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond") + data = ds["train"] + questions = gpqa_format(data) + responses = gpqa_run(questions, max_run=max_run, test=test) + gpqa_eval(responses) + save_eval_results(responses, model_name=model_name) + + +def use_gsm8k(max_run: int, test: bool, model_name: str): + ds = load_dataset("gsm8k", "main") + data = ds["test"] + questions = gsm8k_format(data) + responses = gsm8k_run(questions, max_run=max_run, test=test) + gsm8k_eval(responses) + save_eval_results(responses, model_name=model_name) + + +def use_hendrycks_math(max_run: int, test: bool, model_name: str, type: str): + ds = load_dataset("EleutherAI/hendrycks_math", type) + data = ds["test"] + questions = hendrycks_math_format(data) + responses = hendrycks_math_run(questions, max_run=max_run, test=test) + hendrycks_math_eval(responses) + save_eval_results(responses, model_name=model_name) diff --git a/GoT/model/utils/parse_args.py b/GoT/model/utils/parse_args.py new file mode 100644 index 0000000..d2437a3 --- /dev/null +++ b/GoT/model/utils/parse_args.py @@ -0,0 +1,64 @@ +import argparse +import sys + +from GoT.model.utils.hf_formatter import use_gpqa, use_gsm8k, use_hendrycks_math + + +def defining_and_parse_args(): + parser = argparse.ArgumentParser( + description="Run the GoT model on a benchmark or a custom problem." + ) + parser.add_argument( + "--benchmark", + required=True, + type=str, + choices=["gsm8k", "gpqa", "hendrycks_math"], + help="The benchmark to run the model on.", + ) + parser.add_argument( + "--mode", + required=True, + type=str, + choices=["graph", "standard"], + help="Whether to run the standard model or the graph model.", + ) + parser.add_argument( + "--max_run", + type=int, + default=1, + help="The maximum number of runs for the benchmark.", + ) + parser.add_argument( + "--type", + type=str, + default="algebra", + choices=[ + "algebra", + "counting_and_probability", + "geometry", + "intermediate_algebra", + "number_theory", + "precalculus", + "statistics", + ], + help="The type of math problems to run, only for hendrycks_math.", + ) + + if len(sys.argv) == 1: + parser.print_help(sys.stderr) + sys.exit(1) + + args = parser.parse_args() + return args + + +def call_benchmark(args): + mode = args.mode if args.mode else "standard" + test = mode == "standard" + max_run = args.max_run + if args.benchmark == "gsm8k": + use_gsm8k(max_run=max_run, test=test, model_name=mode) + elif args.benchmark == "gpqa": + use_gpqa(max_run=max_run, test=test, model_name=mode) + elif args.benchmark == "hendrycks_math": + use_hendrycks_math(max_run=max_run, test=test, model_name=mode, type=args.type) diff --git a/GoT/model/utils/utils.py b/GoT/model/utils/utils.py index 3b7fc69..ac14a65 100644 --- a/GoT/model/utils/utils.py +++ b/GoT/model/utils/utils.py @@ -1,5 +1,9 @@ +import json import re +import numpy as np +from sympy import simplify, sympify + from GoT.model.runtime_graph import Response, Score from langgraph.graph import MessagesState from langchain_core.messages import AIMessage @@ -52,10 +56,16 @@ def parse_score(response: MessagesState) -> Score: :rtype: int """ score = response.get("structured_response") + score_res = extract_output(response) if isinstance(score, Score): return score + elif score_res is not None: + data = json.loads(score_res) + return Score.model_validate(data) else: - return Score(score=0, description="Failed to parse score") + return Score( + score=0, description="Failed to parse score", need_tool_crafting=False + ) def parse_response_for_tool_node(response: MessagesState) -> Response: @@ -68,8 +78,12 @@ def parse_response_for_tool_node(response: MessagesState) -> Response: :rtype: Response """ structured_response = response.get("structured_response") + score_res = extract_output(response) if isinstance(structured_response, Response): return structured_response + elif score_res is not None: + data = json.loads(score_res) + return Response.model_validate(data) else: return Response( response="Failed to parse response", @@ -119,8 +133,97 @@ def extract_output(result) -> str: last_msg = messages[-1] if hasattr(last_msg, "content"): + content = last_msg.content + if isinstance(content, list): + text_parts = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + text_parts.append(part.get("text", "")) + return "".join(text_parts).strip() return str(last_msg.content) elif isinstance(last_msg, dict) and "content" in last_msg: return str(last_msg["content"]) return str(result) if result else "" + + +def normalize_number(num_str: str) -> str: + """Remove commas and dollar sign.""" + num_str = num_str.replace("$", "") + num_str = num_str.replace(",", "") + num_str = num_str.replace("*", "") + num_str = num_str.strip() + + # Remove trailing .0 (and only .0) + if num_str.endswith(".0"): + # Ensure it's actually a valid number before trimming + try: + if float(num_str).is_integer(): + num_str = str(int(float(num_str))) + except ValueError: + pass + + return num_str + + +def normalize_list(s: str): + "Extract numbers from a string and return them as a sorted list. This is useful for comparing answers that are lists of numbers regardless of order." + nums = re.findall(r"-?\d+", s) + return sorted(nums) + + +def symbolic_equal(a, b): + """Check if two mathematical expressions are symbolically equal.""" + try: + return simplify(sympify(a) - sympify(b)) == 0 + except Exception: + return False + + +def print_benchmark_result(results: dict, task_name: str, filter: str) -> None: + samples = results["samples"][task_name] + + flex_samples = [s for s in samples if filter in s.get("filter", "")] + + n_total = len(flex_samples) + n_correct = sum(1 for s in flex_samples if s.get("exact_match", 0) == 1.0) + n_wrong = n_total - n_correct + + print(f"Total: {n_total}") + print(f"Correct answers (filter={filter}): {n_correct}") + print(f"Wrong answers (filter={filter}): {n_wrong}") + + +def print_benchmark_result_loglikehood( + results: dict, task_name: str, filter_val: str +) -> None: + # 1. Recupero i samples + samples = results.get("samples", {}).get(task_name, []) + + # 2. Filtro (solitamente filter="none" nel tuo caso) + flex_samples = [s for s in samples if filter_val in s.get("filter", "")] + + n_total = len(flex_samples) + n_correct = 0 + + for s in flex_samples: + try: + scores = [r[0][0] for r in s.get("resps", [])] + + if not scores: + continue + + predicted_idx = int(np.argmax(scores)) + + if predicted_idx >= 5.0: + n_correct += 1 + + except (ValueError, IndexError): + continue + + n_wrong = n_total - n_correct + + print(f"--- Risultati basati su RESPS (Task: {task_name}) ---") + print(f"Total: {n_total}") + print(f"Correct: {n_correct}") + print(f"Wrong: {n_wrong}") diff --git a/GoT/tools/ai_tool.py b/GoT/tools/ai_tool.py new file mode 100644 index 0000000..e69de29 diff --git a/GoT/tools/craft_tool.py b/GoT/tools/craft_tool.py new file mode 100644 index 0000000..8db13dd --- /dev/null +++ b/GoT/tools/craft_tool.py @@ -0,0 +1,102 @@ +import ast +import os +from pathlib import Path +import re + +from langchain.tools import tool + + +@tool +def python_tool(code: str) -> str: + """Execute Python code and return the output. The code should assign the final result to a variable named 'result'.""" + + def sanitize_input(query: str) -> str: + """Sanitize input to the python REPL. + Remove whitespace, backtick & python + (if llm mistakes python console as terminal) + """ + query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query) + query = re.sub(r"(\s|`)*$", "", query) + return query + + try: + # WARNING: Using eval/exec can be dangerous. This is just for demonstration purposes. + namespace: dict[str, object] = {} + exec(sanitize_input(code), namespace, namespace) + return str(namespace.get("result", "No result variable defined.")) + except Exception as e: + return str(e) + + +@tool +def install_dependency(package_name: str) -> str: + """Install a Python package using poetry.""" + try: + os.system(f"poetry add {package_name}") + return f"Package {package_name} installed successfully." + except Exception as e: + return str(e) + + +def is_valid_annotation(annotation): + if isinstance(annotation, ast.Name): + if annotation.id in {"list", "dict", "set"}: + return False + return True + + # Caso tipo generico: List[int], dict[str, int] + if isinstance(annotation, ast.Subscript): + return True + + return False + + +@tool +def craft_tool(tool_function: str) -> str: + """Save the function definition provided by the LLM as a tool that can be used by other agents. + The function should be defined as a python function. + The function should be general and reusable, and should not be specific to the current problem. + The function should be defined in a way that it can be imported and used by other agents.""" + + def sanitize_input(query: str) -> str: + """Sanitize input to the python REPL. + Remove whitespace, backtick & python + (if llm mistakes python console as terminal) + """ + query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query) + query = re.sub(r"(\s|`)*$", "", query) + return query + + code = f"""\n\n{sanitize_input(tool_function)}""" + + # Syntax check + try: + tree = ast.parse(code) + except SyntaxError as e: + return f"Syntax error, tool not saved: {e}" + + functions = [n for n in tree.body if isinstance(n, ast.FunctionDef)] + if len(functions) != 1: + return "Error: exactly one function must be defined." + + func = functions[0] + + for arg in func.args.args: + if arg.annotation is None: + return f"Error: missing type annotation for '{arg.arg}'" + + if not is_valid_annotation(arg.annotation): + return f"Error: invalid type for '{arg.arg}' (must be typed, e.g. List[int], set[str], dict[str, int], etc.)" + + if func.returns is None: + return "Error: missing return type" + + try: + base_dir = Path(__file__).parent + file_path = base_dir / "ai_tool.py" + with open(file_path, "a") as f: + f.write(code) + return "Tool crafted successfully." + except Exception as e: + print(e) + return str(e) diff --git a/GoT/tools/math_tool.py b/GoT/tools/math_tool.py index 00ddd74..7a97d19 100644 --- a/GoT/tools/math_tool.py +++ b/GoT/tools/math_tool.py @@ -3,61 +3,49 @@ @tool -def summing(x: int, y: int) -> int: - """sum of two integer +def summing(x: float, y: float) -> float: + """sum of two float numbers Arguments: - x(int): first number - y(int): second number + x(float): first number + y(float): second number """ return x + y @tool -def sum_three(x: int, y: int, z: int): - """sum of three integer +def minus(x: float, y: float) -> float: + """minus of two float numbers Arguments: - x(int): first number - y(int): second number - z(int): third number + x(float): first number + y(float): second number """ - return x + y + z - - -@tool -def sum_four(x, y, z, a): - """sum of four integer - - Arguments: - x(int): first number - y(int): second number - z(int): third number - a(int): fourth number - """ - return x + y + z + a + return x - y @tool -def minus(x: int, y: int) -> int: - """minus of two integer +def multiply(x: float, y: float) -> float: + """multiply of two float numbers Arguments: - x(int): first number - y(int): second number + x(float): first number + y(float): second number """ - return x - y + return x * y @tool -def multiply(x: int, y: int) -> int: - """multiply of two integer +def divide(x: float, y: float) -> float: + """divide of two float numbers Arguments: - x(int): first number - y(int): second number + x(float): first number + y(float): second number """ - return x * y + if y == 0: + raise ValueError("Cannot divide by zero") + return x / y @tool diff --git a/GoT/tools/runtime_graph_tool.py b/GoT/tools/runtime_graph_tool.py new file mode 100644 index 0000000..77df52a --- /dev/null +++ b/GoT/tools/runtime_graph_tool.py @@ -0,0 +1,94 @@ +from langchain.messages import HumanMessage, SystemMessage +from langchain.tools import tool + +from GoT.model.ollama_llm import LLM +from GoT.model.runtime_graph import ReasoningNode, RuntimeGraph +from GoT.model.utils.utils import parse_response + +MAX_INTERACTIONS = 10 + + +@tool +def divide_thought( + first_part: str, + second_part: str, + first_context: str, + second_context: str, + reasoning_type: str = "pure_reasoning", +) -> str: + """ + This is a tool to divide the thought process into smaller steps. + HOW TO USE THIS TOOL: + - Call it when you think the problem is complex. + - The two parts must be as independent as possible. + Arguments: + - first_part: the first part of the thought process + - second_part: the second part of the thought process + - first_context: the context related to the first part + - second_context: the context related to the second part + - reasoning_type: the type of reasoning to use for each part, it can be "pure_reasoning" or "tool_use" + """ + tool_agent = LLM().create_custom_agent( + LLM().get_tools(), + SystemMessage( + "You are an assistant specialized in tools. Your goal is to resolve the problem with " + " the tool that the user indicates to you. You should to use the tool that the assistant indicates to you." + "Do not write natural language outside the function. " + "If you fail to respect the format, the evaluation will fail." + ), + ) + reasoning_agent = LLM().create_custom_agent( + [], + SystemMessage( + "You are an assistant specialized in reasoning. " + "Your goal is to resolve the problem with reasoning. You should to reason step by step and write all your reasoning. " + "If the problem is too complex, you can divide it into smaller parts." + ), + ) + runtime_graph = RuntimeGraph() + n1 = ReasoningNode("") + n2 = ReasoningNode("") + runtime_graph.add_node(n1) + runtime_graph.add_node(n2) + runtime_graph.add_edge(runtime_graph.temp_node, n1) + runtime_graph.add_edge(runtime_graph.temp_node, n2) + msg1 = [ + HumanMessage( + "Reason aboout this problem: " + first_part + "\nContext: " + first_context + ) + ] + msg2 = [ + HumanMessage( + "Reason aboout this problem: " + + second_part + + "\nContext: " + + second_context + ) + ] + if reasoning_type == "pure_reasoning": + res1 = parse_response( + reasoning_agent.invoke( + {"messages": msg1}, config={"recursion_limit": MAX_INTERACTIONS} + ) + ) + res2 = parse_response( + reasoning_agent.invoke( + {"messages": msg2}, config={"recursion_limit": MAX_INTERACTIONS} + ) + ) + else: + res1 = parse_response( + tool_agent.invoke( + {"messages": msg1}, config={"recursion_limit": MAX_INTERACTIONS} + ) + ) + res2 = parse_response( + tool_agent.invoke( + {"messages": msg2}, config={"recursion_limit": MAX_INTERACTIONS} + ) + ) + runtime_graph.resolve_node(n1, res1) + runtime_graph.resolve_node(n2, res2) + + result = f"First part: {res1}\nSecond part: {res2}" + return result diff --git a/poetry.lock b/poetry.lock index 231f1d6..e6e3dce 100644 --- a/poetry.lock +++ b/poetry.lock @@ -221,6 +221,17 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] +[[package]] +name = "antlr4-python3-runtime" +version = "4.11.0" +description = "ANTLR 4.11.0 runtime for Python 3" +optional = false +python-versions = "*" +files = [ + {file = "antlr4-python3-runtime-4.11.0.tar.gz", hash = "sha256:6a46d4f6033c8d7e53c8975cee5d83b3ba1b48157f6beb09a4965062dacfe2e0"}, + {file = "antlr4_python3_runtime-4.11.0-py3-none-any.whl", hash = "sha256:f523f91387283045f129c343b363a8dda3a07eea62721b8eda95f2d8b817b656"}, +] + [[package]] name = "anyio" version = "4.12.1" @@ -1290,6 +1301,17 @@ files = [ {file = "filelock-3.24.2.tar.gz", hash = "sha256:c22803117490f156e59fafce621f0550a7a853e2bbf4f87f112b11d469b6c81b"}, ] +[[package]] +name = "filetype" +version = "1.2.0" +description = "Infer file type and MIME type of any file/buffer. No external dependencies." +optional = false +python-versions = "*" +files = [ + {file = "filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25"}, + {file = "filetype-1.2.0.tar.gz", hash = "sha256:66b56cd6474bf41d8c54660347d37afcc3f7d1970648de365c102ef77548aadb"}, +] + [[package]] name = "flask" version = "3.1.2" @@ -1627,6 +1649,7 @@ files = [ [package.dependencies] cryptography = ">=38.0.3" pyasn1-modules = ">=0.2.1" +requests = {version = ">=2.20.0,<3.0.0", optional = true, markers = "extra == \"requests\""} rsa = ">=3.1.4,<5" [package.extras] @@ -1640,6 +1663,33 @@ requests = ["requests (>=2.20.0,<3.0.0)"] testing = ["aiohttp (<3.10.0)", "aiohttp (>=3.6.2,<4.0.0)", "aioresponses", "flask", "freezegun", "grpcio", "oauth2client", "packaging", "pyjwt (>=2.0)", "pyopenssl (<24.3.0)", "pyopenssl (>=20.0.0)", "pytest", "pytest-asyncio", "pytest-cov", "pytest-localserver", "pyu2f (>=0.1.5)", "requests (>=2.20.0,<3.0.0)", "responses", "urllib3"] urllib3 = ["packaging", "urllib3"] +[[package]] +name = "google-genai" +version = "1.66.0" +description = "GenAI Python SDK" +optional = false +python-versions = ">=3.10" +files = [ + {file = "google_genai-1.66.0-py3-none-any.whl", hash = "sha256:7f127a39cf695277104ce4091bb26e417c59bb46e952ff3699c3a982d9c474ee"}, + {file = "google_genai-1.66.0.tar.gz", hash = "sha256:ffc01647b65046bca6387320057aa51db0ad64bcc72c8e3e914062acfa5f7c49"}, +] + +[package.dependencies] +anyio = ">=4.8.0,<5.0.0" +distro = ">=1.7.0,<2" +google-auth = {version = ">=2.47.0,<3.0.0", extras = ["requests"]} +httpx = ">=0.28.1,<1.0.0" +pydantic = ">=2.9.0,<3.0.0" +requests = ">=2.28.1,<3.0.0" +sniffio = "*" +tenacity = ">=8.2.3,<9.2.0" +typing-extensions = ">=4.11.0,<5.0.0" +websockets = ">=13.0.0,<17.0" + +[package.extras] +aiohttp = ["aiohttp (>=3.10.11,<4.0.0)"] +local-tokenizer = ["protobuf", "sentencepiece (>=0.2.0)"] + [[package]] name = "graphene" version = "3.4.3" @@ -2339,18 +2389,20 @@ files = [ [[package]] name = "langchain" -version = "1.2.10" +version = "1.2.11" description = "Building applications with LLMs through composability" optional = false python-versions = "<4.0.0,>=3.10.0" files = [ - {file = "langchain-1.2.10-py3-none-any.whl", hash = "sha256:e07a377204451fffaed88276b8193e894893b1003e25c5bca6539288ccca3698"}, - {file = "langchain-1.2.10.tar.gz", hash = "sha256:bdcd7218d9c79a413cf15e106e4eb94408ac0963df9333ccd095b9ed43bf3be7"}, + {file = "langchain-1.2.11-py3-none-any.whl", hash = "sha256:ccc5d23e2568efa6e3cb2dde268a267d7f090bdad47d7e1ee5f0c9769e8cb7b9"}, + {file = "langchain-1.2.11.tar.gz", hash = "sha256:1b0f680de88c178898a69f9814025729110fc68365b2c33cd2ba978114d5fc0a"}, ] [package.dependencies] langchain-core = ">=1.2.10,<2.0.0" -langgraph = ">=1.0.8,<1.1.0" +langchain-google-genai = {version = "*", optional = true, markers = "extra == \"google-genai\""} +langchain-openai = {version = "*", optional = true, markers = "extra == \"openai\""} +langgraph = ">=1.1.0,<1.3.0" pydantic = ">=2.7.4,<3.0.0" [package.extras] @@ -2392,6 +2444,23 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<10.0.0" typing-extensions = ">=4.7.0,<5.0.0" uuid-utils = ">=0.12.0,<1.0" +[[package]] +name = "langchain-google-genai" +version = "4.2.1" +description = "An integration package connecting Google's genai package and LangChain" +optional = false +python-versions = "<4.0.0,>=3.10.0" +files = [ + {file = "langchain_google_genai-4.2.1-py3-none-any.whl", hash = "sha256:a7735289cf94ca3a684d830e09196aac8f6e75e647e3a0a1c3c9dc534ceb985e"}, + {file = "langchain_google_genai-4.2.1.tar.gz", hash = "sha256:7f44487a0337535897e3bba9a1d6605d722629e034f757ffa8755af0aa85daa8"}, +] + +[package.dependencies] +filetype = ">=1.2.0,<2.0.0" +google-genai = ">=1.56.0,<2.0.0" +langchain-core = ">=1.2.5,<2.0.0" +pydantic = ">=2.0.0,<3.0.0" + [[package]] name = "langchain-openai" version = "1.1.9" @@ -2410,19 +2479,19 @@ tiktoken = ">=0.7.0,<1.0.0" [[package]] name = "langgraph" -version = "1.0.8" +version = "1.1.0" description = "Building stateful, multi-actor applications with LLMs" optional = false python-versions = ">=3.10" files = [ - {file = "langgraph-1.0.8-py3-none-any.whl", hash = "sha256:da737177c024caad7e5262642bece4f54edf4cba2c905a1d1338963f41cf0904"}, - {file = "langgraph-1.0.8.tar.gz", hash = "sha256:2630fc578846995114fd659f8b14df9eff5a4e78c49413f67718725e88ceb544"}, + {file = "langgraph-1.1.0-py3-none-any.whl", hash = "sha256:7d29e01312340c9c7c09eb0178db5edd0514c3f35929fa5b32a247fcd9a9cd65"}, + {file = "langgraph-1.1.0.tar.gz", hash = "sha256:2decaef5d6716166dc5c13e0fdf65637e5a1837ef4c94fd82b6bcf2115cb5c78"}, ] [package.dependencies] langchain-core = ">=0.1" langgraph-checkpoint = ">=2.1.0,<5.0.0" -langgraph-prebuilt = ">=1.0.7,<1.1.0" +langgraph-prebuilt = ">=1.0.8,<1.1.0" langgraph-sdk = ">=0.3.0,<0.4.0" pydantic = ">=2.7.4" xxhash = ">=3.5.0" @@ -2444,13 +2513,13 @@ ormsgpack = ">=1.12.0" [[package]] name = "langgraph-prebuilt" -version = "1.0.7" +version = "1.0.8" description = "Library with high-level APIs for creating and executing LangGraph agents and tools." optional = false python-versions = ">=3.10" files = [ - {file = "langgraph_prebuilt-1.0.7-py3-none-any.whl", hash = "sha256:e14923516504405bb5edc3977085bc9622c35476b50c1808544490e13871fe7c"}, - {file = "langgraph_prebuilt-1.0.7.tar.gz", hash = "sha256:38e097e06de810de4d0e028ffc0e432bb56d1fb417620fb1dfdc76c5e03e4bf9"}, + {file = "langgraph_prebuilt-1.0.8-py3-none-any.whl", hash = "sha256:d16a731e591ba4470f3e313a319c7eee7dbc40895bcf15c821f985a3522a7ce0"}, + {file = "langgraph_prebuilt-1.0.8.tar.gz", hash = "sha256:0cd3cf5473ced8a6cd687cc5294e08d3de57529d8dd14fdc6ae4899549efcf69"}, ] [package.dependencies] @@ -2502,6 +2571,30 @@ otel = ["opentelemetry-api (>=1.30.0)", "opentelemetry-exporter-otlp-proto-http pytest = ["pytest (>=7.0.0)", "rich (>=13.9.4)", "vcrpy (>=7.0.0)"] vcr = ["vcrpy (>=7.0.0)"] +[[package]] +name = "latex2sympy2-extended" +version = "1.11.0" +description = "Convert LaTeX math to SymPy expressions" +optional = false +python-versions = ">=3.10" +files = [ + {file = "latex2sympy2_extended-1.11.0-py3-none-any.whl", hash = "sha256:aebb77d52ce269e25028e4bea89ddb14d242ba36bcf7b636496fb5fd9728d234"}, + {file = "latex2sympy2_extended-1.11.0.tar.gz", hash = "sha256:9695657c81b50abba2636638638618db59f4663ed2a4a12d62cef74a40e28fec"}, +] + +[package.dependencies] +antlr4-python3-runtime = [ + {version = ">=4.9.3,<=4.13.2"}, + {version = "4.11.0", optional = true, markers = "extra == \"antlr4-11-0\""}, +] +sympy = "*" + +[package.extras] +antlr4-11-0 = ["antlr4-python3-runtime (==4.11.0)"] +antlr4-13-2 = ["antlr4-python3-runtime (==4.13.2)"] +antlr4-9-3 = ["antlr4-python3-runtime (==4.9.3)"] +dev = ["pytest"] + [[package]] name = "librt" version = "0.8.0" @@ -2613,11 +2706,13 @@ files = [ ] [package.dependencies] +antlr4-python3-runtime = {version = "4.11", optional = true, markers = "extra == \"math\""} datasets = ">=2.16.0" dill = "*" evaluate = ">=0.4.0" jinja2 = "*" jsonlines = "*" +math_verify = {version = "*", extras = ["antlr4-11-0"], optional = true, markers = "extra == \"math\""} more_itertools = "*" numpy = "*" pytablewriter = "*" @@ -2625,6 +2720,7 @@ rouge-score = ">=0.0.4" sacrebleu = ">=1.5.0" scikit-learn = ">=0.24.1" sqlitedict = "*" +sympy = {version = ">=1.12", optional = true, markers = "extra == \"math\""} typing_extensions = "*" word2number = "*" zstandard = "*" @@ -2951,6 +3047,32 @@ files = [ {file = "markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698"}, ] +[[package]] +name = "math-verify" +version = "0.9.0" +description = "HuggingFace library for verifying mathematical answers" +optional = false +python-versions = ">=3.10" +files = [ + {file = "math_verify-0.9.0-py3-none-any.whl", hash = "sha256:3703e7c4885354027fa84409d762a596a2906d1fd4deb78361876bd905a76194"}, + {file = "math_verify-0.9.0.tar.gz", hash = "sha256:45ac6c61344ba056b9e99a660a4bc8d044ed408f730aed68c60435aa5eec4645"}, +] + +[package.dependencies] +latex2sympy2_extended = [ + {version = "1.11.0"}, + {version = "*", extras = ["antlr4-11-0"], optional = true, markers = "extra == \"antlr4-11-0\""}, +] + +[package.extras] +antlr4-11-0 = ["latex2sympy2_extended[antlr4-11-0]"] +antlr4-13-2 = ["latex2sympy2_extended[antlr4-13-2]"] +antlr4-9-3 = ["latex2sympy2_extended[antlr4-9-3]"] +dev = ["math-verify[format]", "math-verify[test]"] +format = ["ruff"] +inference = ["lighteval[math]"] +test = ["pytest"] + [[package]] name = "matplotlib" version = "3.10.8" @@ -3181,6 +3303,23 @@ files = [ {file = "more_itertools-10.8.0.tar.gz", hash = "sha256:f638ddf8a1a0d134181275fb5d58b086ead7c6a72429ad725c67503f13ba30bd"}, ] +[[package]] +name = "mpmath" +version = "1.3.0" +description = "Python library for arbitrary-precision floating-point arithmetic" +optional = false +python-versions = "*" +files = [ + {file = "mpmath-1.3.0-py3-none-any.whl", hash = "sha256:a0b2b9fe80bbcd81a6647ff13108738cfb482d481d826cc0e02f5b35e5c88d2c"}, + {file = "mpmath-1.3.0.tar.gz", hash = "sha256:7a28eb2a9774d00c7bc92411c19a89209d5da7c4c9a9e227be8330a23a25b91f"}, +] + +[package.extras] +develop = ["codecov", "pycodestyle", "pytest (>=4.6)", "pytest-cov", "wheel"] +docs = ["sphinx"] +gmpy = ["gmpy2 (>=2.1.0a4)"] +tests = ["pytest (>=4.6)"] + [[package]] name = "msgpack" version = "1.1.2" @@ -4787,13 +4926,13 @@ six = ">=1.5" [[package]] name = "python-dotenv" -version = "1.2.1" +version = "1.2.2" description = "Read key-value pairs from a .env file and set them as environment variables" optional = false -python-versions = ">=3.9" +python-versions = ">=3.10" files = [ - {file = "python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61"}, - {file = "python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6"}, + {file = "python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a"}, + {file = "python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3"}, ] [package.extras] @@ -5677,6 +5816,23 @@ typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\"" [package.extras] full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] +[[package]] +name = "sympy" +version = "1.14.0" +description = "Computer algebra system (CAS) in Python" +optional = false +python-versions = ">=3.9" +files = [ + {file = "sympy-1.14.0-py3-none-any.whl", hash = "sha256:e091cc3e99d2141a0ba2847328f5479b05d94a6635cb96148ccb3f34671bd8f5"}, + {file = "sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517"}, +] + +[package.dependencies] +mpmath = ">=1.1.0,<1.4" + +[package.extras] +dev = ["hypothesis (>=6.70.0)", "pytest (>=7.1.0)"] + [[package]] name = "tabledata" version = "1.3.4" @@ -6124,6 +6280,76 @@ files = [ {file = "wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159"}, ] +[[package]] +name = "websockets" +version = "16.0" +description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)" +optional = false +python-versions = ">=3.10" +files = [ + {file = "websockets-16.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:04cdd5d2d1dacbad0a7bf36ccbcd3ccd5a30ee188f2560b7a62a30d14107b31a"}, + {file = "websockets-16.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8ff32bb86522a9e5e31439a58addbb0166f0204d64066fb955265c4e214160f0"}, + {file = "websockets-16.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:583b7c42688636f930688d712885cf1531326ee05effd982028212ccc13e5957"}, + {file = "websockets-16.0-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7d837379b647c0c4c2355c2499723f82f1635fd2c26510e1f587d89bc2199e72"}, + {file = "websockets-16.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:df57afc692e517a85e65b72e165356ed1df12386ecb879ad5693be08fac65dde"}, + {file = "websockets-16.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2b9f1e0d69bc60a4a87349d50c09a037a2607918746f07de04df9e43252c77a3"}, + {file = "websockets-16.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:335c23addf3d5e6a8633f9f8eda77efad001671e80b95c491dd0924587ece0b3"}, + {file = "websockets-16.0-cp310-cp310-win32.whl", hash = "sha256:37b31c1623c6605e4c00d466c9d633f9b812ea430c11c8a278774a1fde1acfa9"}, + {file = "websockets-16.0-cp310-cp310-win_amd64.whl", hash = "sha256:8e1dab317b6e77424356e11e99a432b7cb2f3ec8c5ab4dabbcee6add48f72b35"}, + {file = "websockets-16.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:31a52addea25187bde0797a97d6fc3d2f92b6f72a9370792d65a6e84615ac8a8"}, + {file = "websockets-16.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:417b28978cdccab24f46400586d128366313e8a96312e4b9362a4af504f3bbad"}, + {file = "websockets-16.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:af80d74d4edfa3cb9ed973a0a5ba2b2a549371f8a741e0800cb07becdd20f23d"}, + {file = "websockets-16.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:08d7af67b64d29823fed316505a89b86705f2b7981c07848fb5e3ea3020c1abe"}, + {file = "websockets-16.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7be95cfb0a4dae143eaed2bcba8ac23f4892d8971311f1b06f3c6b78952ee70b"}, + {file = "websockets-16.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d6297ce39ce5c2e6feb13c1a996a2ded3b6832155fcfc920265c76f24c7cceb5"}, + {file = "websockets-16.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1c1b30e4f497b0b354057f3467f56244c603a79c0d1dafce1d16c283c25f6e64"}, + {file = "websockets-16.0-cp311-cp311-win32.whl", hash = "sha256:5f451484aeb5cafee1ccf789b1b66f535409d038c56966d6101740c1614b86c6"}, + {file = "websockets-16.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7f0659570eefb578dacde98e24fb60af35350193e4f56e11190787bee77dac"}, + {file = "websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00"}, + {file = "websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79"}, + {file = "websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39"}, + {file = "websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c"}, + {file = "websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f"}, + {file = "websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1"}, + {file = "websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2"}, + {file = "websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89"}, + {file = "websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea"}, + {file = "websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9"}, + {file = "websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230"}, + {file = "websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c"}, + {file = "websockets-16.0-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:95724e638f0f9c350bb1c2b0a7ad0e83d9cc0c9259f3ea94e40d7b02a2179ae5"}, + {file = "websockets-16.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0204dc62a89dc9d50d682412c10b3542d748260d743500a85c13cd1ee4bde82"}, + {file = "websockets-16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:52ac480f44d32970d66763115edea932f1c5b1312de36df06d6b219f6741eed8"}, + {file = "websockets-16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6e5a82b677f8f6f59e8dfc34ec06ca6b5b48bc4fcda346acd093694cc2c24d8f"}, + {file = "websockets-16.0-cp313-cp313-win32.whl", hash = "sha256:abf050a199613f64c886ea10f38b47770a65154dc37181bfaff70c160f45315a"}, + {file = "websockets-16.0-cp313-cp313-win_amd64.whl", hash = "sha256:3425ac5cf448801335d6fdc7ae1eb22072055417a96cc6b31b3861f455fbc156"}, + {file = "websockets-16.0-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:8cc451a50f2aee53042ac52d2d053d08bf89bcb31ae799cb4487587661c038a0"}, + {file = "websockets-16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:daa3b6ff70a9241cf6c7fc9e949d41232d9d7d26fd3522b1ad2b4d62487e9904"}, + {file = "websockets-16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:fd3cb4adb94a2a6e2b7c0d8d05cb94e6f1c81a0cf9dc2694fb65c7e8d94c42e4"}, + {file = "websockets-16.0-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:781caf5e8eee67f663126490c2f96f40906594cb86b408a703630f95550a8c3e"}, + {file = "websockets-16.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:caab51a72c51973ca21fa8a18bd8165e1a0183f1ac7066a182ff27107b71e1a4"}, + {file = "websockets-16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:19c4dc84098e523fd63711e563077d39e90ec6702aff4b5d9e344a60cb3c0cb1"}, + {file = "websockets-16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:a5e18a238a2b2249c9a9235466b90e96ae4795672598a58772dd806edc7ac6d3"}, + {file = "websockets-16.0-cp314-cp314-win32.whl", hash = "sha256:a069d734c4a043182729edd3e9f247c3b2a4035415a9172fd0f1b71658a320a8"}, + {file = "websockets-16.0-cp314-cp314-win_amd64.whl", hash = "sha256:c0ee0e63f23914732c6d7e0cce24915c48f3f1512ec1d079ed01fc629dab269d"}, + {file = "websockets-16.0-cp314-cp314t-macosx_10_15_universal2.whl", hash = "sha256:a35539cacc3febb22b8f4d4a99cc79b104226a756aa7400adc722e83b0d03244"}, + {file = "websockets-16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:b784ca5de850f4ce93ec85d3269d24d4c82f22b7212023c974c401d4980ebc5e"}, + {file = "websockets-16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:569d01a4e7fba956c5ae4fc988f0d4e187900f5497ce46339c996dbf24f17641"}, + {file = "websockets-16.0-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50f23cdd8343b984957e4077839841146f67a3d31ab0d00e6b824e74c5b2f6e8"}, + {file = "websockets-16.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:152284a83a00c59b759697b7f9e9cddf4e3c7861dd0d964b472b70f78f89e80e"}, + {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bc59589ab64b0022385f429b94697348a6a234e8ce22544e3681b2e9331b5944"}, + {file = "websockets-16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:32da954ffa2814258030e5a57bc73a3635463238e797c7375dc8091327434206"}, + {file = "websockets-16.0-cp314-cp314t-win32.whl", hash = "sha256:5a4b4cc550cb665dd8a47f868c8d04c8230f857363ad3c9caf7a0c3bf8c61ca6"}, + {file = "websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd"}, + {file = "websockets-16.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:349f83cd6c9a415428ee1005cadb5c2c56f4389bc06a9af16103c3bc3dcc8b7d"}, + {file = "websockets-16.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4a1aba3340a8dca8db6eb5a7986157f52eb9e436b74813764241981ca4888f03"}, + {file = "websockets-16.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f4a32d1bd841d4bcbffdcb3d2ce50c09c3909fbead375ab28d0181af89fd04da"}, + {file = "websockets-16.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0298d07ee155e2e9fda5be8a9042200dd2e3bb0b8a38482156576f863a9d457c"}, + {file = "websockets-16.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:a653aea902e0324b52f1613332ddf50b00c06fdaf7e92624fbf8c77c78fa5767"}, + {file = "websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec"}, + {file = "websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5"}, +] + [[package]] name = "werkzeug" version = "3.1.5" @@ -6642,4 +6868,4 @@ cffi = ["cffi (>=1.17,<2.0)", "cffi (>=2.0.0b)"] [metadata] lock-version = "2.0" python-versions = ">= 3.10.0 < 3.14.0" -content-hash = "bf1a08632da81c1ddbb92096c0442b3d8f8893d83ac33a9e8f120c9950dcd145" +content-hash = "c7c8e8591227891fa161988b58a7a8e708d8769ce133faf007103f95df4a0ef4" diff --git a/pyproject.toml b/pyproject.toml index 022ecf8..95c44ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,12 +12,12 @@ readme = "README.md" [tool.poetry.dependencies] # update the Python versions in .github/workflows/check.yml if you change this python = ">= 3.10.0 < 3.14.0" -langchain = "^1.2.4" +langchain = {extras = ["google-genai", "openai"], version = "^1.2.11"} langchain-openai = "^1.1.7" langgraph = "^1.0.6" dotenv = "^0.9.9" mlflow = "^3.9.0" -lm-eval = "^0.4.10" +lm-eval = {extras = ["math"], version = "^0.4.11"} boto3 = "^1.42.51" [tool.poetry.group.dev.dependencies] @@ -49,5 +49,5 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [[tool.mypy.overrides]] -module = ["lm_eval.*"] +module = ["lm_eval.*", "datasets.*", "sympy.*"] follow_untyped_imports = true