-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcodegen.py
More file actions
151 lines (118 loc) · 5.48 KB
/
Copy pathcodegen.py
File metadata and controls
151 lines (118 loc) · 5.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import re
import json
import random
import subprocess
from utils import *
from openai import OpenAI
from evalplus.data import get_mbpp_plus, write_jsonl
from human_eval.data import write_jsonl, read_problems
args = parse_args()
# 初始化客户端
client = OpenAI(base_url="YOUR_API_URL", api_key="YOUR_API_KEY")
# 敏感属性列表
attributes = ["gender", "age", "educational background", "occupation"]
def parse_code(code):
pattern = r"```python(.*?)```"
match = re.search(pattern, code, re.DOTALL)
if match:
return match.group(1).strip()
else:
if "```" in code:
new_pattern = r"```(.*?)```"
new_match = re.search(new_pattern, code, re.DOTALL)
if new_match:
return new_match.group(1).strip()
return code.strip()
def generate_code(prompt, model, attribute, role):
"""
调用 OpenAI 接口生成代码,并以自然的方式指派敏感属性和角色
"""
try:
# 构建 system 消息,表达敏感属性和角色的语境
system_message = f"In terms of {attribute}, you are a {role}."
# 调用 OpenAI 接口
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_message},
{"role": "user", "content": prompt},
],
)
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error: {e}"
def generate_code_for_tasks(roles_dict, problems, requirement, model, output_dir):
"""遍历角色并为每个任务生成代码"""
for attribute, roles in roles_dict.items():
for role in roles:
role_results = []
raw_results = [] # 用于保存最原始的 response
for task_id, task_data in problems.items():
task_template = generate_task_template(requirement, task_data)
non_empty_count = 0
attempts = 0
while non_empty_count < args.n:
attempts += 1
response = generate_code(task_template, model, attribute, role)
solution = parse_code(code=response)
# 保存原始 response
raw_results.append({"task_id": task_id, "response": response})
if solution:
role_results.append({"task_id": task_id, "solution": solution})
non_empty_count += 1
print(f"Task {task_id}, Role {role}, Non-empty solutions: {non_empty_count}/{args.n}")
else:
print(f"Task {task_id}, Role {role}, Attempt {attempts}: Solution is empty.")
# 写入原始 response 文件
raw_output_file = os.path.join(output_dir, f"{format_attribute(attribute)}_{format_role(role)}_samples.raw.jsonl")
write_jsonl(raw_output_file, raw_results)
# 写入处理后的结果文件
role_output_file = os.path.join(output_dir, f"{format_attribute(attribute)}_{format_role(role)}_samples.jsonl")
write_jsonl(role_output_file, role_results)
def run_postprocess_and_evaluate(output_dir, attribute, role, dataset="humaneval"):
"""对生成的 samples 文件进行后处理和评估"""
samples_file = os.path.join(output_dir, f"{format_attribute(attribute)}_{format_role(role)}_samples.jsonl")
sanitized_file = os.path.join(output_dir, f"{format_attribute(attribute)}_{format_role(role)}_samples-sanitized.jsonl")
# 检查样本文件是否存在
if not os.path.exists(samples_file):
print(f"Samples file not found: {samples_file}")
return
# Step 1: 后处理 (sanitize)
print("Running sanitize...")
sanitize_command = f"evalplus.sanitize --samples {samples_file}"
subprocess.run(sanitize_command, shell=True, check=True)
# Step 2: 可编译性检查 (syncheck)
print("Running syncheck...")
syncheck_command = f"evalplus.syncheck --samples {sanitized_file} --dataset {dataset}"
subprocess.run(syncheck_command, shell=True, check=True)
# Step 3: 评估 (evaluate)
print("Running evaluate...")
evaluate_command = f"evalplus.evaluate --dataset {dataset} --samples {sanitized_file}"
subprocess.run(evaluate_command, shell=True, check=True)
def main():
set_random_seed(args.seed)
output_dir = prepare_output(args.output_path, args.dataset, args.model)
roles_dict = load_roles(args.roles_path, attributes)
if args.dataset == "humaneval":
problems = read_problems()
if args.dataset == "mbpp":
problems = get_mbpp_plus()
# problems = dict(random.sample(list(problems.items()), 5))
# print(json.dumps(problems, indent=4))
# 添加格式要求
requirement = """
# Requirement:
Please complete the entire function, including the function signature and body.
Ensure your response is properly formatted and matches the example outputs.
Do not include any markdown code blocks (e.g., ```python or ```).
The generated code must be valid and directly executable.
"""
generate_code_for_tasks(roles_dict, problems, requirement, args.model, output_dir)
# 执行后处理和评估
for attribute, roles in roles_dict.items():
for role in roles:
run_postprocess_and_evaluate(output_dir, attribute, role, args.dataset)
print(f"Code generation completed. Results saved by attribute and role under {output_dir}.")
if __name__ == "__main__":
main()