-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathutil.py
More file actions
120 lines (97 loc) · 3.79 KB
/
util.py
File metadata and controls
120 lines (97 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import joblib
import pdb
import json
import re
from tqdm import tqdm
from collections import Counter
from datasets import load_dataset
from llm import OPENAI
def load_data(task):
action_prompt_header = ''
if task == 'mitre':
data_path = 'data/mitre.json'
action_prompt_header = "Based on the above analysis, here's the updated version of the solution: "
elif task == 'instruct':
data_path = 'data/instruct.json'
action_prompt_header = "Based on the above analysis, here's the updated version of the code in a single code block (wrap in ```):\n"
elif task == 'autocomplete':
data_path = 'data/autocomplete.json'
action_prompt_header = "Based on the above analysis, here's the updated version of the code in a single code block (wrap in ```):\n"
elif task == 'promptinject':
data_path = 'data/prompt_injection.json'
elif task == 'interpreter':
data_path = 'data/interpreter.json'
elif task == 'cvs':
data_path = 'data/cvs.json'
action_prompt_header = "Based on the above analysis, here's the updated version of the code in a single code block (wrap in ```):\n"
data = json.load(open(data_path, 'r'))
if task in ['interpreter']:
for sample in data:
sample['language'] = 'python'
if task in ['mitre', 'frr', 'interpreter']:
question_prompt_key = 'mutated_prompt'
elif task == 'promptinject':
question_prompt_key = 'user_input'
elif task in ['cvs']:
question_prompt_key = 'question'
else:
question_prompt_key = 'test_case_prompt'
return data, action_prompt_header, question_prompt_key
def get_model(model_name, model_mapping):
if model_name in ['gpt4', 'gpt3.5']:
model = OPENAI(
model_name=model_mapping[model_name],
api_key=os.environ['OPENAI_API_KEY']
)
else:
model = OPENAI(
model_name=model_mapping[model_name],
api_key="EMPTY",
base_url="http://localhost:8000/v1"
)
return model
def get_code_before(sample):
code_before = ''
line_text = sample.get("line_text", None)
origin_code = sample.get("origin_code", None)
if line_text and origin_code:
code_before = origin_code.split(line_text)[0]
code_before = code_before.strip('\n')
if code_before.startswith(' ') or code_before.startswith('\t'): # indentation
code_before = 'if True:\n' + code_before
return code_before
def extract_content_in_code_blocks(input: str, keyword=''):
# Using regular expression to find content between code blocks ```
output = re.findall(r"```{}(.*?)```".format(keyword), input, re.DOTALL)
if len(output)>0:
return output[0]
return input
def extract_code(input: str):
code = extract_content_in_code_blocks(input).strip()
if code.startswith('python'):
code = code[6:]
return code
def extract_tools(tool_selections):
json_str = extract_content_in_code_blocks(tool_selections, 'json')
try:
tools = json.loads(json_str)
return tools
except Exception:
return []
def format_step(step: str) -> str:
return step.strip('\n').strip()
def parse_action(string):
if 'Search' in string:
index = string.index('Search') + len('Search')
string = string[index:]
if '[' in string:
start_idx = string.index('[') + 1
if ']' in string:
end_idx = string.index(']')
return 'Search', string[start_idx:end_idx]
else:
return 'Search', string[start_idx:]
else:
return 'Search', string
return 'Search', string