diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 555263c..77a591e 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -60,7 +60,7 @@ jobs: npm install npx semantic-release env: - # PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }} RELEASE_TEST_PYPI: ${{ github.event.repository.is_template || contains(github.repository, 'template') }} # dry run if not on main/master branch, or if initial commit diff --git a/GoT/__init__.py b/GoT/__init__.py index c9d8fc7..8a62a54 100644 --- a/GoT/__init__.py +++ b/GoT/__init__.py @@ -3,10 +3,9 @@ from dotenv import load_dotenv from lm_eval import evaluator, tasks -from GoT.model.graph_model import call_graph -from GoT.model.lm_wrapper import LangGraphBigBenchWrapper, TestBigBenchWrapper -from GoT.model.utils.parse_args import call_benchmark, defining_and_parse_args -from GoT.model.utils.utils import ( +from GoT.experiments.lm_wrapper import LangGraphBigBenchWrapper, TestBigBenchWrapper +from GoT.cli.parse_args import call_benchmark, defining_and_parse_args +from GoT.utils.utils import ( print_benchmark_result, print_benchmark_result_loglikehood, ) @@ -62,14 +61,10 @@ def lm_eval_graph_benchmark(): print_benchmark_result_loglikehood(results, task_name, filter_val="none") -def custom_test(): - call_graph("Solve this integral ∫x2⋅ex2dx") - - def main(): - # It could be changed with custom_test() to test a custom problem instead of the benchmark args = defining_and_parse_args() call_benchmark(args) + # download_mlflow_traces(50) # let this be the last line of this file diff --git a/GoT/tools/ai_tool.py b/GoT/agent_tools/ai_tool.py similarity index 100% rename from GoT/tools/ai_tool.py rename to GoT/agent_tools/ai_tool.py diff --git a/GoT/tools/craft_tool.py b/GoT/agent_tools/craft_tool.py similarity index 86% rename from GoT/tools/craft_tool.py rename to GoT/agent_tools/craft_tool.py index 8db13dd..d2569b4 100644 --- a/GoT/tools/craft_tool.py +++ b/GoT/agent_tools/craft_tool.py @@ -56,6 +56,8 @@ def craft_tool(tool_function: str) -> str: """Save the function definition provided by the LLM as a tool that can be used by other agents. The function should be defined as a python function. The function should be general and reusable, and should not be specific to the current problem. + The function must not use tuple as args type. + The function must be defined as gemini api format, with type annotations for all arguments and return type. The function should be defined in a way that it can be imported and used by other agents.""" def sanitize_input(query: str) -> str: @@ -81,6 +83,13 @@ def sanitize_input(query: str) -> str: func = functions[0] + try: + docstring = ast.get_docstring(func) + if not docstring or not docstring.strip(): + return "Error: missing docstring. A description of the function is mandatory for Gemini tools." + except TypeError: + return "Error: missing docstring. A description of the function is mandatory for Gemini tools." + for arg in func.args.args: if arg.annotation is None: return f"Error: missing type annotation for '{arg.arg}'" diff --git a/GoT/tools/math_tool.py b/GoT/agent_tools/math_tool.py similarity index 100% rename from GoT/tools/math_tool.py rename to GoT/agent_tools/math_tool.py diff --git a/GoT/tools/runtime_graph_tool.py b/GoT/agent_tools/runtime_graph_tool.py similarity index 91% rename from GoT/tools/runtime_graph_tool.py rename to GoT/agent_tools/runtime_graph_tool.py index 77df52a..7db745c 100644 --- a/GoT/tools/runtime_graph_tool.py +++ b/GoT/agent_tools/runtime_graph_tool.py @@ -1,9 +1,9 @@ from langchain.messages import HumanMessage, SystemMessage from langchain.tools import tool -from GoT.model.ollama_llm import LLM -from GoT.model.runtime_graph import ReasoningNode, RuntimeGraph -from GoT.model.utils.utils import parse_response +from GoT.core.llm import LLM +from GoT.core.runtime_graph import ReasoningNode, RuntimeGraph +from GoT.utils.utils import parse_response MAX_INTERACTIONS = 10 @@ -21,6 +21,8 @@ def divide_thought( HOW TO USE THIS TOOL: - Call it when you think the problem is complex. - The two parts must be as independent as possible. + IMPORTANT NOTES: + - You can't use the result of the first part to reason about the second part, and vice versa. The two parts must be as independent as possible. Arguments: - first_part: the first part of the thought process - second_part: the second part of the thought process diff --git a/GoT/agent_tools/web_tool.py b/GoT/agent_tools/web_tool.py new file mode 100644 index 0000000..b56413d --- /dev/null +++ b/GoT/agent_tools/web_tool.py @@ -0,0 +1,54 @@ +import arxiv +from langchain.tools import tool +import wikipedia + + +@tool +def search_wikipedia(query: str) -> str: + """ + Fetch a brief summary from Wikipedia. + + Args: + query (str): The keyword or topic to search for. + + Returns: + str: A 3-sentence summary of the topic, the first option if + ambiguous, or an error message if not found. + """ + try: + return wikipedia.search(query) + except wikipedia.DisambiguationError as e: + # happens when query is ambiguous, pick first option + return wikipedia.summary(e.options[0], sentences=3) + except wikipedia.PageError: + return "Page not found" + + +@tool +def search_arxiv(query: str) -> str: + """Search ArXiv for scientific papers on a given topic. + Use this when you need to find research papers, abstracts or academic references.""" + + try: + client = arxiv.Client() + search = arxiv.Search( + query=query, max_results=3, sort_by=arxiv.SortCriterion.Relevance + ) + + results = [] + for paper in client.results(search): + results.append( + f"Title: {paper.title}\n" + f"Authors: {', '.join(a.name for a in paper.authors)}\n" + f"Published: {paper.published.strftime('%Y-%m-%d')}\n" + f"Summary: {paper.summary[:300]}...\n" + f"URL: {paper.entry_id}\n" + ) + + if not results: + return "No papers found for this query." + + return "\n---\n".join(results) + + except Exception as e: + return f"Error searching ArXiv: {str(e)}" diff --git a/GoT/model/utils/parse_args.py b/GoT/cli/parse_args.py similarity index 68% rename from GoT/model/utils/parse_args.py rename to GoT/cli/parse_args.py index d2437a3..8169319 100644 --- a/GoT/model/utils/parse_args.py +++ b/GoT/cli/parse_args.py @@ -1,7 +1,13 @@ import argparse import sys -from GoT.model.utils.hf_formatter import use_gpqa, use_gsm8k, use_hendrycks_math +from GoT.experiments.hf_formatter import ( + use_gaia, + use_gpqa, + use_gsm8k, + use_hendrycks_math, +) +from GoT.experiments.runner_custom import custom_test def defining_and_parse_args(): @@ -12,7 +18,7 @@ def defining_and_parse_args(): "--benchmark", required=True, type=str, - choices=["gsm8k", "gpqa", "hendrycks_math"], + choices=["gsm8k", "gpqa", "hendrycks_math", "gaia", "custom"], help="The benchmark to run the model on.", ) parser.add_argument( @@ -22,6 +28,9 @@ def defining_and_parse_args(): choices=["graph", "standard"], help="Whether to run the standard model or the graph model.", ) + parser.add_argument( + "--prompt", type=str, default="", help="Insert a prompt during a custom run." + ) parser.add_argument( "--max_run", type=int, @@ -29,7 +38,7 @@ def defining_and_parse_args(): help="The maximum number of runs for the benchmark.", ) parser.add_argument( - "--type", + "--category", type=str, default="algebra", choices=[ @@ -39,7 +48,7 @@ def defining_and_parse_args(): "intermediate_algebra", "number_theory", "precalculus", - "statistics", + "prealgebra", ], help="The type of math problems to run, only for hendrycks_math.", ) @@ -61,4 +70,10 @@ def call_benchmark(args): elif args.benchmark == "gpqa": use_gpqa(max_run=max_run, test=test, model_name=mode) elif args.benchmark == "hendrycks_math": - use_hendrycks_math(max_run=max_run, test=test, model_name=mode, type=args.type) + use_hendrycks_math( + max_run=max_run, test=test, model_name=mode, type=args.category + ) + elif args.benchmark == "gaia": + use_gaia(max_run=max_run, test=test, model_name=mode) + elif args.benchmark == "custom" and args.prompt != "": + custom_test(args.prompt, test) diff --git a/GoT/model/graph_model.py b/GoT/core/graph_model.py similarity index 80% rename from GoT/model/graph_model.py rename to GoT/core/graph_model.py index 57491ad..6ace4dc 100644 --- a/GoT/model/graph_model.py +++ b/GoT/core/graph_model.py @@ -2,8 +2,8 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage from langgraph.graph import StateGraph, MessagesState, START, END -from GoT.model.ollama_llm import LLM -from GoT.model.runtime_graph import ( +from GoT.core.llm import LLM +from GoT.core.runtime_graph import ( BacktrackNode, CompletitionNode, CraftingNode, @@ -15,13 +15,14 @@ TestNode, ToolNode, ) -from GoT.model.utils.utils import ( +from GoT.utils.utils import ( extract_tool_used, + extract_tools_crafted, parse_response, parse_response_for_tool_node, parse_score, ) -from GoT.tools.runtime_graph_tool import divide_thought +from GoT.agent_tools.runtime_graph_tool import divide_thought SCORE_THRESHOLD = 5 COMPLEXITY_COEFFICIENT = 0.5 @@ -54,15 +55,16 @@ Your duty is to score, from 0 to 5, the response that user gives and assign a score. Rules: + - If a response suggest the need of crafting a tool, score it with 1 or less and specify clearly the need of a new tool to solve the problem. - You MUST respond ONLY using the Score function. + - You must consider if the format of the answer follow the instruction - You cannot give the full solution, only hints. - - If a response suggest the need of crafting a tool, score it with 1 or less and specify clearly the need of a new tool to solve the problem. - Do not write natural language outside the function. - Always consider creating a tool if it makes the response correct or reusable. Score meanings: 0: Impossible to understand / completely wrong - 1: Nearly completely wrong + 1: Nearly completely wrong / need to craft a tool 2: Correct language but does not follow instruction 3: Tries to solve but fails instruction / wrong 4: Follows instruction but result wrong or incomplete @@ -81,6 +83,7 @@ LLM().get_craft_tool(), SystemMessage( """ + You are a specialized in coding and write new useful method. You create reusable Python tools for other agents. The tool must be GENERAL and parameterized. @@ -101,18 +104,81 @@ def multiply(a: float, b: float) -> float: ' return a * b + + Bad example (hardcoded/placeholder result): + def search_papers(query: str) -> str: + return "Results about " + query # WRONG: never return hardcoded strings + + Good example (real API call): + def search_papers(query: str) -> str: + ' + Arguments: + query: the search query string + Returns: + A string with real results fetched from the API + ' + import arxiv + client = arxiv.Client() + search = arxiv.Search(query=query, max_results=3) + results = [p.title + ": " + p.summary[:200] for p in client.results(search)] + return "\\n".join(results) + + Bad example: Too specific + def get_oldest_blu_ray_title(spreadsheet_path: str) -> str: + " + Analyzes a spreadsheet to find the oldest Blu-Ray title. + + Arguments: + spreadsheet_path: The file path to the spreadsheet (e.g., 'C:/Users/user/data.xlsx'). + + Returns: + The title of the oldest Blu-Ray as it appears in the spreadsheet. + " + import pandas as pd + + df = pd.read_excel(spreadsheet_path) + + # Assuming 'Format' column for media type and 'Recording Date' for date + blu_rays = df[df['Format'] == 'Blu-Ray'] + + if blu_rays.empty: + return "No Blu-Ray titles found." + + # Ensure 'Recording Date' is in datetime format for proper comparison + blu_rays['Recording Date'] = pd.to_datetime(blu_rays['Recording Date']) + + oldest_blu_ray = blu_rays.sort_values(by='Recording Date', ascending=True).iloc[0] + + return oldest_blu_ray['Title'] + + Good example + def open_excel_files(excel_path: str) + Analyzes a spreadsheet. + + Arguments: + spreadsheet_path: The file path to the spreadsheet (e.g., 'C:/Users/user/data.xlsx'). + + Returns: + The excel file in string + " + import pandas as pd + + df = pd.read_excel(spreadsheet_path) + return df.to_string() + Rules: - Prefer generic names and parameters, never craft specific functions. - If the function contains specific numbers or values, it is wrong. - - Craft only one function, it must contains always the docs. + - The Tool must follow the Json schema protocol, Tuple is banned. + - Never return hardcoded or placeholder strings, the function must fetch real data. + - Craft a maximum of 3 tools, it must contains always the docs. If the number of tool crafted exceed, you fail. - Never craft tool that raise exceptions. - - Respond ONLY using the tool available. - No natural language. - - No comments in the python interpreter. + - No more than 1 line comments in the python codes. """ ), response_format=Response, - type="remote_response_format", + type="remote_crafter", ) reasoning_agent = LLM().create_custom_agent( @@ -138,12 +204,11 @@ def goal(prompt: MessagesState): runtime_graph.add_node(goal_node) runtime_graph.temp_node = goal_node for i in range(0, 3): + reasoning_node = ReasoningNode("") call_node = ToolNode( "Please, resolve the problem with the tools given, you MUST follow the previous reasoning.", "", - tool_name="", ) - reasoning_node = ReasoningNode("") runtime_graph.add_node(reasoning_node) runtime_graph.add_node(call_node) runtime_graph.add_edge(goal_node, reasoning_node) @@ -156,22 +221,6 @@ def goal(prompt: MessagesState): return prompt -# def tool_expand(goal: MessagesState): -# msg = parse_response(goal) -# sys_msg = "Please make a list using '-' to denote each tool in a probabilistic order, don't use this character for other reasons. Select only the tool(s) you want to use to solve this problem." -# messages = [ -# HumanMessage(msg), -# SystemMessage(sys_msg), -# ] -# res = starting_agent.invoke({"messages": messages}, config={"recursion_limit": MAX_INTERACTIONS}) -# str_res = parse_response(res) -# goal["messages"].append(AIMessage(content=str_res)) -# # tool_list = parse_tool_list(str_res) # Toglie elementi inutili -# # add tool nodes in the runtime graph - -# return goal - - def tool_reasoning(messages: MessagesState): messages["messages"].append( HumanMessage( @@ -194,7 +243,7 @@ def tool_call(messages: MessagesState): # It calls the llm and it resolves the call node call_node = runtime_graph.temp_node tool_agent = LLM().create_custom_agent( - LLM().get_tools() + [divide_thought], + LLM().get_tools(), SystemMessage( "You are an assistant specialized in tools. Your goal is to resolve the problem with " " the tool that the user indicates to you. You HAVE to use or craft the tool that the assistant indicates to you." @@ -235,6 +284,9 @@ def tool_call(messages: MessagesState): ) tool_used = extract_tool_used(res) runtime_graph.temp_response.response = parse_response_for_tool_node(res).response + runtime_graph.temp_response.explanation = parse_response_for_tool_node( + res + ).explanation parsed_res = f"Response: {parse_response_for_tool_node(res).response}\nExplanation: {parse_response_for_tool_node(res).explanation}" runtime_graph.resolve_node(call_node, parsed_res) @@ -284,23 +336,27 @@ def response_evaluation(messages: MessagesState): def crafting(messages: MessagesState): - crafting_node = CraftingNode(response="", tool_crafted="", resolved=False) + crafting_node = CraftingNode(response="", tools_crafted=[], resolved=False) runtime_graph.add_node(crafting_node) runtime_graph.add_edge(runtime_graph.temp_node, crafting_node) runtime_graph.temp_node = crafting_node + ai_feedback = runtime_graph.temp_response.explanation crafting_messages = [ HumanMessage(content="Original task:\n" + parse_response(runtime_graph.goal)), + AIMessage(content=ai_feedback), SystemMessage( - content="Craft a tool to solve this problem using craft_tool. It must be a function" + content="Use the context given to craft a tool to solve this problem using craft_tool. It must be a function" ), ] - craft_res = crafter_agent.invoke( - {"messages": crafting_messages}, config={"recursion_limit": MAX_INTERACTIONS} - ) - runtime_graph.temp_response.response = parse_response_for_tool_node( - craft_res - ).response - parsed_res = f"Response: {parse_response_for_tool_node(craft_res).response}\nExplanation: {parse_response_for_tool_node(craft_res).explanation}" + try: + craft_res = crafter_agent.invoke( + {"messages": crafting_messages}, + config={"recursion_limit": MAX_INTERACTIONS}, + ) + crafting_node.tools_crafted = extract_tools_crafted(craft_res) + parsed_res = parse_response(craft_res) + except Exception: + parsed_res = "" runtime_graph.resolve_node(crafting_node, parsed_res) runtime_graph.temp_node = runtime_graph.call_tool_node() runtime_graph.add_edge(crafting_node, runtime_graph.temp_node) @@ -322,36 +378,31 @@ def crafting(messages: MessagesState): def test_result(messages: MessagesState): - n = runtime_graph.exist_tool_available() + is_tool_path_available = runtime_graph.exist_tool_available() test_node = runtime_graph.temp_node if not isinstance(test_node, TestNode): raise TypeError("Expected TestNode for scoring") - - if test_node.score >= ( + threshold = ( COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity - ): + ) + if test_node.score >= threshold: runtime_graph.add_edge(test_node, runtime_graph.temp_response) runtime_graph.temp_response.resolved = True return END elif ( - test_node.score - < (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity) - and n is True + test_node.score < threshold + and is_tool_path_available is True and test_node.need_tool_crafting is True ): return "crafting" - elif ( - test_node.score - < (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity) - and n is True - ): + elif test_node.score < threshold and is_tool_path_available is True: if test_node.need_tool_crafting is True: test_node.response = "The problem is too complex to craft a new tool, try reason step by step or divide complexity." return "backtrack" elif ( test_node.score < (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity) - and n is False + and is_tool_path_available is False and runtime_graph.exist_reasoning_node_available() ): return "reasoning_mode" @@ -402,7 +453,6 @@ def backtrack(messages: MessagesState): runtime_graph.add_edge( backtrack_node, runtime_graph.temp_node ) # tool call node that we want to resolve - # messages = runtime_graph.append_prompt_to_messages_state(runtime_graph.temp_node) messages.get("messages", []).append(AIMessage(backtrack_node.feedback)) return messages @@ -433,7 +483,6 @@ def call_graph(prompt: str): def invoke_graph(): graph = StateGraph(MessagesState) graph.add_node(goal) - # graph.add_node(tool_expand) graph.add_node(tool_reasoning) graph.add_node(tool_call) graph.add_node(backtrack) @@ -442,7 +491,6 @@ def invoke_graph(): graph.add_node(response_evaluation) graph.add_node(reasoning_mode) graph.add_edge(START, "goal") - # graph.add_edge("goal", "tool_expand") graph.add_edge("goal", "tool_reasoning") graph.add_edge("tool_reasoning", "tool_call") graph.add_edge("tool_call", "response_evaluation") diff --git a/GoT/model/ollama_llm.py b/GoT/core/llm.py similarity index 79% rename from GoT/model/ollama_llm.py rename to GoT/core/llm.py index 65bfcf2..6ae2014 100644 --- a/GoT/model/ollama_llm.py +++ b/GoT/core/llm.py @@ -11,7 +11,7 @@ from langchain.agents import create_agent import mlflow -from GoT.tools.math_tool import ( +from GoT.agent_tools.math_tool import ( multiply, summing, minus, @@ -19,7 +19,7 @@ divide, ) -from GoT.tools.craft_tool import craft_tool, install_dependency +from GoT.agent_tools.craft_tool import craft_tool, install_dependency load_dotenv() @@ -51,21 +51,27 @@ def __init__(self): api_key=os.environ.get("GEMINI_API_KEY"), temperature=1.0, # Gemini 3.0+ defaults to 1.0 ) - self.remoteLLMResponseFormat = ChatGoogleGenerativeAI( + self.remoteLLMReasoning = ChatGoogleGenerativeAI( model="gemini-2.5-flash", api_key=os.environ.get("GEMINI_API_KEY"), temperature=1.0, # Gemini 3.0+ defaults to 1.0 ) - self.remoteLLMScoreFormat = ChatGoogleGenerativeAI( + self.remoteLLMCrafter = ChatGoogleGenerativeAI( model="gemini-2.5-flash", api_key=os.environ.get("GEMINI_API_KEY"), - temperature=0.7, # Gemini 3.0+ defaults to 1.0 + temperature=1.0, # Gemini 3.0+ defaults to 1.0 + ) + self.remoteLLMEvaluator = ChatGoogleGenerativeAI( + model="gemini-2.5-flash", + api_key=os.environ.get("GEMINI_API_KEY"), + temperature=1.0, # Gemini 3.0+ defaults to 1.0 ) self.remoteLLMs = { "remote_standard": self.remoteLLMStandard, - "remote_response_format": self.remoteLLMResponseFormat, - "remote_score_format": self.remoteLLMScoreFormat, + "remote_response_format": self.remoteLLMReasoning, + "remote_score_format": self.remoteLLMEvaluator, + "remote_crafter": self.remoteLLMCrafter, } self.system_prompt = SystemMessage(SYSTEM_PROMPT_GENERAL) @@ -79,7 +85,7 @@ def get_craft_tool(self): return [craft_tool, install_dependency] def get_crafted_tools(self) -> list[BaseTool]: - module_name = "GoT.tools.ai_tool" + module_name = "GoT.agent_tools.ai_tool" if module_name in sys.modules: module = importlib.reload(sys.modules[module_name]) else: diff --git a/GoT/model/runtime_graph.py b/GoT/core/runtime_graph.py similarity index 86% rename from GoT/model/runtime_graph.py rename to GoT/core/runtime_graph.py index 506f3d9..6661e15 100644 --- a/GoT/model/runtime_graph.py +++ b/GoT/core/runtime_graph.py @@ -1,17 +1,16 @@ from typing import Dict, List from langgraph.graph import MessagesState -from langchain_core.messages import AnyMessage, HumanMessage from pydantic import BaseModel, Field class RuntimeNode: - _id_counter = 0 # Contatore globale per ID unici + _id_counter = 0 # global ID counter def __init__( self, resolved: bool = False, ): - self.id = RuntimeNode._id_counter # ID unico per ogni nodo + self.id = RuntimeNode._id_counter RuntimeNode._id_counter += 1 self.resolved = resolved @@ -52,13 +51,11 @@ def __init__( self, prompt: str, response: str, - tool_name: str, resolved: bool = False, ): super().__init__(resolved) self.prompt = prompt self.response = response - self.tool_name = tool_name class GoalNode(RuntimeNode): @@ -107,22 +104,24 @@ class ResponseNode(RuntimeNode): def __init__( self, response: str, + explanation: str = "", resolved: bool = False, ): super().__init__(resolved) self.response = response + self.explanation = explanation class CraftingNode(RuntimeNode): def __init__( self, response: str, - tool_crafted: str = "", + tools_crafted: list[str] = [], resolved: bool = False, ): super().__init__(resolved) self.response = response - self.tool_crafted = tool_crafted + self.tools_crafted = tools_crafted class Score(BaseModel): @@ -168,7 +167,6 @@ class RuntimeGraph: def __init__(self): self.goal: MessagesState = MessagesState(messages=[]) self.nodes: Dict[RuntimeNode, List[RuntimeNode]] = {} - self.tools_available: Dict[RuntimeNode, str] = {} self.temp_node: RuntimeNode = RuntimeNode() self.temp_response: ResponseNode = ResponseNode(response="", resolved=False) @@ -179,9 +177,6 @@ def add_edge(self, n1: RuntimeNode, n2: RuntimeNode): self.nodes.setdefault(n1, []).append(n2) self.nodes.setdefault(n2, []) - def add_tool_link(self, call_node: RuntimeNode, tool_name: str): - self.tools_available.setdefault(call_node, tool_name) - def resolve_node(self, node: RuntimeNode, response: str) -> None: if isinstance(node, (ToolNode, TestNode, CompletitionNode)): node.response = response @@ -206,31 +201,16 @@ def exist_tool_available(self) -> bool: call_nodes = [n for n in nodes if (isinstance(n, ToolNode) and not n.resolved)] return True if call_nodes else False - def get_resolved_tools(self): - resolved_nodes = [t for t in self.tools_available.keys() if t.resolved is True] - return [self.tools_available[n] for n in resolved_nodes] - - def is_craftin_node_resolved(self) -> bool: + def is_crafting_node_resolved(self) -> bool: nodes = list(self.nodes.keys()) crafting_nodes = [ n for n in nodes if (isinstance(n, CraftingNode) and n.resolved) ] return True if crafting_nodes else False - def append_prompt_to_messages_state( - self, node: TestNode | ToolNode | CompletitionNode | GoalNode - ) -> MessagesState: - messages: list[AnyMessage] = [] - - if node.prompt: - messages.append(HumanMessage(content=node.prompt)) - - return MessagesState(messages=messages) - def clear(self): RuntimeNode._id_counter = 0 self.nodes = {} - self.tools_available = {} self.temp_node = RuntimeNode() self.temp_response = ResponseNode(response="", resolved=False) diff --git a/GoT/model/utils/hf_formatter.py b/GoT/experiments/hf_formatter.py similarity index 66% rename from GoT/model/utils/hf_formatter.py rename to GoT/experiments/hf_formatter.py index 7b1eddf..a3973c9 100644 --- a/GoT/model/utils/hf_formatter.py +++ b/GoT/experiments/hf_formatter.py @@ -1,19 +1,25 @@ import json +import os from random import shuffle import re from datasets import Dataset, load_dataset +from huggingface_hub import hf_hub_download from langchain.messages import HumanMessage -from GoT.model.graph_model import call_graph -from GoT.model.ollama_llm import LLM -from GoT.model.utils.utils import ( +from GoT.core.graph_model import call_graph +from GoT.core.llm import LLM +from GoT.utils.utils import ( + extract_answer_from_response, extract_output, normalize_list, normalize_number, + parse_response, symbolic_equal, ) +TOKEN = os.getenv("HF_TOKEN") + class ResultEval: def __init__( @@ -100,50 +106,6 @@ def gpqa_format(dataset: Dataset) -> list[ResultEval]: return questions -def gpqa_run(questions: list[ResultEval], max_run: int, test: bool) -> list[ResultEval]: - responses = [] - run_counter = 0 - agent = LLM().create_custom_agent(LLM().get_tools() + LLM().get_craft_tool()) - for q in questions[25:]: - if run_counter >= max_run: - break - prompt = q.question - correct_letter = q.correct_answer - try: - if test: - response = extract_output( - agent.invoke( - {"messages": [HumanMessage(content=prompt)]}, - config={"recursion_limit": 10}, - ) - ) - else: - response = extract_output(call_graph(prompt)) - norm_res = normalize_number(response) - responses.append( - ResultEval( - question=prompt, - response=norm_res, - filtered_answer="", - correct_answer=correct_letter, - answer_success=0.0, - ) - ) - except Exception as e: - print(f"Error processing question: {e}") - responses.append( - ResultEval( - question=prompt, - response="Error", - filtered_answer="", - correct_answer=correct_letter, - answer_success=0.0, - ) - ) - run_counter += 1 - return responses - - def gpqa_eval(responses: list[ResultEval]): correct = 0 @@ -183,52 +145,6 @@ def gsm8k_format(dataset: Dataset) -> list[ResultEval]: return questions -def gsm8k_run( - questions: list[ResultEval], max_run: int, test: bool -) -> list[ResultEval]: - responses = [] - run_counter = 0 - agent = LLM().create_custom_agent(LLM().get_tools() + LLM().get_craft_tool()) - for q in questions: - if run_counter >= max_run: - break - prompt = q.question - correct_answer = q.correct_answer - try: - if test: - response = extract_output( - agent.invoke( - {"messages": [HumanMessage(content=prompt)]}, - config={"recursion_limit": 20}, - ) - ) - else: - response = extract_output(call_graph(prompt)) - norm_res = normalize_number(response) - responses.append( - ResultEval( - question=prompt, - response=norm_res, - filtered_answer="", - correct_answer=correct_answer, - answer_success=0.0, - ) - ) - except Exception as e: - print(f"Error processing question: {e}") - responses.append( - ResultEval( - question=prompt, - response="Error", - filtered_answer="", - correct_answer=correct_answer, - answer_success=0.0, - ) - ) - run_counter += 1 - return responses - - def gsm8k_eval(responses: list[ResultEval]): correct = 0 @@ -272,12 +188,34 @@ def hendrycks_math_format(dataset: Dataset) -> list[ResultEval]: return questions -def hendrycks_math_run( +def hendrycks_math_eval(responses: list[ResultEval]): + correct = 0 + + for res in responses: + norm_res = extract_answer_from_response(res.response) + norm_correct = normalize_number(res.correct_answer) + res.filtered_answer = norm_res + + if ( + (norm_res in norm_correct) + or (normalize_list(norm_res) == normalize_list(norm_correct)) + or (symbolic_equal(norm_res, norm_correct)) + ): + correct += 1 + res.answer_success = 1.0 + + accuracy = correct / len(responses) * 100 + print(f"Accuracy: {accuracy:.2f}%") + print(f"Total: {len(responses)}") + print(f"Correct: {correct}") + + +def benchmark_run( questions: list[ResultEval], max_run: int, test: bool ) -> list[ResultEval]: responses = [] run_counter = 0 - agent = LLM().create_custom_agent(LLM().get_tools() + LLM().get_craft_tool()) + agent = LLM().create_custom_agent(LLM().get_tools()) for q in questions: if run_counter >= max_run: break @@ -285,7 +223,7 @@ def hendrycks_math_run( correct_answer = q.correct_answer try: if test: - response = extract_output( + response = parse_response( agent.invoke( {"messages": [HumanMessage(content=prompt)]}, config={"recursion_limit": 20}, @@ -318,20 +256,46 @@ def hendrycks_math_run( return responses -def hendrycks_math_eval(responses: list[ResultEval]): +def gaia_format(dataset: Dataset) -> list[ResultEval]: + questions = [] + for data in dataset: + sample = data + question = sample["Question"] + attachment = sample.get("file_name", None) + if attachment: + abs_path = hf_hub_download( + repo_id="gaia-benchmark/GAIA", + filename=f"2023/validation/{attachment}", + repo_type="dataset", + token=TOKEN, + ) + print(abs_path) + question += f"\nAttachment file path: {abs_path}" + correct_answer = sample["Final answer"] + prompt = ( + "Answer the following question. Think step by step before answering.\n\n" + f"{question}\n" + "Answer:" + ) + + questions.append( + ResultEval.create_empty_result( + question=prompt, correct_answer=correct_answer + ) + ) + + return questions + + +def gaia_eval(responses: list[ResultEval]): correct = 0 for res in responses: - opt_res = re.search(r"\\boxed\{(.*)\}", res.response) - norm_res = opt_res.group(1) if opt_res else "N/A" + norm_res = normalize_number(res.response) norm_correct = normalize_number(res.correct_answer) res.filtered_answer = norm_res - if ( - (norm_res in norm_correct) - or (normalize_list(norm_res) == normalize_list(norm_correct)) - or (symbolic_equal(norm_res, norm_correct)) - ): + if norm_res in norm_correct: correct += 1 res.answer_success = 1.0 @@ -342,27 +306,32 @@ def hendrycks_math_eval(responses: list[ResultEval]): def use_gpqa(max_run: int, test: bool, model_name: str): - ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond") - data = ds["train"] - questions = gpqa_format(data) - responses = gpqa_run(questions, max_run=max_run, test=test) + ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train") + questions = gpqa_format(ds) + responses = benchmark_run(questions, max_run=max_run, test=test) gpqa_eval(responses) save_eval_results(responses, model_name=model_name) def use_gsm8k(max_run: int, test: bool, model_name: str): - ds = load_dataset("gsm8k", "main") - data = ds["test"] - questions = gsm8k_format(data) - responses = gsm8k_run(questions, max_run=max_run, test=test) + ds = load_dataset("gsm8k", "main", split="test") + questions = gsm8k_format(ds) + responses = benchmark_run(questions, max_run=max_run, test=test) gsm8k_eval(responses) save_eval_results(responses, model_name=model_name) def use_hendrycks_math(max_run: int, test: bool, model_name: str, type: str): - ds = load_dataset("EleutherAI/hendrycks_math", type) - data = ds["test"] - questions = hendrycks_math_format(data) - responses = hendrycks_math_run(questions, max_run=max_run, test=test) + ds = load_dataset("EleutherAI/hendrycks_math", type, split="test") + questions = hendrycks_math_format(ds) + responses = benchmark_run(questions, max_run=max_run, test=test) hendrycks_math_eval(responses) save_eval_results(responses, model_name=model_name) + + +def use_gaia(max_run: int, test: bool, model_name: str): + ds = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation") + questions = gaia_format(ds) + responses = benchmark_run(questions, max_run=max_run, test=test) + gaia_eval(responses) + save_eval_results(responses, model_name=model_name) diff --git a/GoT/model/lm_wrapper.py b/GoT/experiments/lm_wrapper.py similarity index 98% rename from GoT/model/lm_wrapper.py rename to GoT/experiments/lm_wrapper.py index 1013c0b..8f9bc42 100644 --- a/GoT/model/lm_wrapper.py +++ b/GoT/experiments/lm_wrapper.py @@ -1,10 +1,10 @@ -from GoT.model.graph_model import call_graph +from GoT.core.graph_model import call_graph from lm_eval.api.registry import register_model from lm_eval.api.model import LM -from GoT.model.ollama_llm import LLM +from GoT.core.llm import LLM from langchain_core.messages import HumanMessage -from GoT.model.utils.utils import extract_output, normalize_number, parse_response +from GoT.utils.utils import extract_output, normalize_number, parse_response class LangGraphLM: diff --git a/GoT/experiments/runner_custom.py b/GoT/experiments/runner_custom.py new file mode 100644 index 0000000..ab9ec3e --- /dev/null +++ b/GoT/experiments/runner_custom.py @@ -0,0 +1,15 @@ +from langchain.messages import HumanMessage + +from GoT.core.graph_model import call_graph +from GoT.core.llm import LLM + + +def custom_test(text: str, is_graph_mode: bool): + if not is_graph_mode: + call_graph(text) + else: + agent = LLM().create_custom_agent(LLM().get_tools()) + agent.invoke( + {"messages": [HumanMessage(content=text)]}, + config={"recursion_limit": 20}, + ) diff --git a/GoT/model/utils/utils.py b/GoT/utils/utils.py similarity index 73% rename from GoT/model/utils/utils.py rename to GoT/utils/utils.py index ac14a65..37f8e64 100644 --- a/GoT/model/utils/utils.py +++ b/GoT/utils/utils.py @@ -1,10 +1,11 @@ import json import re +import mlflow import numpy as np from sympy import simplify, sympify -from GoT.model.runtime_graph import Response, Score +from GoT.core.runtime_graph import Response, Score from langgraph.graph import MessagesState from langchain_core.messages import AIMessage @@ -16,7 +17,10 @@ def parse_response(res) -> str: :param res: the MessagesState :return: The response in string """ - return res["messages"][-1].content + response = res["messages"][-1].text + if response is None or response == "": + response = res["messages"][-1].content + return response def parse_tool_list(response: str) -> list[str]: @@ -82,8 +86,14 @@ def parse_response_for_tool_node(response: MessagesState) -> Response: if isinstance(structured_response, Response): return structured_response elif score_res is not None: - data = json.loads(score_res) - return Response.model_validate(data) + try: + data = json.loads(score_res) + return Response.model_validate(data) + except json.JSONDecodeError: + return Response( + response=score_res, + explanation="", + ) else: return Response( response="Failed to parse response", @@ -108,6 +118,47 @@ def extract_tool_used(response: MessagesState) -> list[str]: return tools_used +def extract_function_signature(tool_crafted: dict) -> str: + """ + Extract name and arguments from function string. + """ + func_str = tool_crafted.get("tool_function", "") + match = re.search(r"def (\w+)\(([^)]*)\)", func_str) + if not match: + return "" + + func_name = match.group(1) + args = match.group(2) + + clean_args = ", ".join( + arg.split("=")[0].strip() + for arg in args.split(",") + if arg.strip() and arg.strip() != "self" + ) + + return f"{func_name}({clean_args})" + + +def extract_tools_crafted(response: MessagesState) -> list[str]: + """ + Extract the tools that LLM has crafted. + + :param response: The LLM response + :type response: MessagesState + :return: The list of tools crafted + :rtype: list[str] + """ + tools_crafted = [] + for msg in response.get("messages", []): + if isinstance(msg, AIMessage): + for tool_call in msg.tool_calls: + tool_crafted = tool_call["args"] + signature = extract_function_signature(tool_crafted) + if signature != "": + tools_crafted.append(signature) + return tools_crafted + + def remove_tools_from_list(tool_list, tools_to_remove): """ Remove a list of tools and return the updated list @@ -180,6 +231,30 @@ def symbolic_equal(a, b): return False +def extract_answer_from_response(response: str) -> str: + boxed_match = re.search(r"\\boxed\{(.*)\}", response) + if boxed_match: + return boxed_match.group(1).strip() + + boxed_match_alt = re.search(r"boxed\{(.*)\}", response) + if boxed_match_alt: + return boxed_match_alt.group(1).strip() + + answer_match = re.search(r"Answer:\s*(.*)", response, re.IGNORECASE) + if answer_match: + return answer_match.group(1).strip() + + clean_response = response.strip() + if not clean_response: + return "N/A" + + try: + float(clean_response) + return clean_response + except ValueError: + return "N/A" + + def print_benchmark_result(results: dict, task_name: str, filter: str) -> None: samples = results["samples"][task_name] @@ -227,3 +302,8 @@ def print_benchmark_result_loglikehood( print(f"Total: {n_total}") print(f"Correct: {n_correct}") print(f"Wrong: {n_wrong}") + + +def download_mlflow_traces(n_max: int): + traces = mlflow.search_traces(max_results=n_max, order_by=["timestamp DESC"]) + traces.to_csv("traces.csv", index=False) diff --git a/poetry.lock b/poetry.lock index e6e3dce..f5f1ce4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -251,6 +251,22 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] trio = ["trio (>=0.31.0)", "trio (>=0.32.0)"] +[[package]] +name = "arxiv" +version = "3.0.0" +description = "Python wrapper for the arXiv API" +optional = false +python-versions = ">=3.10" +files = [ + {file = "arxiv-3.0.0-py3-none-any.whl", hash = "sha256:8b4d4e2e336bfeb71ea653623d7dadb260f682f0475cee2aecad0560a23b34db"}, + {file = "arxiv-3.0.0.tar.gz", hash = "sha256:c8cb0d31208afbc1ceb17bd3f9816c8d4c5ca1e0abf199d211e216715440498d"}, +] + +[package.dependencies] +feedparser = ">=6.0.10,<6.1.0" +requests = ">=2.32,<2.34" +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + [[package]] name = "async-timeout" version = "5.0.1" @@ -273,6 +289,28 @@ files = [ {file = "attrs-25.4.0.tar.gz", hash = "sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11"}, ] +[[package]] +name = "beautifulsoup4" +version = "4.14.3" +description = "Screen-scraping library" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb"}, + {file = "beautifulsoup4-4.14.3.tar.gz", hash = "sha256:6292b1c5186d356bba669ef9f7f051757099565ad9ada5dd630bd9de5fa7fb86"}, +] + +[package.dependencies] +soupsieve = ">=1.6.1" +typing-extensions = ">=4.0.0" + +[package.extras] +cchardet = ["cchardet"] +chardet = ["chardet"] +charset-normalizer = ["charset-normalizer"] +html5lib = ["html5lib"] +lxml = ["lxml"] + [[package]] name = "blinker" version = "1.9.0" @@ -1290,6 +1328,20 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "feedparser" +version = "6.0.12" +description = "Universal feed parser, handles RSS 0.9x, RSS 1.0, RSS 2.0, CDF, Atom 0.3, and Atom 1.0 feeds" +optional = false +python-versions = ">=3.6" +files = [ + {file = "feedparser-6.0.12-py3-none-any.whl", hash = "sha256:6bbff10f5a52662c00a2e3f86a38928c37c48f77b3c511aedcd51de933549324"}, + {file = "feedparser-6.0.12.tar.gz", hash = "sha256:64f76ce90ae3e8ef5d1ede0f8d3b50ce26bcce71dd8ae5e82b1cd2d4a5f94228"}, +] + +[package.dependencies] +sgmllib3k = "*" + [[package]] name = "filelock" version = "3.24.2" @@ -5611,6 +5663,16 @@ enabler = ["pytest-enabler (>=2.2)"] test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.18.*)", "pytest-mypy"] +[[package]] +name = "sgmllib3k" +version = "1.0.0" +description = "Py3k port of sgmllib." +optional = false +python-versions = "*" +files = [ + {file = "sgmllib3k-1.0.0.tar.gz", hash = "sha256:7868fb1c8bfa764c1ac563d3cf369c381d1325d36124933a726f29fcdaa812e9"}, +] + [[package]] name = "shellingham" version = "1.5.4" @@ -5676,6 +5738,17 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "soupsieve" +version = "2.8.3" +description = "A modern CSS selector implementation for Beautiful Soup." +optional = false +python-versions = ">=3.9" +files = [ + {file = "soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95"}, + {file = "soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349"}, +] + [[package]] name = "sqlalchemy" version = "2.0.46" @@ -6367,6 +6440,20 @@ markupsafe = ">=2.1.1" [package.extras] watchdog = ["watchdog (>=2.3)"] +[[package]] +name = "wikipedia" +version = "1.4.0" +description = "Wikipedia API for Python" +optional = false +python-versions = "*" +files = [ + {file = "wikipedia-1.4.0.tar.gz", hash = "sha256:db0fad1829fdd441b1852306e9856398204dc0786d2996dd2e0c8bb8e26133b2"}, +] + +[package.dependencies] +beautifulsoup4 = "*" +requests = ">=2.0.0,<3.0.0" + [[package]] name = "word2number" version = "1.1" @@ -6868,4 +6955,4 @@ cffi = ["cffi (>=1.17,<2.0)", "cffi (>=2.0.0b)"] [metadata] lock-version = "2.0" python-versions = ">= 3.10.0 < 3.14.0" -content-hash = "c7c8e8591227891fa161988b58a7a8e708d8769ce133faf007103f95df4a0ef4" +content-hash = "afb0288bd77331357bfc74c7862dee098b73e2fce3028027cc78717cf022970b" diff --git a/pyproject.toml b/pyproject.toml index 95c44ad..27f02e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,8 @@ dotenv = "^0.9.9" mlflow = "^3.9.0" lm-eval = {extras = ["math"], version = "^0.4.11"} boto3 = "^1.42.51" +wikipedia = "^1.4.0" +arxiv = "^3.0.0" [tool.poetry.group.dev.dependencies] coverage = "^7.4.0" @@ -49,5 +51,5 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [[tool.mypy.overrides]] -module = ["lm_eval.*", "datasets.*", "sympy.*"] +module = ["lm_eval.*", "datasets.*", "sympy.*", "arxiv.*", "wikipedia.*"] follow_untyped_imports = true