Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
fa71e8f
chore: improve prompt
MarkRagg Apr 15, 2026
5569cb6
chore: remove craft tools in hf formatter
MarkRagg Apr 16, 2026
53344dd
chore: create a single benchmark_run for all datasets
MarkRagg Apr 16, 2026
4cf3a01
feat: gaia benchmark added
MarkRagg Apr 16, 2026
929ebd0
chore: simplify codes
MarkRagg Apr 17, 2026
1045101
chore: add explanation of tool needed
MarkRagg Apr 17, 2026
c840f47
chore: remove comments
MarkRagg Apr 21, 2026
5bd0c62
chore: add method to download mlflow traces
MarkRagg Apr 22, 2026
b4a4dda
chore: change var name and simplify codes
MarkRagg Apr 23, 2026
0c93f80
feat: add wikipedia and arxiv tools
MarkRagg Apr 24, 2026
ea2f8e2
fix: fix names in type arg
MarkRagg Apr 24, 2026
8e8ff33
chore: add specific crafter LLM and improve prompt
MarkRagg Apr 25, 2026
f0cec82
style: change system folder architecture
MarkRagg Apr 25, 2026
fab5a9c
style: ruff format + ignores arxiv, wikipedia stubs
MarkRagg Apr 25, 2026
677cab6
chore: improve prompt
MarkRagg Apr 15, 2026
6dc36a1
chore: remove craft tools in hf formatter
MarkRagg Apr 16, 2026
4333141
chore: create a single benchmark_run for all datasets
MarkRagg Apr 16, 2026
02a6e8e
feat: gaia benchmark added
MarkRagg Apr 16, 2026
626007f
chore: simplify codes
MarkRagg Apr 17, 2026
f62ed94
chore: add explanation of tool needed
MarkRagg Apr 17, 2026
66811f2
chore: remove comments
MarkRagg Apr 21, 2026
073be66
chore: add method to download mlflow traces
MarkRagg Apr 22, 2026
67ecf42
chore: change var name and simplify codes
MarkRagg Apr 23, 2026
bef5e20
feat: add wikipedia and arxiv tools
MarkRagg Apr 24, 2026
f5738a2
fix: fix names in type arg
MarkRagg Apr 24, 2026
ca78485
chore: add specific crafter LLM and improve prompt
MarkRagg Apr 25, 2026
3a87d0f
style: change system folder architecture
MarkRagg Apr 25, 2026
def5e3c
style: ruff format + ignores arxiv, wikipedia stubs
MarkRagg Apr 25, 2026
345ce80
chore: improve evaluation
MarkRagg May 7, 2026
2b4f617
chore: code refinement
MarkRagg May 8, 2026
47ef8b4
feat: add custom runs on cli args
MarkRagg May 14, 2026
89e468d
chore: change argument 'type' in 'category'
MarkRagg May 15, 2026
d2f886f
chore: memorize tool crafted in each CraftingNode
MarkRagg May 20, 2026
2558b00
chore: improve parsing
MarkRagg May 20, 2026
0468e3e
style: change var names
MarkRagg May 20, 2026
ac7289b
chore: comment pypi release
MarkRagg May 20, 2026
b787ac3
fix: accidentally delete crafting condition
MarkRagg May 20, 2026
868b3d0
chore: add docstring control
MarkRagg May 20, 2026
b417846
style: ruff format
MarkRagg May 22, 2026
52d5dbb
chore: mypy check
MarkRagg May 22, 2026
6bfc5c4
chore: uncomment pypi config
MarkRagg May 22, 2026
7840f62
Merge branch 'develop' into feature/code-refinement
MarkRagg May 22, 2026
9825232
Merge pull request #16 from MarkRagg/feature/code-refinement
MarkRagg May 22, 2026
60a7cda
Merge branch 'main' into develop
MarkRagg Jun 11, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
npm install
npx semantic-release
env:
# PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }}
GITHUB_TOKEN: ${{ secrets.RELEASE_TOKEN }}
RELEASE_TEST_PYPI: ${{ github.event.repository.is_template || contains(github.repository, 'template') }}
# dry run if not on main/master branch, or if initial commit
Expand Down
7 changes: 1 addition & 6 deletions GoT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from dotenv import load_dotenv

