-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrunner.py
More file actions
90 lines (76 loc) · 4.62 KB
/
runner.py
File metadata and controls
90 lines (76 loc) · 4.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import re
import argparse
import sys
import LLM
from env import Environment
from agents.agent import Agent
from agents.attack_agent import AttackAgent
from agents.controller_agent import ControllerAgent
def calculate_total_tokens(log_dir):
total_prompt_tokens = 0
total_sample_tokens = 0
for root, dirs, files in os.walk(log_dir):
for file in files:
if file.endswith("_log.log") or "_log.log_" in file:
with open(os.path.join(root, file), 'r') as f:
content = f.read()
prompt_tokens = re.findall(r"Number of prompt tokens: (\d+)", content)
sample_tokens = re.findall(r"Number of sampled tokens: (\d+)", content)
total_prompt_tokens += sum(map(int, prompt_tokens))
total_sample_tokens += sum(map(int, sample_tokens))
with open(os.path.join(log_dir, "input_tokens.txt"), "w") as f:
f.write(str(total_prompt_tokens))
with open(os.path.join(log_dir, "output_tokens.txt"), "w") as f:
f.write(str(total_sample_tokens))
return total_prompt_tokens, total_sample_tokens
def run(agent_cls, args):
with Environment(args) as env:
print("=====================================")
task_prompt, benchmark_folder = env.get_task_description()
print("Benchmark folder: ", benchmark_folder)
print("Task Prompt: ", task_prompt)
print("Lower level actions enabled for AttackAgent: ", [action.name for action in env.low_level_actions])
print("High level actions enabled for AttackAgent: ", [action.name for action in env.high_level_actions])
print("Read only files: ", env.read_only_files, file=sys.stderr)
print("=====================================")
if args.task == "all_in_one":
agent = ControllerAgent(args, env)
print("High-level actions enabled for ControllerAgent: ", agent.all_tool_names)
else:
agent = agent_cls(args, env)
final_message = agent.run(env)
total_prompt_tokens, total_sample_tokens = calculate_total_tokens(args.log_dir)
print(f"Total prompt tokens: {total_prompt_tokens}")
print(f"Total sample tokens: {total_sample_tokens}")
print("=====================================")
print("Final message: ", final_message)
env.save("final")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="mia", help="task name (mia, attr_infer, data_recon, model_steal, all_in_one)")
parser.add_argument("--task-config", type=str, default="", help="task config file under the benchmark folder")
parser.add_argument("--log-dir", type=str, default="./logs", help="log dir")
parser.add_argument("--work-dir", type=str, default="./workspace", help="work dir")
parser.add_argument("--max-steps", type=int, default=50, help="number of steps")
parser.add_argument("--max-time", type=int, default=5* 60 * 60, help="max time")
parser.add_argument("--device", type=int, default=0, help="device id")
parser.add_argument("--python", type=str, default="python", help="python command")
parser.add_argument("--interactive", action="store_true", help="interactive mode")
parser.add_argument("--resume", type=str, default=None, help="resume from a previous run")
parser.add_argument("--resume-step", type=int, default=0, help="the step to resume from")
# agent configs
parser.add_argument("--agent-type", type=str, default="AttackAgent", help="agent type")
parser.add_argument("--llm-name", type=str, default="gpt-4o", help="llm name")
parser.add_argument("--agent-max-steps", type=int, default=50, help="max iterations for agent")
parser.add_argument("--actions-remove-from-prompt", type=str, nargs='+', default=[], help="actions to remove from prompt")
parser.add_argument("--actions-add-to-prompt", type=str, nargs='+', default=[], help="actions to add to prompt")
parser.add_argument("--valid-format-entires", type=str, nargs='+', default=None, help="valid format entries")
parser.add_argument("--max-steps-in-context", type=int, default=3, help="max steps in context")
parser.add_argument("--max-observation-steps-in-context", type=int, default=3, help="max observation steps in context")
parser.add_argument("--max-retries", type=int, default=30, help="max retries")
parser.add_argument("--fast-llm-name", type=str, default="gpt-4o", help="fast llm name")
args = parser.parse_args()
os.environ['FAST_LLM_NAME'] = args.fast_llm_name
print(args, file=sys.stderr)
run(getattr(sys.modules[__name__], args.agent_type), args)