diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 555263c..77a591e 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -60,7 +60,7 @@ jobs: npm install npx semantic-release env: - # PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} + PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }} RELEASE_TEST_PYPI: ${{ github.event.repository.is_template || contains(github.repository, 'template') }} # dry run if not on main/master branch, or if initial commit diff --git a/GoT/__init__.py b/GoT/__init__.py index a5df598..8a62a54 100644 --- a/GoT/__init__.py +++ b/GoT/__init__.py @@ -3,7 +3,6 @@ from dotenv import load_dotenv from lm_eval import evaluator, tasks -from GoT.core.graph_model import call_graph from GoT.experiments.lm_wrapper import LangGraphBigBenchWrapper, TestBigBenchWrapper from GoT.cli.parse_args import call_benchmark, defining_and_parse_args from GoT.utils.utils import ( @@ -62,14 +61,10 @@ def lm_eval_graph_benchmark(): print_benchmark_result_loglikehood(results, task_name, filter_val="none") -def custom_test(): - call_graph("Solve this integral ∫x2⋅ex2dx") - - def main(): - # It could be changed with custom_test() to test a custom problem instead of the benchmark args = defining_and_parse_args() call_benchmark(args) + # download_mlflow_traces(50) # let this be the last line of this file diff --git a/GoT/agent_tools/craft_tool.py b/GoT/agent_tools/craft_tool.py index dc1d2d1..d2569b4 100644 --- a/GoT/agent_tools/craft_tool.py +++ b/GoT/agent_tools/craft_tool.py @@ -83,6 +83,13 @@ def sanitize_input(query: str) -> str: func = functions[0] + try: + docstring = ast.get_docstring(func) + if not docstring or not docstring.strip(): + return "Error: missing docstring. A description of the function is mandatory for Gemini tools." + except TypeError: + return "Error: missing docstring. A description of the function is mandatory for Gemini tools." + for arg in func.args.args: if arg.annotation is None: return f"Error: missing type annotation for '{arg.arg}'" diff --git a/GoT/cli/parse_args.py b/GoT/cli/parse_args.py index f52e05e..8169319 100644 --- a/GoT/cli/parse_args.py +++ b/GoT/cli/parse_args.py @@ -7,6 +7,7 @@ use_gsm8k, use_hendrycks_math, ) +from GoT.experiments.runner_custom import custom_test def defining_and_parse_args(): @@ -17,7 +18,7 @@ def defining_and_parse_args(): "--benchmark", required=True, type=str, - choices=["gsm8k", "gpqa", "hendrycks_math", "gaia"], + choices=["gsm8k", "gpqa", "hendrycks_math", "gaia", "custom"], help="The benchmark to run the model on.", ) parser.add_argument( @@ -27,6 +28,9 @@ def defining_and_parse_args(): choices=["graph", "standard"], help="Whether to run the standard model or the graph model.", ) + parser.add_argument( + "--prompt", type=str, default="", help="Insert a prompt during a custom run." + ) parser.add_argument( "--max_run", type=int, @@ -34,7 +38,7 @@ def defining_and_parse_args(): help="The maximum number of runs for the benchmark.", ) parser.add_argument( - "--type", + "--category", type=str, default="algebra", choices=[ @@ -66,6 +70,10 @@ def call_benchmark(args): elif args.benchmark == "gpqa": use_gpqa(max_run=max_run, test=test, model_name=mode) elif args.benchmark == "hendrycks_math": - use_hendrycks_math(max_run=max_run, test=test, model_name=mode, type=args.type) + use_hendrycks_math( + max_run=max_run, test=test, model_name=mode, type=args.category + ) elif args.benchmark == "gaia": use_gaia(max_run=max_run, test=test, model_name=mode) + elif args.benchmark == "custom" and args.prompt != "": + custom_test(args.prompt, test) diff --git a/GoT/core/graph_model.py b/GoT/core/graph_model.py index c64e3b1..6ace4dc 100644 --- a/GoT/core/graph_model.py +++ b/GoT/core/graph_model.py @@ -17,6 +17,7 @@ ) from GoT.utils.utils import ( extract_tool_used, + extract_tools_crafted, parse_response, parse_response_for_tool_node, parse_score, @@ -54,16 +55,16 @@ Your duty is to score, from 0 to 5, the response that user gives and assign a score. Rules: + - If a response suggest the need of crafting a tool, score it with 1 or less and specify clearly the need of a new tool to solve the problem. - You MUST respond ONLY using the Score function. - You must consider if the format of the answer follow the instruction - You cannot give the full solution, only hints. - - If a response suggest the need of crafting a tool, score it with 1 or less and specify clearly the need of a new tool to solve the problem. - Do not write natural language outside the function. - Always consider creating a tool if it makes the response correct or reusable. Score meanings: 0: Impossible to understand / completely wrong - 1: Nearly completely wrong + 1: Nearly completely wrong / need to craft a tool 2: Correct language but does not follow instruction 3: Tries to solve but fails instruction / wrong 4: Follows instruction but result wrong or incomplete @@ -168,6 +169,7 @@ def open_excel_files(excel_path: str) Rules: - Prefer generic names and parameters, never craft specific functions. - If the function contains specific numbers or values, it is wrong. + - The Tool must follow the Json schema protocol, Tuple is banned. - Never return hardcoded or placeholder strings, the function must fetch real data. - Craft a maximum of 3 tools, it must contains always the docs. If the number of tool crafted exceed, you fail. - Never craft tool that raise exceptions. @@ -202,12 +204,11 @@ def goal(prompt: MessagesState): runtime_graph.add_node(goal_node) runtime_graph.temp_node = goal_node for i in range(0, 3): + reasoning_node = ReasoningNode("") call_node = ToolNode( "Please, resolve the problem with the tools given, you MUST follow the previous reasoning.", "", - tool_name="", ) - reasoning_node = ReasoningNode("") runtime_graph.add_node(reasoning_node) runtime_graph.add_node(call_node) runtime_graph.add_edge(goal_node, reasoning_node) @@ -242,7 +243,7 @@ def tool_call(messages: MessagesState): # It calls the llm and it resolves the call node call_node = runtime_graph.temp_node tool_agent = LLM().create_custom_agent( - LLM().get_tools() + [divide_thought], + LLM().get_tools(), SystemMessage( "You are an assistant specialized in tools. Your goal is to resolve the problem with " " the tool that the user indicates to you. You HAVE to use or craft the tool that the assistant indicates to you." @@ -335,7 +336,7 @@ def response_evaluation(messages: MessagesState): def crafting(messages: MessagesState): - crafting_node = CraftingNode(response="", tool_crafted="", resolved=False) + crafting_node = CraftingNode(response="", tools_crafted=[], resolved=False) runtime_graph.add_node(crafting_node) runtime_graph.add_edge(runtime_graph.temp_node, crafting_node) runtime_graph.temp_node = crafting_node @@ -352,12 +353,10 @@ def crafting(messages: MessagesState): {"messages": crafting_messages}, config={"recursion_limit": MAX_INTERACTIONS}, ) + crafting_node.tools_crafted = extract_tools_crafted(craft_res) parsed_res = parse_response(craft_res) except Exception: parsed_res = "" - # runtime_graph.temp_response.response = parse_response_for_tool_node( - # craft_res - # ).response runtime_graph.resolve_node(crafting_node, parsed_res) runtime_graph.temp_node = runtime_graph.call_tool_node() runtime_graph.add_edge(crafting_node, runtime_graph.temp_node) @@ -454,7 +453,6 @@ def backtrack(messages: MessagesState): runtime_graph.add_edge( backtrack_node, runtime_graph.temp_node ) # tool call node that we want to resolve - # messages = runtime_graph.append_prompt_to_messages_state(runtime_graph.temp_node) messages.get("messages", []).append(AIMessage(backtrack_node.feedback)) return messages diff --git a/GoT/core/llm.py b/GoT/core/llm.py index f76b50f..6ae2014 100644 --- a/GoT/core/llm.py +++ b/GoT/core/llm.py @@ -51,26 +51,26 @@ def __init__(self): api_key=os.environ.get("GEMINI_API_KEY"), temperature=1.0, # Gemini 3.0+ defaults to 1.0 ) - self.remoteLLMResponseFormat = ChatGoogleGenerativeAI( + self.remoteLLMReasoning = ChatGoogleGenerativeAI( model="gemini-2.5-flash", api_key=os.environ.get("GEMINI_API_KEY"), temperature=1.0, # Gemini 3.0+ defaults to 1.0 ) self.remoteLLMCrafter = ChatGoogleGenerativeAI( - model="gemini-3-flash-preview", + model="gemini-2.5-flash", api_key=os.environ.get("GEMINI_API_KEY"), temperature=1.0, # Gemini 3.0+ defaults to 1.0 ) - self.remoteLLMScoreFormat = ChatGoogleGenerativeAI( + self.remoteLLMEvaluator = ChatGoogleGenerativeAI( model="gemini-2.5-flash", api_key=os.environ.get("GEMINI_API_KEY"), - temperature=0.7, # Gemini 3.0+ defaults to 1.0 + temperature=1.0, # Gemini 3.0+ defaults to 1.0 ) self.remoteLLMs = { "remote_standard": self.remoteLLMStandard, - "remote_response_format": self.remoteLLMResponseFormat, - "remote_score_format": self.remoteLLMScoreFormat, + "remote_response_format": self.remoteLLMReasoning, + "remote_score_format": self.remoteLLMEvaluator, "remote_crafter": self.remoteLLMCrafter, } diff --git a/GoT/core/runtime_graph.py b/GoT/core/runtime_graph.py index 6047e23..6661e15 100644 --- a/GoT/core/runtime_graph.py +++ b/GoT/core/runtime_graph.py @@ -1,17 +1,16 @@ from typing import Dict, List from langgraph.graph import MessagesState -from langchain_core.messages import AnyMessage, HumanMessage from pydantic import BaseModel, Field class RuntimeNode: - _id_counter = 0 # Contatore globale per ID unici + _id_counter = 0 # global ID counter def __init__( self, resolved: bool = False, ): - self.id = RuntimeNode._id_counter # ID unico per ogni nodo + self.id = RuntimeNode._id_counter RuntimeNode._id_counter += 1 self.resolved = resolved @@ -52,13 +51,11 @@ def __init__( self, prompt: str, response: str, - tool_name: str, resolved: bool = False, ): super().__init__(resolved) self.prompt = prompt self.response = response - self.tool_name = tool_name class GoalNode(RuntimeNode): @@ -119,12 +116,12 @@ class CraftingNode(RuntimeNode): def __init__( self, response: str, - tool_crafted: str = "", + tools_crafted: list[str] = [], resolved: bool = False, ): super().__init__(resolved) self.response = response - self.tool_crafted = tool_crafted + self.tools_crafted = tools_crafted class Score(BaseModel): @@ -170,7 +167,6 @@ class RuntimeGraph: def __init__(self): self.goal: MessagesState = MessagesState(messages=[]) self.nodes: Dict[RuntimeNode, List[RuntimeNode]] = {} - self.tools_available: Dict[RuntimeNode, str] = {} self.temp_node: RuntimeNode = RuntimeNode() self.temp_response: ResponseNode = ResponseNode(response="", resolved=False) @@ -181,9 +177,6 @@ def add_edge(self, n1: RuntimeNode, n2: RuntimeNode): self.nodes.setdefault(n1, []).append(n2) self.nodes.setdefault(n2, []) - def add_tool_link(self, call_node: RuntimeNode, tool_name: str): - self.tools_available.setdefault(call_node, tool_name) - def resolve_node(self, node: RuntimeNode, response: str) -> None: if isinstance(node, (ToolNode, TestNode, CompletitionNode)): node.response = response @@ -208,31 +201,16 @@ def exist_tool_available(self) -> bool: call_nodes = [n for n in nodes if (isinstance(n, ToolNode) and not n.resolved)] return True if call_nodes else False - def get_resolved_tools(self): - resolved_nodes = [t for t in self.tools_available.keys() if t.resolved is True] - return [self.tools_available[n] for n in resolved_nodes] - - def is_craftin_node_resolved(self) -> bool: + def is_crafting_node_resolved(self) -> bool: nodes = list(self.nodes.keys()) crafting_nodes = [ n for n in nodes if (isinstance(n, CraftingNode) and n.resolved) ] return True if crafting_nodes else False - def append_prompt_to_messages_state( - self, node: TestNode | ToolNode | CompletitionNode | GoalNode - ) -> MessagesState: - messages: list[AnyMessage] = [] - - if node.prompt: - messages.append(HumanMessage(content=node.prompt)) - - return MessagesState(messages=messages) - def clear(self): RuntimeNode._id_counter = 0 self.nodes = {} - self.tools_available = {} self.temp_node = RuntimeNode() self.temp_response = ResponseNode(response="", resolved=False) diff --git a/GoT/experiments/hf_formatter.py b/GoT/experiments/hf_formatter.py index 6e2ebc8..a3973c9 100644 --- a/GoT/experiments/hf_formatter.py +++ b/GoT/experiments/hf_formatter.py @@ -10,9 +10,11 @@ from GoT.core.graph_model import call_graph from GoT.core.llm import LLM from GoT.utils.utils import ( + extract_answer_from_response, extract_output, normalize_list, normalize_number, + parse_response, symbolic_equal, ) @@ -190,8 +192,7 @@ def hendrycks_math_eval(responses: list[ResultEval]): correct = 0 for res in responses: - opt_res = re.search(r"\\boxed\{(.*)\}", res.response) - norm_res = opt_res.group(1) if opt_res else "N/A" + norm_res = extract_answer_from_response(res.response) norm_correct = normalize_number(res.correct_answer) res.filtered_answer = norm_res @@ -222,7 +223,7 @@ def benchmark_run( correct_answer = q.correct_answer try: if test: - response = extract_output( + response = parse_response( agent.invoke( {"messages": [HumanMessage(content=prompt)]}, config={"recursion_limit": 20}, diff --git a/GoT/experiments/runner_custom.py b/GoT/experiments/runner_custom.py new file mode 100644 index 0000000..ab9ec3e --- /dev/null +++ b/GoT/experiments/runner_custom.py @@ -0,0 +1,15 @@ +from langchain.messages import HumanMessage + +from GoT.core.graph_model import call_graph +from GoT.core.llm import LLM + + +def custom_test(text: str, is_graph_mode: bool): + if not is_graph_mode: + call_graph(text) + else: + agent = LLM().create_custom_agent(LLM().get_tools()) + agent.invoke( + {"messages": [HumanMessage(content=text)]}, + config={"recursion_limit": 20}, + ) diff --git a/GoT/utils/utils.py b/GoT/utils/utils.py index 1b936b4..37f8e64 100644 --- a/GoT/utils/utils.py +++ b/GoT/utils/utils.py @@ -17,7 +17,10 @@ def parse_response(res) -> str: :param res: the MessagesState :return: The response in string """ - return res["messages"][-1].content + response = res["messages"][-1].text + if response is None or response == "": + response = res["messages"][-1].content + return response def parse_tool_list(response: str) -> list[str]: @@ -83,8 +86,14 @@ def parse_response_for_tool_node(response: MessagesState) -> Response: if isinstance(structured_response, Response): return structured_response elif score_res is not None: - data = json.loads(score_res) - return Response.model_validate(data) + try: + data = json.loads(score_res) + return Response.model_validate(data) + except json.JSONDecodeError: + return Response( + response=score_res, + explanation="", + ) else: return Response( response="Failed to parse response", @@ -109,6 +118,47 @@ def extract_tool_used(response: MessagesState) -> list[str]: return tools_used +def extract_function_signature(tool_crafted: dict) -> str: + """ + Extract name and arguments from function string. + """ + func_str = tool_crafted.get("tool_function", "") + match = re.search(r"def (\w+)\(([^)]*)\)", func_str) + if not match: + return "" + + func_name = match.group(1) + args = match.group(2) + + clean_args = ", ".join( + arg.split("=")[0].strip() + for arg in args.split(",") + if arg.strip() and arg.strip() != "self" + ) + + return f"{func_name}({clean_args})" + + +def extract_tools_crafted(response: MessagesState) -> list[str]: + """ + Extract the tools that LLM has crafted. + + :param response: The LLM response + :type response: MessagesState + :return: The list of tools crafted + :rtype: list[str] + """ + tools_crafted = [] + for msg in response.get("messages", []): + if isinstance(msg, AIMessage): + for tool_call in msg.tool_calls: + tool_crafted = tool_call["args"] + signature = extract_function_signature(tool_crafted) + if signature != "": + tools_crafted.append(signature) + return tools_crafted + + def remove_tools_from_list(tool_list, tools_to_remove): """ Remove a list of tools and return the updated list @@ -181,6 +231,30 @@ def symbolic_equal(a, b): return False +def extract_answer_from_response(response: str) -> str: + boxed_match = re.search(r"\\boxed\{(.*)\}", response) + if boxed_match: + return boxed_match.group(1).strip() + + boxed_match_alt = re.search(r"boxed\{(.*)\}", response) + if boxed_match_alt: + return boxed_match_alt.group(1).strip() + + answer_match = re.search(r"Answer:\s*(.*)", response, re.IGNORECASE) + if answer_match: + return answer_match.group(1).strip() + + clean_response = response.strip() + if not clean_response: + return "N/A" + + try: + float(clean_response) + return clean_response + except ValueError: + return "N/A" + + def print_benchmark_result(results: dict, task_name: str, filter: str) -> None: samples = results["samples"][task_name]