from lm_eval import evaluator, tasks
from GoT.core.graph_model import call_graph
from GoT.experiments.lm_wrapper import LangGraphBigBenchWrapper, TestBigBenchWrapper
from GoT.cli.parse_args import call_benchmark, defining_and_parse_args
from GoT.utils.utils import (
Expand Down Expand Up @@ -62,14 +61,10 @@ def lm_eval_graph_benchmark():
print_benchmark_result_loglikehood(results, task_name, filter_val="none")


def custom_test():
call_graph("Solve this integral ∫x2⋅ex2dx")


def main():
# It could be changed with custom_test() to test a custom problem instead of the benchmark
args = defining_and_parse_args()
call_benchmark(args)
# download_mlflow_traces(50)


# let this be the last line of this file
Expand Down
7 changes: 7 additions & 0 deletions GoT/agent_tools/craft_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ def sanitize_input(query: str) -> str:

func = functions[0]

try:
docstring = ast.get_docstring(func)
if not docstring or not docstring.strip():
return "Error: missing docstring. A description of the function is mandatory for Gemini tools."
except TypeError:
return "Error: missing docstring. A description of the function is mandatory for Gemini tools."

for arg in func.args.args:
if arg.annotation is None:
return f"Error: missing type annotation for '{arg.arg}'"
Expand Down
14 changes: 11 additions & 3 deletions GoT/cli/parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use_gsm8k,
use_hendrycks_math,
)
from GoT.experiments.runner_custom import custom_test


