From 677cab637884271ae7e7ac0ef6d2d3cffc58be0f Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Wed, 15 Apr 2026 10:35:13 +0200 Subject: [PATCH 01/14] chore: improve prompt --- GoT/model/graph_model.py | 1 + GoT/tools/runtime_graph_tool.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/GoT/model/graph_model.py b/GoT/model/graph_model.py index 57491ad..db01896 100644 --- a/GoT/model/graph_model.py +++ b/GoT/model/graph_model.py @@ -55,6 +55,7 @@ Rules: - You MUST respond ONLY using the Score function. + - You must consider if the format of the answer follow the instruction - You cannot give the full solution, only hints. - If a response suggest the need of crafting a tool, score it with 1 or less and specify clearly the need of a new tool to solve the problem. - Do not write natural language outside the function. diff --git a/GoT/tools/runtime_graph_tool.py b/GoT/tools/runtime_graph_tool.py index 77df52a..6b2a7ef 100644 --- a/GoT/tools/runtime_graph_tool.py +++ b/GoT/tools/runtime_graph_tool.py @@ -21,6 +21,8 @@ def divide_thought( HOW TO USE THIS TOOL: - Call it when you think the problem is complex. - The two parts must be as independent as possible. + IMPORTANT NOTES: + - You can't use the result of the first part to reason about the second part, and vice versa. The two parts must be as independent as possible. Arguments: - first_part: the first part of the thought process - second_part: the second part of the thought process From 6dc36a1323693394e27949a59386ddc4d4958efb Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Thu, 16 Apr 2026 10:12:49 +0200 Subject: [PATCH 02/14] chore: remove craft tools in hf formatter --- GoT/model/utils/hf_formatter.py | 8 ++++---- GoT/tools/craft_tool.py | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/GoT/model/utils/hf_formatter.py b/GoT/model/utils/hf_formatter.py index 7b1eddf..eadd6b0 100644 --- a/GoT/model/utils/hf_formatter.py +++ b/GoT/model/utils/hf_formatter.py @@ -103,7 +103,7 @@ def gpqa_format(dataset: Dataset) -> list[ResultEval]: def gpqa_run(questions: list[ResultEval], max_run: int, test: bool) -> list[ResultEval]: responses = [] run_counter = 0 - agent = LLM().create_custom_agent(LLM().get_tools() + LLM().get_craft_tool()) + agent = LLM().create_custom_agent(LLM().get_tools()) for q in questions[25:]: if run_counter >= max_run: break @@ -188,7 +188,7 @@ def gsm8k_run( ) -> list[ResultEval]: responses = [] run_counter = 0 - agent = LLM().create_custom_agent(LLM().get_tools() + LLM().get_craft_tool()) + agent = LLM().create_custom_agent(LLM().get_tools()) for q in questions: if run_counter >= max_run: break @@ -277,8 +277,8 @@ def hendrycks_math_run( ) -> list[ResultEval]: responses = [] run_counter = 0 - agent = LLM().create_custom_agent(LLM().get_tools() + LLM().get_craft_tool()) - for q in questions: + agent = LLM().create_custom_agent(LLM().get_tools()) + for q in questions[10:]: if run_counter >= max_run: break prompt = q.question diff --git a/GoT/tools/craft_tool.py b/GoT/tools/craft_tool.py index 8db13dd..dc1d2d1 100644 --- a/GoT/tools/craft_tool.py +++ b/GoT/tools/craft_tool.py @@ -56,6 +56,8 @@ def craft_tool(tool_function: str) -> str: """Save the function definition provided by the LLM as a tool that can be used by other agents. The function should be defined as a python function. The function should be general and reusable, and should not be specific to the current problem. + The function must not use tuple as args type. + The function must be defined as gemini api format, with type annotations for all arguments and return type. The function should be defined in a way that it can be imported and used by other agents.""" def sanitize_input(query: str) -> str: From 43331413f729730ce3f71ff7f916d19c55b65088 Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Thu, 16 Apr 2026 16:55:26 +0200 Subject: [PATCH 03/14] chore: create a single benchmark_run for all datasets --- GoT/model/utils/hf_formatter.py | 145 ++++++-------------------------- 1 file changed, 27 insertions(+), 118 deletions(-) diff --git a/GoT/model/utils/hf_formatter.py b/GoT/model/utils/hf_formatter.py index eadd6b0..fbbdb59 100644 --- a/GoT/model/utils/hf_formatter.py +++ b/GoT/model/utils/hf_formatter.py @@ -100,50 +100,6 @@ def gpqa_format(dataset: Dataset) -> list[ResultEval]: return questions -def gpqa_run(questions: list[ResultEval], max_run: int, test: bool) -> list[ResultEval]: - responses = [] - run_counter = 0 - agent = LLM().create_custom_agent(LLM().get_tools()) - for q in questions[25:]: - if run_counter >= max_run: - break - prompt = q.question - correct_letter = q.correct_answer - try: - if test: - response = extract_output( - agent.invoke( - {"messages": [HumanMessage(content=prompt)]}, - config={"recursion_limit": 10}, - ) - ) - else: - response = extract_output(call_graph(prompt)) - norm_res = normalize_number(response) - responses.append( - ResultEval( - question=prompt, - response=norm_res, - filtered_answer="", - correct_answer=correct_letter, - answer_success=0.0, - ) - ) - except Exception as e: - print(f"Error processing question: {e}") - responses.append( - ResultEval( - question=prompt, - response="Error", - filtered_answer="", - correct_answer=correct_letter, - answer_success=0.0, - ) - ) - run_counter += 1 - return responses - - def gpqa_eval(responses: list[ResultEval]): correct = 0 @@ -183,52 +139,6 @@ def gsm8k_format(dataset: Dataset) -> list[ResultEval]: return questions -def gsm8k_run( - questions: list[ResultEval], max_run: int, test: bool -) -> list[ResultEval]: - responses = [] - run_counter = 0 - agent = LLM().create_custom_agent(LLM().get_tools()) - for q in questions: - if run_counter >= max_run: - break - prompt = q.question - correct_answer = q.correct_answer - try: - if test: - response = extract_output( - agent.invoke( - {"messages": [HumanMessage(content=prompt)]}, - config={"recursion_limit": 20}, - ) - ) - else: - response = extract_output(call_graph(prompt)) - norm_res = normalize_number(response) - responses.append( - ResultEval( - question=prompt, - response=norm_res, - filtered_answer="", - correct_answer=correct_answer, - answer_success=0.0, - ) - ) - except Exception as e: - print(f"Error processing question: {e}") - responses.append( - ResultEval( - question=prompt, - response="Error", - filtered_answer="", - correct_answer=correct_answer, - answer_success=0.0, - ) - ) - run_counter += 1 - return responses - - def gsm8k_eval(responses: list[ResultEval]): correct = 0 @@ -272,13 +182,35 @@ def hendrycks_math_format(dataset: Dataset) -> list[ResultEval]: return questions -def hendrycks_math_run( +def hendrycks_math_eval(responses: list[ResultEval]): + correct = 0 + + for res in responses: + opt_res = re.search(r"\\boxed\{(.*)\}", res.response) + norm_res = opt_res.group(1) if opt_res else "N/A" + norm_correct = normalize_number(res.correct_answer) + res.filtered_answer = norm_res + + if ( + (norm_res in norm_correct) + or (normalize_list(norm_res) == normalize_list(norm_correct)) + or (symbolic_equal(norm_res, norm_correct)) + ): + correct += 1 + res.answer_success = 1.0 + + accuracy = correct / len(responses) * 100 + print(f"Accuracy: {accuracy:.2f}%") + print(f"Total: {len(responses)}") + print(f"Correct: {correct}") + +def benchmark_run( questions: list[ResultEval], max_run: int, test: bool ) -> list[ResultEval]: responses = [] run_counter = 0 agent = LLM().create_custom_agent(LLM().get_tools()) - for q in questions[10:]: + for q in questions[40:]: if run_counter >= max_run: break prompt = q.question @@ -318,34 +250,11 @@ def hendrycks_math_run( return responses -def hendrycks_math_eval(responses: list[ResultEval]): - correct = 0 - - for res in responses: - opt_res = re.search(r"\\boxed\{(.*)\}", res.response) - norm_res = opt_res.group(1) if opt_res else "N/A" - norm_correct = normalize_number(res.correct_answer) - res.filtered_answer = norm_res - - if ( - (norm_res in norm_correct) - or (normalize_list(norm_res) == normalize_list(norm_correct)) - or (symbolic_equal(norm_res, norm_correct)) - ): - correct += 1 - res.answer_success = 1.0 - - accuracy = correct / len(responses) * 100 - print(f"Accuracy: {accuracy:.2f}%") - print(f"Total: {len(responses)}") - print(f"Correct: {correct}") - - def use_gpqa(max_run: int, test: bool, model_name: str): ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond") data = ds["train"] questions = gpqa_format(data) - responses = gpqa_run(questions, max_run=max_run, test=test) + responses = benchmark_run(questions, max_run=max_run, test=test) gpqa_eval(responses) save_eval_results(responses, model_name=model_name) @@ -354,7 +263,7 @@ def use_gsm8k(max_run: int, test: bool, model_name: str): ds = load_dataset("gsm8k", "main") data = ds["test"] questions = gsm8k_format(data) - responses = gsm8k_run(questions, max_run=max_run, test=test) + responses = benchmark_run(questions, max_run=max_run, test=test) gsm8k_eval(responses) save_eval_results(responses, model_name=model_name) @@ -363,6 +272,6 @@ def use_hendrycks_math(max_run: int, test: bool, model_name: str, type: str): ds = load_dataset("EleutherAI/hendrycks_math", type) data = ds["test"] questions = hendrycks_math_format(data) - responses = hendrycks_math_run(questions, max_run=max_run, test=test) + responses = benchmark_run(questions, max_run=max_run, test=test) hendrycks_math_eval(responses) save_eval_results(responses, model_name=model_name) From 02a6e8e16550dd34efaef811542791ad5e876880 Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Thu, 16 Apr 2026 17:15:31 +0200 Subject: [PATCH 04/14] feat: gaia benchmark added --- GoT/model/utils/hf_formatter.py | 45 +++++++++++++++++++++++++++++++++ GoT/model/utils/parse_args.py | 6 +++-- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/GoT/model/utils/hf_formatter.py b/GoT/model/utils/hf_formatter.py index fbbdb59..a26ecce 100644 --- a/GoT/model/utils/hf_formatter.py +++ b/GoT/model/utils/hf_formatter.py @@ -250,6 +250,43 @@ def benchmark_run( return responses +def gaia_format(dataset: Dataset) -> list[ResultEval]: + questions = [] + for data in dataset: + sample = data + question = sample["Question"] + correct_answer = sample["Final answer"] + prompt = ( + "Answer the following question. Think step by step before answering.\n\n" + f"{question}\n" + "Answer:" + ) + + questions.append( + ResultEval.create_empty_result( + question=prompt, correct_answer=correct_answer + ) + ) + + return questions + +def gaia_eval(responses: list[ResultEval]): + correct = 0 + + for res in responses: + norm_res = normalize_number(res.response) + norm_correct = normalize_number(res.correct_answer) + res.filtered_answer = norm_res + + if norm_res in norm_correct: + correct += 1 + res.answer_success = 1.0 + + accuracy = correct / len(responses) * 100 + print(f"Accuracy: {accuracy:.2f}%") + print(f"Total: {len(responses)}") + print(f"Correct: {correct}") + def use_gpqa(max_run: int, test: bool, model_name: str): ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond") data = ds["train"] @@ -275,3 +312,11 @@ def use_hendrycks_math(max_run: int, test: bool, model_name: str, type: str): responses = benchmark_run(questions, max_run=max_run, test=test) hendrycks_math_eval(responses) save_eval_results(responses, model_name=model_name) + +def use_gaia(max_run: int, test: bool, model_name: str): + ds = load_dataset("gaia-benchmark/GAIA", "2023_level1") + data = ds["test"] + questions = gaia_format(data) + responses = benchmark_run(questions, max_run=max_run, test=test) + gaia_eval(responses) + save_eval_results(responses, model_name=model_name) diff --git a/GoT/model/utils/parse_args.py b/GoT/model/utils/parse_args.py index d2437a3..6cb3b9d 100644 --- a/GoT/model/utils/parse_args.py +++ b/GoT/model/utils/parse_args.py @@ -1,7 +1,7 @@ import argparse import sys -from GoT.model.utils.hf_formatter import use_gpqa, use_gsm8k, use_hendrycks_math +from GoT.model.utils.hf_formatter import use_gaia, use_gpqa, use_gsm8k, use_hendrycks_math def defining_and_parse_args(): @@ -12,7 +12,7 @@ def defining_and_parse_args(): "--benchmark", required=True, type=str, - choices=["gsm8k", "gpqa", "hendrycks_math"], + choices=["gsm8k", "gpqa", "hendrycks_math", "gaia"], help="The benchmark to run the model on.", ) parser.add_argument( @@ -62,3 +62,5 @@ def call_benchmark(args): use_gpqa(max_run=max_run, test=test, model_name=mode) elif args.benchmark == "hendrycks_math": use_hendrycks_math(max_run=max_run, test=test, model_name=mode, type=args.type) + elif args.benchmark == "gaia": + use_gaia(max_run=max_run, test=test, model_name=mode) From 626007f61628052c0680d0faf38fefff10e7e430 Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Fri, 17 Apr 2026 15:34:16 +0200 Subject: [PATCH 05/14] chore: simplify codes --- GoT/model/utils/hf_formatter.py | 35 +++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/GoT/model/utils/hf_formatter.py b/GoT/model/utils/hf_formatter.py index a26ecce..5c9b162 100644 --- a/GoT/model/utils/hf_formatter.py +++ b/GoT/model/utils/hf_formatter.py @@ -1,7 +1,9 @@ import json +import os from random import shuffle import re from datasets import Dataset, load_dataset +from huggingface_hub import hf_hub_download from langchain.messages import HumanMessage @@ -14,6 +16,7 @@ symbolic_equal, ) +TOKEN = os.getenv("HF_TOKEN") class ResultEval: def __init__( @@ -210,7 +213,7 @@ def benchmark_run( responses = [] run_counter = 0 agent = LLM().create_custom_agent(LLM().get_tools()) - for q in questions[40:]: + for q in questions: if run_counter >= max_run: break prompt = q.question @@ -255,6 +258,16 @@ def gaia_format(dataset: Dataset) -> list[ResultEval]: for data in dataset: sample = data question = sample["Question"] + attachment = sample.get("file_name", None) + if attachment: + abs_path = hf_hub_download( + repo_id="gaia-benchmark/GAIA", + filename=f"2023/validation/{attachment}", + repo_type="dataset", + token=TOKEN + ) + print(abs_path) + question += f"\nAttachment file path: {abs_path}" correct_answer = sample["Final answer"] prompt = ( "Answer the following question. Think step by step before answering.\n\n" @@ -288,35 +301,31 @@ def gaia_eval(responses: list[ResultEval]): print(f"Correct: {correct}") def use_gpqa(max_run: int, test: bool, model_name: str): - ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond") - data = ds["train"] - questions = gpqa_format(data) + ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train") + questions = gpqa_format(ds) responses = benchmark_run(questions, max_run=max_run, test=test) gpqa_eval(responses) save_eval_results(responses, model_name=model_name) def use_gsm8k(max_run: int, test: bool, model_name: str): - ds = load_dataset("gsm8k", "main") - data = ds["test"] - questions = gsm8k_format(data) + ds = load_dataset("gsm8k", "main", split="test") + questions = gsm8k_format(ds) responses = benchmark_run(questions, max_run=max_run, test=test) gsm8k_eval(responses) save_eval_results(responses, model_name=model_name) def use_hendrycks_math(max_run: int, test: bool, model_name: str, type: str): - ds = load_dataset("EleutherAI/hendrycks_math", type) - data = ds["test"] - questions = hendrycks_math_format(data) + ds = load_dataset("EleutherAI/hendrycks_math", type, split="test") + questions = hendrycks_math_format(ds) responses = benchmark_run(questions, max_run=max_run, test=test) hendrycks_math_eval(responses) save_eval_results(responses, model_name=model_name) def use_gaia(max_run: int, test: bool, model_name: str): - ds = load_dataset("gaia-benchmark/GAIA", "2023_level1") - data = ds["test"] - questions = gaia_format(data) + ds = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation") + questions = gaia_format(ds) responses = benchmark_run(questions, max_run=max_run, test=test) gaia_eval(responses) save_eval_results(responses, model_name=model_name) From f62ed949b8091ee4b820dad0b5b07db5399870a4 Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Fri, 17 Apr 2026 15:34:47 +0200 Subject: [PATCH 06/14] chore: add explanation of tool needed --- GoT/model/graph_model.py | 4 ++++ GoT/model/runtime_graph.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/GoT/model/graph_model.py b/GoT/model/graph_model.py index db01896..ab6cc1c 100644 --- a/GoT/model/graph_model.py +++ b/GoT/model/graph_model.py @@ -82,6 +82,7 @@ LLM().get_craft_tool(), SystemMessage( """ + You are a master coder specialized in crafting tools for other agents. You create reusable Python tools for other agents. The tool must be GENERAL and parameterized. @@ -236,6 +237,7 @@ def tool_call(messages: MessagesState): ) tool_used = extract_tool_used(res) runtime_graph.temp_response.response = parse_response_for_tool_node(res).response + runtime_graph.temp_response.explanation = parse_response_for_tool_node(res).explanation parsed_res = f"Response: {parse_response_for_tool_node(res).response}\nExplanation: {parse_response_for_tool_node(res).explanation}" runtime_graph.resolve_node(call_node, parsed_res) @@ -289,8 +291,10 @@ def crafting(messages: MessagesState): runtime_graph.add_node(crafting_node) runtime_graph.add_edge(runtime_graph.temp_node, crafting_node) runtime_graph.temp_node = crafting_node + ai_feedback = runtime_graph.temp_response.explanation crafting_messages = [ HumanMessage(content="Original task:\n" + parse_response(runtime_graph.goal)), + AIMessage(content=ai_feedback), SystemMessage( content="Craft a tool to solve this problem using craft_tool. It must be a function" ), diff --git a/GoT/model/runtime_graph.py b/GoT/model/runtime_graph.py index 506f3d9..6047e23 100644 --- a/GoT/model/runtime_graph.py +++ b/GoT/model/runtime_graph.py @@ -107,10 +107,12 @@ class ResponseNode(RuntimeNode): def __init__( self, response: str, + explanation: str = "", resolved: bool = False, ): super().__init__(resolved) self.response = response + self.explanation = explanation class CraftingNode(RuntimeNode): From 66811f2425c8bbbc5eca10d0240984ac94d03c05 Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Tue, 21 Apr 2026 10:22:33 +0200 Subject: [PATCH 07/14] chore: remove comments --- GoT/model/graph_model.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/GoT/model/graph_model.py b/GoT/model/graph_model.py index ab6cc1c..d289e95 100644 --- a/GoT/model/graph_model.py +++ b/GoT/model/graph_model.py @@ -158,22 +158,6 @@ def goal(prompt: MessagesState): return prompt -# def tool_expand(goal: MessagesState): -# msg = parse_response(goal) -# sys_msg = "Please make a list using '-' to denote each tool in a probabilistic order, don't use this character for other reasons. Select only the tool(s) you want to use to solve this problem." -# messages = [ -# HumanMessage(msg), -# SystemMessage(sys_msg), -# ] -# res = starting_agent.invoke({"messages": messages}, config={"recursion_limit": MAX_INTERACTIONS}) -# str_res = parse_response(res) -# goal["messages"].append(AIMessage(content=str_res)) -# # tool_list = parse_tool_list(str_res) # Toglie elementi inutili -# # add tool nodes in the runtime graph - -# return goal - - def tool_reasoning(messages: MessagesState): messages["messages"].append( HumanMessage( @@ -438,7 +422,6 @@ def call_graph(prompt: str): def invoke_graph(): graph = StateGraph(MessagesState) graph.add_node(goal) - # graph.add_node(tool_expand) graph.add_node(tool_reasoning) graph.add_node(tool_call) graph.add_node(backtrack) @@ -447,7 +430,6 @@ def invoke_graph(): graph.add_node(response_evaluation) graph.add_node(reasoning_mode) graph.add_edge(START, "goal") - # graph.add_edge("goal", "tool_expand") graph.add_edge("goal", "tool_reasoning") graph.add_edge("tool_reasoning", "tool_call") graph.add_edge("tool_call", "response_evaluation") From 073be66c12929be6a912874a14ebcb7489cdc89e Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Wed, 22 Apr 2026 10:14:56 +0200 Subject: [PATCH 08/14] chore: add method to download mlflow traces --- GoT/model/utils/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/GoT/model/utils/utils.py b/GoT/model/utils/utils.py index ac14a65..3ae103a 100644 --- a/GoT/model/utils/utils.py +++ b/GoT/model/utils/utils.py @@ -1,6 +1,7 @@ import json import re +import mlflow import numpy as np from sympy import simplify, sympify @@ -227,3 +228,7 @@ def print_benchmark_result_loglikehood( print(f"Total: {n_total}") print(f"Correct: {n_correct}") print(f"Wrong: {n_wrong}") + +def download_mlflow_traces(n_max: int): + traces = mlflow.search_traces(max_results=n_max, order_by=["timestamp DESC"]) + traces.to_csv("traces.csv", index=False) From 67ecf42e025968e3742740c7c8b6443a7d6e3b4b Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Thu, 23 Apr 2026 16:27:23 +0200 Subject: [PATCH 09/14] chore: change var name and simplify codes --- GoT/model/graph_model.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/GoT/model/graph_model.py b/GoT/model/graph_model.py index d289e95..7769781 100644 --- a/GoT/model/graph_model.py +++ b/GoT/model/graph_model.py @@ -82,7 +82,7 @@ LLM().get_craft_tool(), SystemMessage( """ - You are a master coder specialized in crafting tools for other agents. + You are a specialized in coding and write new useful method. You create reusable Python tools for other agents. The tool must be GENERAL and parameterized. @@ -311,28 +311,24 @@ def crafting(messages: MessagesState): def test_result(messages: MessagesState): - n = runtime_graph.exist_tool_available() + is_tool_path_available = runtime_graph.exist_tool_available() test_node = runtime_graph.temp_node if not isinstance(test_node, TestNode): raise TypeError("Expected TestNode for scoring") - - if test_node.score >= ( - COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity - ): + threshold = COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity + if test_node.score >= threshold: runtime_graph.add_edge(test_node, runtime_graph.temp_response) runtime_graph.temp_response.resolved = True return END elif ( - test_node.score - < (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity) - and n is True + test_node.score < threshold + and is_tool_path_available is True and test_node.need_tool_crafting is True ): return "crafting" elif ( - test_node.score - < (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity) - and n is True + test_node.score < threshold + and is_tool_path_available is True ): if test_node.need_tool_crafting is True: test_node.response = "The problem is too complex to craft a new tool, try reason step by step or divide complexity." @@ -340,7 +336,7 @@ def test_result(messages: MessagesState): elif ( test_node.score < (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity) - and n is False + and is_tool_path_available is False and runtime_graph.exist_reasoning_node_available() ): return "reasoning_mode" From bef5e20d1a682c6da2ee5152497cfbd497dddb47 Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Fri, 24 Apr 2026 10:47:38 +0200 Subject: [PATCH 10/14] feat: add wikipedia and arxiv tools --- GoT/model/ollama_llm.py | 3 +- GoT/tools/web_tool.py | 54 +++++++++++++++++++++++++ poetry.lock | 89 ++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 2 + 4 files changed, 146 insertions(+), 2 deletions(-) create mode 100644 GoT/tools/web_tool.py diff --git a/GoT/model/ollama_llm.py b/GoT/model/ollama_llm.py index 65bfcf2..5a59369 100644 --- a/GoT/model/ollama_llm.py +++ b/GoT/model/ollama_llm.py @@ -20,6 +20,7 @@ ) from GoT.tools.craft_tool import craft_tool, install_dependency +from GoT.tools.web_tool import search_arxiv, search_wikipedia load_dotenv() @@ -71,7 +72,7 @@ def __init__(self): self.system_prompt = SystemMessage(SYSTEM_PROMPT_GENERAL) def get_tools(self): - initial_tools = [summing, minus, square_root, multiply, divide] + initial_tools = [summing, minus, square_root, multiply, divide, search_wikipedia, search_arxiv] crafted_tools = self.get_crafted_tools() return initial_tools + crafted_tools diff --git a/GoT/tools/web_tool.py b/GoT/tools/web_tool.py new file mode 100644 index 0000000..5987bce --- /dev/null +++ b/GoT/tools/web_tool.py @@ -0,0 +1,54 @@ +import arxiv +from langchain.tools import tool +import wikipedia + +@tool +def search_wikipedia(query: str) -> str: + """ + Fetch a brief summary from Wikipedia. + + Args: + query (str): The keyword or topic to search for. + + Returns: + str: A 3-sentence summary of the topic, the first option if + ambiguous, or an error message if not found. + """ + try: + return wikipedia.search(query) + except wikipedia.DisambiguationError as e: + # happens when query is ambiguous, pick first option + return wikipedia.summary(e.options[0], sentences=3) + except wikipedia.PageError: + return "Page not found" + +@tool +def search_arxiv(query: str) -> str: + """Search ArXiv for scientific papers on a given topic. + Use this when you need to find research papers, abstracts or academic references.""" + + try: + client = arxiv.Client() + search = arxiv.Search( + query=query, + max_results=3, + sort_by=arxiv.SortCriterion.Relevance + ) + + results = [] + for paper in client.results(search): + results.append( + f"Title: {paper.title}\n" + f"Authors: {', '.join(a.name for a in paper.authors)}\n" + f"Published: {paper.published.strftime('%Y-%m-%d')}\n" + f"Summary: {paper.summary[:300]}...\n" + f"URL: {paper.entry_id}\n" + ) + + if not results: + return "No papers found for this query." + + return "\n---\n".join(results) + + except Exception as e: + return f"Error searching ArXiv: {str(e)}" \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index e6e3dce..f5f1ce4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -251,6 +251,22 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] trio = ["trio (>=0.31.0)", "trio (>=0.32.0)"] +[[package]] +name = "arxiv" +version = "3.0.0" +description = "Python wrapper for the arXiv API" +optional = false +python-versions = ">=3.10" +files = [ + {file = "arxiv-3.0.0-py3-none-any.whl", hash = "sha256:8b4d4e2e336bfeb71ea653623d7dadb260f682f0475cee2aecad0560a23b34db"}, + {file = "arxiv-3.0.0.tar.gz", hash = "sha256:c8cb0d31208afbc1ceb17bd3f9816c8d4c5ca1e0abf199d211e216715440498d"}, +] + +[package.dependencies] +feedparser = ">=6.0.10,<6.1.0" +requests = ">=2.32,<2.34" +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} + [[package]] name = "async-timeout" version = "5.0.1" @@ -273,6 +289,28 @@ files = [ {file = "attrs-25.4.0.tar.gz", hash = "sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11"}, ] +[[package]] +name = "beautifulsoup4" +version = "4.14.3" +description = "Screen-scraping library" +optional = false +python-versions = ">=3.7.0" +files = [ + {file = "beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb"}, + {file = "beautifulsoup4-4.14.3.tar.gz", hash = "sha256:6292b1c5186d356bba669ef9f7f051757099565ad9ada5dd630bd9de5fa7fb86"}, +] + +[package.dependencies] +soupsieve = ">=1.6.1" +typing-extensions = ">=4.0.0" + +[package.extras] +cchardet = ["cchardet"] +chardet = ["chardet"] +charset-normalizer = ["charset-normalizer"] +html5lib = ["html5lib"] +lxml = ["lxml"] + [[package]] name = "blinker" version = "1.9.0" @@ -1290,6 +1328,20 @@ files = [ [package.extras] devel = ["colorama", "json-spec", "jsonschema", "pylint", "pytest", "pytest-benchmark", "pytest-cache", "validictory"] +[[package]] +name = "feedparser" +version = "6.0.12" +description = "Universal feed parser, handles RSS 0.9x, RSS 1.0, RSS 2.0, CDF, Atom 0.3, and Atom 1.0 feeds" +optional = false +python-versions = ">=3.6" +files = [ + {file = "feedparser-6.0.12-py3-none-any.whl", hash = "sha256:6bbff10f5a52662c00a2e3f86a38928c37c48f77b3c511aedcd51de933549324"}, + {file = "feedparser-6.0.12.tar.gz", hash = "sha256:64f76ce90ae3e8ef5d1ede0f8d3b50ce26bcce71dd8ae5e82b1cd2d4a5f94228"}, +] + +[package.dependencies] +sgmllib3k = "*" + [[package]] name = "filelock" version = "3.24.2" @@ -5611,6 +5663,16 @@ enabler = ["pytest-enabler (>=2.2)"] test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.18.*)", "pytest-mypy"] +[[package]] +name = "sgmllib3k" +version = "1.0.0" +description = "Py3k port of sgmllib." +optional = false +python-versions = "*" +files = [ + {file = "sgmllib3k-1.0.0.tar.gz", hash = "sha256:7868fb1c8bfa764c1ac563d3cf369c381d1325d36124933a726f29fcdaa812e9"}, +] + [[package]] name = "shellingham" version = "1.5.4" @@ -5676,6 +5738,17 @@ files = [ {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, ] +[[package]] +name = "soupsieve" +version = "2.8.3" +description = "A modern CSS selector implementation for Beautiful Soup." +optional = false +python-versions = ">=3.9" +files = [ + {file = "soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95"}, + {file = "soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349"}, +] + [[package]] name = "sqlalchemy" version = "2.0.46" @@ -6367,6 +6440,20 @@ markupsafe = ">=2.1.1" [package.extras] watchdog = ["watchdog (>=2.3)"] +[[package]] +name = "wikipedia" +version = "1.4.0" +description = "Wikipedia API for Python" +optional = false +python-versions = "*" +files = [ + {file = "wikipedia-1.4.0.tar.gz", hash = "sha256:db0fad1829fdd441b1852306e9856398204dc0786d2996dd2e0c8bb8e26133b2"}, +] + +[package.dependencies] +beautifulsoup4 = "*" +requests = ">=2.0.0,<3.0.0" + [[package]] name = "word2number" version = "1.1" @@ -6868,4 +6955,4 @@ cffi = ["cffi (>=1.17,<2.0)", "cffi (>=2.0.0b)"] [metadata] lock-version = "2.0" python-versions = ">= 3.10.0 < 3.14.0" -content-hash = "c7c8e8591227891fa161988b58a7a8e708d8769ce133faf007103f95df4a0ef4" +content-hash = "afb0288bd77331357bfc74c7862dee098b73e2fce3028027cc78717cf022970b" diff --git a/pyproject.toml b/pyproject.toml index 95c44ad..20e60bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,8 @@ dotenv = "^0.9.9" mlflow = "^3.9.0" lm-eval = {extras = ["math"], version = "^0.4.11"} boto3 = "^1.42.51" +wikipedia = "^1.4.0" +arxiv = "^3.0.0" [tool.poetry.group.dev.dependencies] coverage = "^7.4.0" From f5738a2f521d343c846e9237ea1ecd6632028d81 Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Fri, 24 Apr 2026 10:47:51 +0200 Subject: [PATCH 11/14] fix: fix names in type arg --- GoT/model/utils/parse_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GoT/model/utils/parse_args.py b/GoT/model/utils/parse_args.py index 6cb3b9d..818903a 100644 --- a/GoT/model/utils/parse_args.py +++ b/GoT/model/utils/parse_args.py @@ -39,7 +39,7 @@ def defining_and_parse_args(): "intermediate_algebra", "number_theory", "precalculus", - "statistics", + "prealgebra" ], help="The type of math problems to run, only for hendrycks_math.", ) From ca7848503e875783b688fc65ff63286436c67192 Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Sat, 25 Apr 2026 14:24:23 +0200 Subject: [PATCH 12/14] chore: add specific crafter LLM and improve prompt --- GoT/model/graph_model.py | 89 ++++++++++++++++++++++++++++++++++------ GoT/model/ollama_llm.py | 8 +++- 2 files changed, 84 insertions(+), 13 deletions(-) diff --git a/GoT/model/graph_model.py b/GoT/model/graph_model.py index 7769781..8fc881a 100644 --- a/GoT/model/graph_model.py +++ b/GoT/model/graph_model.py @@ -103,18 +103,80 @@ def multiply(a: float, b: float) -> float: ' return a * b + + Bad example (hardcoded/placeholder result): + def search_papers(query: str) -> str: + return "Results about " + query # WRONG: never return hardcoded strings + + Good example (real API call): + def search_papers(query: str) -> str: + ' + Arguments: + query: the search query string + Returns: + A string with real results fetched from the API + ' + import arxiv + client = arxiv.Client() + search = arxiv.Search(query=query, max_results=3) + results = [p.title + ": " + p.summary[:200] for p in client.results(search)] + return "\\n".join(results) + + Bad example: Too specific + def get_oldest_blu_ray_title(spreadsheet_path: str) -> str: + " + Analyzes a spreadsheet to find the oldest Blu-Ray title. + + Arguments: + spreadsheet_path: The file path to the spreadsheet (e.g., 'C:/Users/user/data.xlsx'). + + Returns: + The title of the oldest Blu-Ray as it appears in the spreadsheet. + " + import pandas as pd + + df = pd.read_excel(spreadsheet_path) + + # Assuming 'Format' column for media type and 'Recording Date' for date + blu_rays = df[df['Format'] == 'Blu-Ray'] + + if blu_rays.empty: + return "No Blu-Ray titles found." + + # Ensure 'Recording Date' is in datetime format for proper comparison + blu_rays['Recording Date'] = pd.to_datetime(blu_rays['Recording Date']) + + oldest_blu_ray = blu_rays.sort_values(by='Recording Date', ascending=True).iloc[0] + + return oldest_blu_ray['Title'] + + Good example + def open_excel_files(excel_path: str) + Analyzes a spreadsheet. + + Arguments: + spreadsheet_path: The file path to the spreadsheet (e.g., 'C:/Users/user/data.xlsx'). + + Returns: + The excel file in string + " + import pandas as pd + + df = pd.read_excel(spreadsheet_path) + return df.to_string() + Rules: - Prefer generic names and parameters, never craft specific functions. - If the function contains specific numbers or values, it is wrong. - - Craft only one function, it must contains always the docs. + - Never return hardcoded or placeholder strings, the function must fetch real data. + - Craft a maximum of 3 tools, it must contains always the docs. If the number of tool crafted exceed, you fail. - Never craft tool that raise exceptions. - - Respond ONLY using the tool available. - No natural language. - - No comments in the python interpreter. + - No more than 1 line comments in the python codes. """ ), response_format=Response, - type="remote_response_format", + type="remote_crafter", ) reasoning_agent = LLM().create_custom_agent( @@ -280,16 +342,19 @@ def crafting(messages: MessagesState): HumanMessage(content="Original task:\n" + parse_response(runtime_graph.goal)), AIMessage(content=ai_feedback), SystemMessage( - content="Craft a tool to solve this problem using craft_tool. It must be a function" + content="Use the context given to craft a tool to solve this problem using craft_tool. It must be a function" ), ] - craft_res = crafter_agent.invoke( - {"messages": crafting_messages}, config={"recursion_limit": MAX_INTERACTIONS} - ) - runtime_graph.temp_response.response = parse_response_for_tool_node( - craft_res - ).response - parsed_res = f"Response: {parse_response_for_tool_node(craft_res).response}\nExplanation: {parse_response_for_tool_node(craft_res).explanation}" + try: + craft_res = crafter_agent.invoke( + {"messages": crafting_messages}, config={"recursion_limit": MAX_INTERACTIONS} + ) + parsed_res = parse_response(craft_res) + except Exception: + parsed_res = "" + # runtime_graph.temp_response.response = parse_response_for_tool_node( + # craft_res + # ).response runtime_graph.resolve_node(crafting_node, parsed_res) runtime_graph.temp_node = runtime_graph.call_tool_node() runtime_graph.add_edge(crafting_node, runtime_graph.temp_node) diff --git a/GoT/model/ollama_llm.py b/GoT/model/ollama_llm.py index 5a59369..e06bfe9 100644 --- a/GoT/model/ollama_llm.py +++ b/GoT/model/ollama_llm.py @@ -57,6 +57,11 @@ def __init__(self): api_key=os.environ.get("GEMINI_API_KEY"), temperature=1.0, # Gemini 3.0+ defaults to 1.0 ) + self.remoteLLMCrafter = ChatGoogleGenerativeAI( + model="gemini-3-flash-preview", + api_key=os.environ.get("GEMINI_API_KEY"), + temperature=1.0, # Gemini 3.0+ defaults to 1.0 + ) self.remoteLLMScoreFormat = ChatGoogleGenerativeAI( model="gemini-2.5-flash", api_key=os.environ.get("GEMINI_API_KEY"), @@ -67,12 +72,13 @@ def __init__(self): "remote_standard": self.remoteLLMStandard, "remote_response_format": self.remoteLLMResponseFormat, "remote_score_format": self.remoteLLMScoreFormat, + "remote_crafter": self.remoteLLMCrafter } self.system_prompt = SystemMessage(SYSTEM_PROMPT_GENERAL) def get_tools(self): - initial_tools = [summing, minus, square_root, multiply, divide, search_wikipedia, search_arxiv] + initial_tools = [summing, minus, square_root, multiply, divide] crafted_tools = self.get_crafted_tools() return initial_tools + crafted_tools From 3a87d0f838206b2680dac1295453a02d834ad8d2 Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Sat, 25 Apr 2026 15:21:25 +0200 Subject: [PATCH 13/14] style: change system folder architecture --- GoT/__init__.py | 8 ++++---- GoT/{tools => agent_tools}/ai_tool.py | 0 GoT/{tools => agent_tools}/craft_tool.py | 0 GoT/{tools => agent_tools}/math_tool.py | 0 GoT/{tools => agent_tools}/runtime_graph_tool.py | 6 +++--- GoT/{tools => agent_tools}/web_tool.py | 0 GoT/{model/utils => cli}/parse_args.py | 2 +- GoT/{model => core}/graph_model.py | 8 ++++---- GoT/{model/ollama_llm.py => core/llm.py} | 7 +++---- GoT/{model => core}/runtime_graph.py | 0 GoT/{model/utils => experiments}/hf_formatter.py | 6 +++--- GoT/{model => experiments}/lm_wrapper.py | 6 +++--- GoT/{model => }/utils/utils.py | 2 +- 13 files changed, 22 insertions(+), 23 deletions(-) rename GoT/{tools => agent_tools}/ai_tool.py (100%) rename GoT/{tools => agent_tools}/craft_tool.py (100%) rename GoT/{tools => agent_tools}/math_tool.py (100%) rename GoT/{tools => agent_tools}/runtime_graph_tool.py (95%) rename GoT/{tools => agent_tools}/web_tool.py (100%) rename GoT/{model/utils => cli}/parse_args.py (96%) rename GoT/{model => core}/graph_model.py (99%) rename GoT/{model/ollama_llm.py => core/llm.py} (94%) rename GoT/{model => core}/runtime_graph.py (100%) rename GoT/{model/utils => experiments}/hf_formatter.py (98%) rename GoT/{model => experiments}/lm_wrapper.py (98%) rename GoT/{model => }/utils/utils.py (99%) diff --git a/GoT/__init__.py b/GoT/__init__.py index c9d8fc7..a5df598 100644 --- a/GoT/__init__.py +++ b/GoT/__init__.py @@ -3,10 +3,10 @@ from dotenv import load_dotenv from lm_eval import evaluator, tasks -from GoT.model.graph_model import call_graph -from GoT.model.lm_wrapper import LangGraphBigBenchWrapper, TestBigBenchWrapper -from GoT.model.utils.parse_args import call_benchmark, defining_and_parse_args -from GoT.model.utils.utils import ( +from GoT.core.graph_model import call_graph +from GoT.experiments.lm_wrapper import LangGraphBigBenchWrapper, TestBigBenchWrapper +from GoT.cli.parse_args import call_benchmark, defining_and_parse_args +from GoT.utils.utils import ( print_benchmark_result, print_benchmark_result_loglikehood, ) diff --git a/GoT/tools/ai_tool.py b/GoT/agent_tools/ai_tool.py similarity index 100% rename from GoT/tools/ai_tool.py rename to GoT/agent_tools/ai_tool.py diff --git a/GoT/tools/craft_tool.py b/GoT/agent_tools/craft_tool.py similarity index 100% rename from GoT/tools/craft_tool.py rename to GoT/agent_tools/craft_tool.py diff --git a/GoT/tools/math_tool.py b/GoT/agent_tools/math_tool.py similarity index 100% rename from GoT/tools/math_tool.py rename to GoT/agent_tools/math_tool.py diff --git a/GoT/tools/runtime_graph_tool.py b/GoT/agent_tools/runtime_graph_tool.py similarity index 95% rename from GoT/tools/runtime_graph_tool.py rename to GoT/agent_tools/runtime_graph_tool.py index 6b2a7ef..7db745c 100644 --- a/GoT/tools/runtime_graph_tool.py +++ b/GoT/agent_tools/runtime_graph_tool.py @@ -1,9 +1,9 @@ from langchain.messages import HumanMessage, SystemMessage from langchain.tools import tool -from GoT.model.ollama_llm import LLM -from GoT.model.runtime_graph import ReasoningNode, RuntimeGraph -from GoT.model.utils.utils import parse_response +from GoT.core.llm import LLM +from GoT.core.runtime_graph import ReasoningNode, RuntimeGraph +from GoT.utils.utils import parse_response MAX_INTERACTIONS = 10 diff --git a/GoT/tools/web_tool.py b/GoT/agent_tools/web_tool.py similarity index 100% rename from GoT/tools/web_tool.py rename to GoT/agent_tools/web_tool.py diff --git a/GoT/model/utils/parse_args.py b/GoT/cli/parse_args.py similarity index 96% rename from GoT/model/utils/parse_args.py rename to GoT/cli/parse_args.py index 818903a..b7605b4 100644 --- a/GoT/model/utils/parse_args.py +++ b/GoT/cli/parse_args.py @@ -1,7 +1,7 @@ import argparse import sys -from GoT.model.utils.hf_formatter import use_gaia, use_gpqa, use_gsm8k, use_hendrycks_math +from GoT.experiments.hf_formatter import use_gaia, use_gpqa, use_gsm8k, use_hendrycks_math def defining_and_parse_args(): diff --git a/GoT/model/graph_model.py b/GoT/core/graph_model.py similarity index 99% rename from GoT/model/graph_model.py rename to GoT/core/graph_model.py index 8fc881a..b1ab35b 100644 --- a/GoT/model/graph_model.py +++ b/GoT/core/graph_model.py @@ -2,8 +2,8 @@ from langchain_core.messages import HumanMessage, SystemMessage, AIMessage from langgraph.graph import StateGraph, MessagesState, START, END -from GoT.model.ollama_llm import LLM -from GoT.model.runtime_graph import ( +from GoT.core.llm import LLM +from GoT.core.runtime_graph import ( BacktrackNode, CompletitionNode, CraftingNode, @@ -15,13 +15,13 @@ TestNode, ToolNode, ) -from GoT.model.utils.utils import ( +from GoT.utils.utils import ( extract_tool_used, parse_response, parse_response_for_tool_node, parse_score, ) -from GoT.tools.runtime_graph_tool import divide_thought +from GoT.agent_tools.runtime_graph_tool import divide_thought SCORE_THRESHOLD = 5 COMPLEXITY_COEFFICIENT = 0.5 diff --git a/GoT/model/ollama_llm.py b/GoT/core/llm.py similarity index 94% rename from GoT/model/ollama_llm.py rename to GoT/core/llm.py index e06bfe9..29d4513 100644 --- a/GoT/model/ollama_llm.py +++ b/GoT/core/llm.py @@ -11,7 +11,7 @@ from langchain.agents import create_agent import mlflow -from GoT.tools.math_tool import ( +from GoT.agent_tools.math_tool import ( multiply, summing, minus, @@ -19,8 +19,7 @@ divide, ) -from GoT.tools.craft_tool import craft_tool, install_dependency -from GoT.tools.web_tool import search_arxiv, search_wikipedia +from GoT.agent_tools.craft_tool import craft_tool, install_dependency load_dotenv() @@ -86,7 +85,7 @@ def get_craft_tool(self): return [craft_tool, install_dependency] def get_crafted_tools(self) -> list[BaseTool]: - module_name = "GoT.tools.ai_tool" + module_name = "GoT.agent_tools.ai_tool" if module_name in sys.modules: module = importlib.reload(sys.modules[module_name]) else: diff --git a/GoT/model/runtime_graph.py b/GoT/core/runtime_graph.py similarity index 100% rename from GoT/model/runtime_graph.py rename to GoT/core/runtime_graph.py diff --git a/GoT/model/utils/hf_formatter.py b/GoT/experiments/hf_formatter.py similarity index 98% rename from GoT/model/utils/hf_formatter.py rename to GoT/experiments/hf_formatter.py index 5c9b162..f7127a4 100644 --- a/GoT/model/utils/hf_formatter.py +++ b/GoT/experiments/hf_formatter.py @@ -7,9 +7,9 @@ from langchain.messages import HumanMessage -from GoT.model.graph_model import call_graph -from GoT.model.ollama_llm import LLM -from GoT.model.utils.utils import ( +from GoT.core.graph_model import call_graph +from GoT.core.llm import LLM +from GoT.utils.utils import ( extract_output, normalize_list, normalize_number, diff --git a/GoT/model/lm_wrapper.py b/GoT/experiments/lm_wrapper.py similarity index 98% rename from GoT/model/lm_wrapper.py rename to GoT/experiments/lm_wrapper.py index 1013c0b..8f9bc42 100644 --- a/GoT/model/lm_wrapper.py +++ b/GoT/experiments/lm_wrapper.py @@ -1,10 +1,10 @@ -from GoT.model.graph_model import call_graph +from GoT.core.graph_model import call_graph from lm_eval.api.registry import register_model from lm_eval.api.model import LM -from GoT.model.ollama_llm import LLM +from GoT.core.llm import LLM from langchain_core.messages import HumanMessage -from GoT.model.utils.utils import extract_output, normalize_number, parse_response +from GoT.utils.utils import extract_output, normalize_number, parse_response class LangGraphLM: diff --git a/GoT/model/utils/utils.py b/GoT/utils/utils.py similarity index 99% rename from GoT/model/utils/utils.py rename to GoT/utils/utils.py index 3ae103a..d156a87 100644 --- a/GoT/model/utils/utils.py +++ b/GoT/utils/utils.py @@ -5,7 +5,7 @@ import numpy as np from sympy import simplify, sympify -from GoT.model.runtime_graph import Response, Score +from GoT.core.runtime_graph import Response, Score from langgraph.graph import MessagesState from langchain_core.messages import AIMessage From def5e3c0e6f3d468a04b01ce1b4e0aa2d1e4892a Mon Sep 17 00:00:00 2001 From: Raggini Marco Date: Sat, 25 Apr 2026 15:26:29 +0200 Subject: [PATCH 14/14] style: ruff format + ignores arxiv, wikipedia stubs --- GoT/agent_tools/web_tool.py | 24 ++++++++++++------------ GoT/cli/parse_args.py | 9 +++++++-- GoT/core/graph_model.py | 20 +++++++++++--------- GoT/core/llm.py | 2 +- GoT/experiments/hf_formatter.py | 15 ++++++++++----- GoT/utils/utils.py | 1 + pyproject.toml | 2 +- 7 files changed, 43 insertions(+), 30 deletions(-) diff --git a/GoT/agent_tools/web_tool.py b/GoT/agent_tools/web_tool.py index 5987bce..b56413d 100644 --- a/GoT/agent_tools/web_tool.py +++ b/GoT/agent_tools/web_tool.py @@ -2,6 +2,7 @@ from langchain.tools import tool import wikipedia + @tool def search_wikipedia(query: str) -> str: """ @@ -11,7 +12,7 @@ def search_wikipedia(query: str) -> str: query (str): The keyword or topic to search for. Returns: - str: A 3-sentence summary of the topic, the first option if + str: A 3-sentence summary of the topic, the first option if ambiguous, or an error message if not found. """ try: @@ -21,20 +22,19 @@ def search_wikipedia(query: str) -> str: return wikipedia.summary(e.options[0], sentences=3) except wikipedia.PageError: return "Page not found" - + + @tool def search_arxiv(query: str) -> str: - """Search ArXiv for scientific papers on a given topic. + """Search ArXiv for scientific papers on a given topic. Use this when you need to find research papers, abstracts or academic references.""" - + try: client = arxiv.Client() search = arxiv.Search( - query=query, - max_results=3, - sort_by=arxiv.SortCriterion.Relevance + query=query, max_results=3, sort_by=arxiv.SortCriterion.Relevance ) - + results = [] for paper in client.results(search): results.append( @@ -44,11 +44,11 @@ def search_arxiv(query: str) -> str: f"Summary: {paper.summary[:300]}...\n" f"URL: {paper.entry_id}\n" ) - + if not results: return "No papers found for this query." - + return "\n---\n".join(results) - + except Exception as e: - return f"Error searching ArXiv: {str(e)}" \ No newline at end of file + return f"Error searching ArXiv: {str(e)}" diff --git a/GoT/cli/parse_args.py b/GoT/cli/parse_args.py index b7605b4..f52e05e 100644 --- a/GoT/cli/parse_args.py +++ b/GoT/cli/parse_args.py @@ -1,7 +1,12 @@ import argparse import sys -from GoT.experiments.hf_formatter import use_gaia, use_gpqa, use_gsm8k, use_hendrycks_math +from GoT.experiments.hf_formatter import ( + use_gaia, + use_gpqa, + use_gsm8k, + use_hendrycks_math, +) def defining_and_parse_args(): @@ -39,7 +44,7 @@ def defining_and_parse_args(): "intermediate_algebra", "number_theory", "precalculus", - "prealgebra" + "prealgebra", ], help="The type of math problems to run, only for hendrycks_math.", ) diff --git a/GoT/core/graph_model.py b/GoT/core/graph_model.py index b1ab35b..c64e3b1 100644 --- a/GoT/core/graph_model.py +++ b/GoT/core/graph_model.py @@ -283,7 +283,9 @@ def tool_call(messages: MessagesState): ) tool_used = extract_tool_used(res) runtime_graph.temp_response.response = parse_response_for_tool_node(res).response - runtime_graph.temp_response.explanation = parse_response_for_tool_node(res).explanation + runtime_graph.temp_response.explanation = parse_response_for_tool_node( + res + ).explanation parsed_res = f"Response: {parse_response_for_tool_node(res).response}\nExplanation: {parse_response_for_tool_node(res).explanation}" runtime_graph.resolve_node(call_node, parsed_res) @@ -345,12 +347,13 @@ def crafting(messages: MessagesState): content="Use the context given to craft a tool to solve this problem using craft_tool. It must be a function" ), ] - try: + try: craft_res = crafter_agent.invoke( - {"messages": crafting_messages}, config={"recursion_limit": MAX_INTERACTIONS} + {"messages": crafting_messages}, + config={"recursion_limit": MAX_INTERACTIONS}, ) parsed_res = parse_response(craft_res) - except Exception: + except Exception: parsed_res = "" # runtime_graph.temp_response.response = parse_response_for_tool_node( # craft_res @@ -380,7 +383,9 @@ def test_result(messages: MessagesState): test_node = runtime_graph.temp_node if not isinstance(test_node, TestNode): raise TypeError("Expected TestNode for scoring") - threshold = COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity + threshold = ( + COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity + ) if test_node.score >= threshold: runtime_graph.add_edge(test_node, runtime_graph.temp_response) runtime_graph.temp_response.resolved = True @@ -391,10 +396,7 @@ def test_result(messages: MessagesState): and test_node.need_tool_crafting is True ): return "crafting" - elif ( - test_node.score < threshold - and is_tool_path_available is True - ): + elif test_node.score < threshold and is_tool_path_available is True: if test_node.need_tool_crafting is True: test_node.response = "The problem is too complex to craft a new tool, try reason step by step or divide complexity." return "backtrack" diff --git a/GoT/core/llm.py b/GoT/core/llm.py index 29d4513..f76b50f 100644 --- a/GoT/core/llm.py +++ b/GoT/core/llm.py @@ -71,7 +71,7 @@ def __init__(self): "remote_standard": self.remoteLLMStandard, "remote_response_format": self.remoteLLMResponseFormat, "remote_score_format": self.remoteLLMScoreFormat, - "remote_crafter": self.remoteLLMCrafter + "remote_crafter": self.remoteLLMCrafter, } self.system_prompt = SystemMessage(SYSTEM_PROMPT_GENERAL) diff --git a/GoT/experiments/hf_formatter.py b/GoT/experiments/hf_formatter.py index f7127a4..6e2ebc8 100644 --- a/GoT/experiments/hf_formatter.py +++ b/GoT/experiments/hf_formatter.py @@ -18,6 +18,7 @@ TOKEN = os.getenv("HF_TOKEN") + class ResultEval: def __init__( self, @@ -207,6 +208,7 @@ def hendrycks_math_eval(responses: list[ResultEval]): print(f"Total: {len(responses)}") print(f"Correct: {correct}") + def benchmark_run( questions: list[ResultEval], max_run: int, test: bool ) -> list[ResultEval]: @@ -261,11 +263,11 @@ def gaia_format(dataset: Dataset) -> list[ResultEval]: attachment = sample.get("file_name", None) if attachment: abs_path = hf_hub_download( - repo_id="gaia-benchmark/GAIA", - filename=f"2023/validation/{attachment}", - repo_type="dataset", - token=TOKEN - ) + repo_id="gaia-benchmark/GAIA", + filename=f"2023/validation/{attachment}", + repo_type="dataset", + token=TOKEN, + ) print(abs_path) question += f"\nAttachment file path: {abs_path}" correct_answer = sample["Final answer"] @@ -283,6 +285,7 @@ def gaia_format(dataset: Dataset) -> list[ResultEval]: return questions + def gaia_eval(responses: list[ResultEval]): correct = 0 @@ -300,6 +303,7 @@ def gaia_eval(responses: list[ResultEval]): print(f"Total: {len(responses)}") print(f"Correct: {correct}") + def use_gpqa(max_run: int, test: bool, model_name: str): ds = load_dataset("Idavidrein/gpqa", "gpqa_diamond", split="train") questions = gpqa_format(ds) @@ -323,6 +327,7 @@ def use_hendrycks_math(max_run: int, test: bool, model_name: str, type: str): hendrycks_math_eval(responses) save_eval_results(responses, model_name=model_name) + def use_gaia(max_run: int, test: bool, model_name: str): ds = load_dataset("gaia-benchmark/GAIA", "2023_level1", split="validation") questions = gaia_format(ds) diff --git a/GoT/utils/utils.py b/GoT/utils/utils.py index d156a87..1b936b4 100644 --- a/GoT/utils/utils.py +++ b/GoT/utils/utils.py @@ -229,6 +229,7 @@ def print_benchmark_result_loglikehood( print(f"Correct: {n_correct}") print(f"Wrong: {n_wrong}") + def download_mlflow_traces(n_max: int): traces = mlflow.search_traces(max_results=n_max, order_by=["timestamp DESC"]) traces.to_csv("traces.csv", index=False) diff --git a/pyproject.toml b/pyproject.toml index 20e60bf..27f02e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,5 +51,5 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [[tool.mypy.overrides]] -module = ["lm_eval.*", "datasets.*", "sympy.*"] +module = ["lm_eval.*", "datasets.*", "sympy.*", "arxiv.*", "wikipedia.*"] follow_untyped_imports = true