Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions GoT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from dotenv import load_dotenv

from lm_eval import evaluator, tasks
from GoT.model.graph_model import call_graph
from GoT.model.lm_wrapper import LangGraphBigBenchWrapper, TestBigBenchWrapper
from GoT.model.utils.parse_args import call_benchmark, defining_and_parse_args
from GoT.model.utils.utils import (
from GoT.core.graph_model import call_graph
from GoT.experiments.lm_wrapper import LangGraphBigBenchWrapper, TestBigBenchWrapper
from GoT.cli.parse_args import call_benchmark, defining_and_parse_args
from GoT.utils.utils import (
print_benchmark_result,
print_benchmark_result_loglikehood,
)
Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions GoT/tools/craft_tool.py → GoT/agent_tools/craft_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def craft_tool(tool_function: str) -> str:
"""Save the function definition provided by the LLM as a tool that can be used by other agents.
The function should be defined as a python function.
The function should be general and reusable, and should not be specific to the current problem.
The function must not use tuple as args type.
The function must be defined as gemini api format, with type annotations for all arguments and return type.
The function should be defined in a way that it can be imported and used by other agents."""

def sanitize_input(query: str) -> str:
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from langchain.messages import HumanMessage, SystemMessage
from langchain.tools import tool

from GoT.model.ollama_llm import LLM
from GoT.model.runtime_graph import ReasoningNode, RuntimeGraph
from GoT.model.utils.utils import parse_response
from GoT.core.llm import LLM
from GoT.core.runtime_graph import ReasoningNode, RuntimeGraph
from GoT.utils.utils import parse_response

MAX_INTERACTIONS = 10

Expand All @@ -21,6 +21,8 @@ def divide_thought(
HOW TO USE THIS TOOL:
- Call it when you think the problem is complex.
- The two parts must be as independent as possible.
IMPORTANT NOTES:
- You can't use the result of the first part to reason about the second part, and vice versa. The two parts must be as independent as possible.
Arguments:
- first_part: the first part of the thought process
- second_part: the second part of the thought process
Expand Down
54 changes: 54 additions & 0 deletions GoT/agent_tools/web_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import arxiv
from langchain.tools import tool
import wikipedia


@tool
def search_wikipedia(query: str) -> str:
"""
Fetch a brief summary from Wikipedia.

Args:
query (str): The keyword or topic to search for.

Returns:
str: A 3-sentence summary of the topic, the first option if
ambiguous, or an error message if not found.
"""
try:
return wikipedia.search(query)
except wikipedia.DisambiguationError as e:
# happens when query is ambiguous, pick first option
return wikipedia.summary(e.options[0], sentences=3)
except wikipedia.PageError:
return "Page not found"


@tool
def search_arxiv(query: str) -> str:
"""Search ArXiv for scientific papers on a given topic.
Use this when you need to find research papers, abstracts or academic references."""

try:
client = arxiv.Client()
search = arxiv.Search(
query=query, max_results=3, sort_by=arxiv.SortCriterion.Relevance
)

results = []
for paper in client.results(search):
results.append(
f"Title: {paper.title}\n"
f"Authors: {', '.join(a.name for a in paper.authors)}\n"
f"Published: {paper.published.strftime('%Y-%m-%d')}\n"
f"Summary: {paper.summary[:300]}...\n"
f"URL: {paper.entry_id}\n"
)

if not results:
return "No papers found for this query."

return "\n---\n".join(results)

except Exception as e:
return f"Error searching ArXiv: {str(e)}"
13 changes: 10 additions & 3 deletions GoT/model/utils/parse_args.py → GoT/cli/parse_args.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import argparse
import sys

from GoT.model.utils.hf_formatter import use_gpqa, use_gsm8k, use_hendrycks_math
from GoT.experiments.hf_formatter import (
use_gaia,
use_gpqa,
use_gsm8k,
use_hendrycks_math,
)


def defining_and_parse_args():
Expand All @@ -12,7 +17,7 @@ def defining_and_parse_args():
"--benchmark",
required=True,
type=str,
choices=["gsm8k", "gpqa", "hendrycks_math"],
choices=["gsm8k", "gpqa", "hendrycks_math", "gaia"],
help="The benchmark to run the model on.",
)
parser.add_argument(
Expand All @@ -39,7 +44,7 @@ def defining_and_parse_args():
"intermediate_algebra",
"number_theory",
"precalculus",
"statistics",
"prealgebra",
],
help="The type of math problems to run, only for hendrycks_math.",
)
Expand All @@ -62,3 +67,5 @@ def call_benchmark(args):
use_gpqa(max_run=max_run, test=test, model_name=mode)
elif args.benchmark == "hendrycks_math":
use_hendrycks_math(max_run=max_run, test=test, model_name=mode, type=args.type)
elif args.benchmark == "gaia":
use_gaia(max_run=max_run, test=test, model_name=mode)
144 changes: 97 additions & 47 deletions GoT/model/graph_model.py → GoT/core/graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage
from langgraph.graph import StateGraph, MessagesState, START, END