def defining_and_parse_args():
Expand All @@ -17,7 +18,7 @@ def defining_and_parse_args():
"--benchmark",
required=True,
type=str,
choices=["gsm8k", "gpqa", "hendrycks_math", "gaia"],
choices=["gsm8k", "gpqa", "hendrycks_math", "gaia", "custom"],
help="The benchmark to run the model on.",
)
parser.add_argument(
Expand All @@ -27,14 +28,17 @@ def defining_and_parse_args():
choices=["graph", "standard"],
help="Whether to run the standard model or the graph model.",
)
parser.add_argument(
"--prompt", type=str, default="", help="Insert a prompt during a custom run."
)
parser.add_argument(
"--max_run",
type=int,
default=1,
help="The maximum number of runs for the benchmark.",
)
parser.add_argument(
"--type",
"--category",
type=str,
default="algebra",
choices=[
Expand Down Expand Up @@ -66,6 +70,10 @@ def call_benchmark(args):
elif args.benchmark == "gpqa":
use_gpqa(max_run=max_run, test=test, model_name=mode)
elif args.benchmark == "hendrycks_math":
use_hendrycks_math(max_run=max_run, test=test, model_name=mode, type=args.type)
use_hendrycks_math(
max_run=max_run, test=test, model_name=mode, type=args.category
)
elif args.benchmark == "gaia":
use_gaia(max_run=max_run, test=test, model_name=mode)
elif args.benchmark == "custom" and args.prompt != "":
custom_test(args.prompt, test)
18 changes: 8 additions & 10 deletions GoT/core/graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from GoT.utils.utils import (
extract_tool_used,
extract_tools_crafted,
parse_response,
parse_response_for_tool_node,
parse_score,
Expand Down Expand Up @@ -54,16 +55,16 @@
Your duty is to score, from 0 to 5, the response that user gives and assign a score.

Rules:
- If a response suggest the need of crafting a tool, score it with 1 or less and specify clearly the need of a new tool to solve the problem.
- You MUST respond ONLY using the Score function.
- You must consider if the format of the answer follow the instruction
- You cannot give the full solution, only hints.
- If a response suggest the need of crafting a tool, score it with 1 or less and specify clearly the need of a new tool to solve the problem.
- Do not write natural language outside the function.
- Always consider creating a tool if it makes the response correct or reusable.

Score meanings:
0: Impossible to understand / completely wrong
1: Nearly completely wrong
1: Nearly completely wrong / need to craft a tool
2: Correct language but does not follow instruction
3: Tries to solve but fails instruction / wrong
4: Follows instruction but result wrong or incomplete
Expand Down Expand Up @@ -168,6 +169,7 @@ def open_excel_files(excel_path: str)
Rules:
- Prefer generic names and parameters, never craft specific functions.
- If the function contains specific numbers or values, it is wrong.
- The Tool must follow the Json schema protocol, Tuple is banned.
- Never return hardcoded or placeholder strings, the function must fetch real data.
- Craft a maximum of 3 tools, it must contains always the docs. If the number of tool crafted exceed, you fail.
- Never craft tool that raise exceptions.
Expand Down Expand Up @@ -202,12 +204,11 @@ def goal(prompt: MessagesState):
runtime_graph.add_node(goal_node)
runtime_graph.temp_node = goal_node
for i in range(0, 3):
reasoning_node = ReasoningNode("")
call_node = ToolNode(
"Please, resolve the problem with the tools given, you MUST follow the previous reasoning.",
"",
tool_name="",
)
reasoning_node = ReasoningNode("")
runtime_graph.add_node(reasoning_node)
runtime_graph.add_node(call_node)
runtime_graph.add_edge(goal_node, reasoning_node)
Expand Down Expand Up @@ -242,7 +243,7 @@ def tool_call(messages: MessagesState):
# It calls the llm and it resolves the call node
call_node = runtime_graph.temp_node
tool_agent = LLM().create_custom_agent(
LLM().get_tools() + [divide_thought],
LLM().get_tools(),
SystemMessage(
"You are an assistant specialized in tools. Your goal is to resolve the problem with "
" the tool that the user indicates to you. You HAVE to use or craft the tool that the assistant indicates to you."
Expand Down Expand Up @@ -335,7 +336,7 @@ def response_evaluation(messages: MessagesState):


def crafting(messages: MessagesState):
crafting_node = CraftingNode(response="", tool_crafted="", resolved=False)
crafting_node = CraftingNode(response="", tools_crafted=[], resolved=False)
runtime_graph.add_node(crafting_node)
runtime_graph.add_edge(runtime_graph.temp_node, crafting_node)
runtime_graph.temp_node = crafting_node
Expand All @@ -352,12 +353,10 @@ def crafting(messages: MessagesState):
{"messages": crafting_messages},
config={"recursion_limit": MAX_INTERACTIONS},
)
crafting_node.tools_crafted = extract_tools_crafted(craft_res)
parsed_res = parse_response(craft_res)
except Exception:
parsed_res = ""
# runtime_graph.temp_response.response = parse_response_for_tool_node(
# craft_res
# ).response
runtime_graph.resolve_node(crafting_node, parsed_res)
runtime_graph.temp_node = runtime_graph.call_tool_node()
runtime_graph.add_edge(crafting_node, runtime_graph.temp_node)
Expand Down Expand Up @@ -454,7 +453,6 @@ def backtrack(messages: MessagesState):
runtime_graph.add_edge(
backtrack_node, runtime_graph.temp_node
) # tool call node that we want to resolve
# messages = runtime_graph.append_prompt_to_messages_state(runtime_graph.temp_node)
messages.get("messages", []).append(AIMessage(backtrack_node.feedback))
return messages

Expand Down
12 changes: 6 additions & 6 deletions GoT/core/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,26 @@ def __init__(self):
api_key=os.environ.get("GEMINI_API_KEY"),
temperature=1.0, # Gemini 3.0+ defaults to 1.0
)
self.remoteLLMResponseFormat = ChatGoogleGenerativeAI(
self.remoteLLMReasoning = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
api_key=os.environ.get("GEMINI_API_KEY"),
temperature=1.0, # Gemini 3.0+ defaults to 1.0
)
self.remoteLLMCrafter = ChatGoogleGenerativeAI(
model="gemini-3-flash-preview",
model="gemini-2.5-flash",
api_key=os.environ.get("GEMINI_API_KEY"),
temperature=1.0, # Gemini 3.0+ defaults to 1.0
)
self.remoteLLMScoreFormat = ChatGoogleGenerativeAI(
self.remoteLLMEvaluator = ChatGoogleGenerativeAI(
model="gemini-2.5-flash",
api_key=os.environ.get("GEMINI_API_KEY"),
temperature=0.7, # Gemini 3.0+ defaults to 1.0
temperature=1.0, # Gemini 3.0+ defaults to 1.0
)

self.remoteLLMs = {
"remote_standard": self.remoteLLMStandard,
"remote_response_format": self.remoteLLMResponseFormat,
"remote_score_format": self.remoteLLMScoreFormat,
"remote_response_format": self.remoteLLMReasoning,
"remote_score_format": self.remoteLLMEvaluator,
"remote_crafter": self.remoteLLMCrafter,
}

Expand Down
32 changes: 5 additions & 27 deletions GoT/core/runtime_graph.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import Dict, List
from langgraph.graph import MessagesState
from langchain_core.messages import AnyMessage, HumanMessage
from pydantic import BaseModel, Field


class RuntimeNode:
_id_counter = 0 # Contatore globale per ID unici
_id_counter = 0 # global ID counter

def __init__(
self,
resolved: bool = False,
):
self.id = RuntimeNode._id_counter # ID unico per ogni nodo
self.id = RuntimeNode._id_counter
RuntimeNode._id_counter += 1
self.resolved = resolved

Expand Down Expand Up @@ -52,13 +51,11 @@ def __init__(
self,
prompt: str,
response: str,
tool_name: str,
resolved: bool = False,
):
super().__init__(resolved)
self.prompt = prompt
self.response = response
self.tool_name = tool_name


class GoalNode(RuntimeNode):
Expand Down Expand Up @@ -119,12 +116,12 @@ class CraftingNode(RuntimeNode):
def __init__(
self,
response: str,
tool_crafted: str = "",
tools_crafted: list[str] = [],
resolved: bool = False,
):
super().__init__(resolved)
self.response = response
self.tool_crafted = tool_crafted
self.tools_crafted = tools_crafted


class Score(BaseModel):
Expand Down Expand Up @@ -170,7 +167,6 @@ class RuntimeGraph:
def __init__(self):
self.goal: MessagesState = MessagesState(messages=[])
self.nodes: Dict[RuntimeNode, List[RuntimeNode]] = {}
self.tools_available: Dict[RuntimeNode, str] = {}
self.temp_node: RuntimeNode = RuntimeNode()
self.temp_response: ResponseNode = ResponseNode(response="", resolved=False)

Expand All @@ -181,9 +177,6 @@ def add_edge(self, n1: RuntimeNode, n2: RuntimeNode):
self.nodes.setdefault(n1, []).append(n2)
self.nodes.setdefault(n2, [])

def add_tool_link(self, call_node: RuntimeNode, tool_name: str):
self.tools_available.setdefault(call_node, tool_name)

def resolve_node(self, node: RuntimeNode, response: str) -> None:
if isinstance(node, (ToolNode, TestNode, CompletitionNode)):
node.response = response
Expand All @@ -208,31 +201,16 @@ def exist_tool_available(self) -> bool:
call_nodes = [n for n in nodes if (isinstance(n, ToolNode) and not n.resolved)]
return True if call_nodes else False

def get_resolved_tools(self):
resolved_nodes = [t for t in self.tools_available.keys() if t.resolved is True]
return [self.tools_available[n] for n in resolved_nodes]

def is_craftin_node_resolved(self) -> bool:
def is_crafting_node_resolved(self) -> bool:
nodes = list(self.nodes.keys())
crafting_nodes = [
n for n in nodes if (isinstance(n, CraftingNode) and n.resolved)
]
return True if crafting_nodes else False

def append_prompt_to_messages_state(
self, node: TestNode | ToolNode | CompletitionNode | GoalNode
) -> MessagesState:
messages: list[AnyMessage] = []

if node.prompt:
messages.append(HumanMessage(content=node.prompt))

return MessagesState(messages=messages)

def clear(self):
RuntimeNode._id_counter = 0
self.nodes = {}
self.tools_available = {}
self.temp_node = RuntimeNode()
self.temp_response = ResponseNode(response="", resolved=False)

Expand Down
7 changes: 4 additions & 3 deletions GoT/experiments/hf_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from GoT.core.graph_model import call_graph
from GoT.core.llm import LLM
from GoT.utils.utils import (
extract_answer_from_response,
extract_output,
normalize_list,
normalize_number,
parse_response,
symbolic_equal,
)

Expand Down Expand Up @@ -190,8 +192,7 @@ def hendrycks_math_eval(responses: list[ResultEval]):
correct = 0

for res in responses:
opt_res = re.search(r"\\boxed\{(.*)\}", res.response)
norm_res = opt_res.group(1) if opt_res else "N/A"
norm_res = extract_answer_from_response(res.response)
norm_correct = normalize_number(res.correct_answer)
res.filtered_answer = norm_res

Expand Down Expand Up @@ -222,7 +223,7 @@ def benchmark_run(
correct_answer = q.correct_answer
try:
if test:
response = extract_output(
response = parse_response(
agent.invoke(
{"messages": [HumanMessage(content=prompt)]},
config={"recursion_limit": 20},
Expand Down
15 changes: 15 additions & 0 deletions GoT/experiments/runner_custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from langchain.messages import HumanMessage

from GoT.core.graph_model import call_graph
from GoT.core.llm import LLM


def custom_test(text: str, is_graph_mode: bool):
if not is_graph_mode:
call_graph(text)
else:
agent = LLM().create_custom_agent(LLM().get_tools())
agent.invoke(
{"messages": [HumanMessage(content=text)]},
config={"recursion_limit": 20},
)
Loading
Loading