from GoT.model.ollama_llm import LLM
from GoT.model.runtime_graph import (
from GoT.core.llm import LLM
from GoT.core.runtime_graph import (
BacktrackNode,
CompletitionNode,
CraftingNode,
Expand All @@ -15,13 +15,13 @@
TestNode,
ToolNode,
)
from GoT.model.utils.utils import (
from GoT.utils.utils import (
extract_tool_used,
parse_response,
parse_response_for_tool_node,
parse_score,
)
from GoT.tools.runtime_graph_tool import divide_thought
from GoT.agent_tools.runtime_graph_tool import divide_thought

SCORE_THRESHOLD = 5
COMPLEXITY_COEFFICIENT = 0.5
Expand Down Expand Up @@ -55,6 +55,7 @@

Rules:
- You MUST respond ONLY using the Score function.
- You must consider if the format of the answer follow the instruction
- You cannot give the full solution, only hints.
- If a response suggest the need of crafting a tool, score it with 1 or less and specify clearly the need of a new tool to solve the problem.
- Do not write natural language outside the function.
Expand All @@ -81,6 +82,7 @@
LLM().get_craft_tool(),
SystemMessage(
"""
You are a specialized in coding and write new useful method.
You create reusable Python tools for other agents.

The tool must be GENERAL and parameterized.
Expand All @@ -101,18 +103,80 @@ def multiply(a: float, b: float) -> float:
'
return a * b


Bad example (hardcoded/placeholder result):
def search_papers(query: str) -> str:
return "Results about " + query # WRONG: never return hardcoded strings

Good example (real API call):
def search_papers(query: str) -> str:
'
Arguments:
query: the search query string
Returns:
A string with real results fetched from the API
'
import arxiv
client = arxiv.Client()
search = arxiv.Search(query=query, max_results=3)
results = [p.title + ": " + p.summary[:200] for p in client.results(search)]
return "\\n".join(results)

Bad example: Too specific
def get_oldest_blu_ray_title(spreadsheet_path: str) -> str:
"
Analyzes a spreadsheet to find the oldest Blu-Ray title.

Arguments:
spreadsheet_path: The file path to the spreadsheet (e.g., 'C:/Users/user/data.xlsx').

Returns:
The title of the oldest Blu-Ray as it appears in the spreadsheet.
"
import pandas as pd

df = pd.read_excel(spreadsheet_path)

# Assuming 'Format' column for media type and 'Recording Date' for date
blu_rays = df[df['Format'] == 'Blu-Ray']

if blu_rays.empty:
return "No Blu-Ray titles found."

# Ensure 'Recording Date' is in datetime format for proper comparison
blu_rays['Recording Date'] = pd.to_datetime(blu_rays['Recording Date'])

oldest_blu_ray = blu_rays.sort_values(by='Recording Date', ascending=True).iloc[0]

return oldest_blu_ray['Title']

Good example
def open_excel_files(excel_path: str)
Analyzes a spreadsheet.

Arguments:
spreadsheet_path: The file path to the spreadsheet (e.g., 'C:/Users/user/data.xlsx').

Returns:
The excel file in string
"
import pandas as pd

df = pd.read_excel(spreadsheet_path)
return df.to_string()

Rules:
- Prefer generic names and parameters, never craft specific functions.
- If the function contains specific numbers or values, it is wrong.
- Craft only one function, it must contains always the docs.
- Never return hardcoded or placeholder strings, the function must fetch real data.
- Craft a maximum of 3 tools, it must contains always the docs. If the number of tool crafted exceed, you fail.
- Never craft tool that raise exceptions.
- Respond ONLY using the tool available.
- No natural language.
- No comments in the python interpreter.
- No more than 1 line comments in the python codes.
"""
),
response_format=Response,
type="remote_response_format",
type="remote_crafter",
)

reasoning_agent = LLM().create_custom_agent(
Expand Down Expand Up @@ -156,22 +220,6 @@ def goal(prompt: MessagesState):
return prompt


# def tool_expand(goal: MessagesState):
# msg = parse_response(goal)
# sys_msg = "Please make a list using '-' to denote each tool in a probabilistic order, don't use this character for other reasons. Select only the tool(s) you want to use to solve this problem."
# messages = [
# HumanMessage(msg),
# SystemMessage(sys_msg),
# ]
# res = starting_agent.invoke({"messages": messages}, config={"recursion_limit": MAX_INTERACTIONS})
# str_res = parse_response(res)
# goal["messages"].append(AIMessage(content=str_res))
# # tool_list = parse_tool_list(str_res) # Toglie elementi inutili
# # add tool nodes in the runtime graph

# return goal


def tool_reasoning(messages: MessagesState):
messages["messages"].append(
HumanMessage(
Expand Down Expand Up @@ -235,6 +283,9 @@ def tool_call(messages: MessagesState):
)
tool_used = extract_tool_used(res)
runtime_graph.temp_response.response = parse_response_for_tool_node(res).response
runtime_graph.temp_response.explanation = parse_response_for_tool_node(
res
).explanation
parsed_res = f"Response: {parse_response_for_tool_node(res).response}\nExplanation: {parse_response_for_tool_node(res).explanation}"
runtime_graph.resolve_node(call_node, parsed_res)

Expand Down Expand Up @@ -288,19 +339,25 @@ def crafting(messages: MessagesState):
runtime_graph.add_node(crafting_node)
runtime_graph.add_edge(runtime_graph.temp_node, crafting_node)
runtime_graph.temp_node = crafting_node
ai_feedback = runtime_graph.temp_response.explanation
crafting_messages = [
HumanMessage(content="Original task:\n" + parse_response(runtime_graph.goal)),
AIMessage(content=ai_feedback),
SystemMessage(
content="Craft a tool to solve this problem using craft_tool. It must be a function"
content="Use the context given to craft a tool to solve this problem using craft_tool. It must be a function"
),
]
craft_res = crafter_agent.invoke(
{"messages": crafting_messages}, config={"recursion_limit": MAX_INTERACTIONS}
)
runtime_graph.temp_response.response = parse_response_for_tool_node(
craft_res
).response
parsed_res = f"Response: {parse_response_for_tool_node(craft_res).response}\nExplanation: {parse_response_for_tool_node(craft_res).explanation}"
try:
craft_res = crafter_agent.invoke(
{"messages": crafting_messages},
config={"recursion_limit": MAX_INTERACTIONS},
)
parsed_res = parse_response(craft_res)
except Exception:
parsed_res = ""
# runtime_graph.temp_response.response = parse_response_for_tool_node(
# craft_res
# ).response
runtime_graph.resolve_node(crafting_node, parsed_res)
runtime_graph.temp_node = runtime_graph.call_tool_node()
runtime_graph.add_edge(crafting_node, runtime_graph.temp_node)
Expand All @@ -322,36 +379,31 @@ def crafting(messages: MessagesState):


def test_result(messages: MessagesState):
n = runtime_graph.exist_tool_available()
is_tool_path_available = runtime_graph.exist_tool_available()
test_node = runtime_graph.temp_node
if not isinstance(test_node, TestNode):
raise TypeError("Expected TestNode for scoring")

if test_node.score >= (
threshold = (
COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity
):
)
if test_node.score >= threshold:
runtime_graph.add_edge(test_node, runtime_graph.temp_response)
runtime_graph.temp_response.resolved = True
return END
elif (
test_node.score
< (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity)
and n is True
test_node.score < threshold
and is_tool_path_available is True
and test_node.need_tool_crafting is True
):
return "crafting"
elif (
test_node.score
< (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity)
and n is True
):
elif test_node.score < threshold and is_tool_path_available is True:
if test_node.need_tool_crafting is True:
test_node.response = "The problem is too complex to craft a new tool, try reason step by step or divide complexity."
return "backtrack"
elif (
test_node.score
< (COMPLEXITY_THRESHOLD - COMPLEXITY_COEFFICIENT * test_node.problem_complexity)
and n is False
and is_tool_path_available is False
and runtime_graph.exist_reasoning_node_available()
):
return "reasoning_mode"
Expand Down Expand Up @@ -433,7 +485,6 @@ def call_graph(prompt: str):
def invoke_graph():
graph = StateGraph(MessagesState)
graph.add_node(goal)
# graph.add_node(tool_expand)
graph.add_node(tool_reasoning)
graph.add_node(tool_call)
graph.add_node(backtrack)
Expand All @@ -442,7 +493,6 @@ def invoke_graph():
graph.add_node(response_evaluation)
graph.add_node(reasoning_mode)
graph.add_edge(START, "goal")
# graph.add_edge("goal", "tool_expand")
graph.add_edge("goal", "tool_reasoning")
graph.add_edge("tool_reasoning", "tool_call")
graph.add_edge("tool_call", "response_evaluation")
Expand Down
Loading