-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathPromptBuilder.py
More file actions
172 lines (149 loc) · 6.61 KB
/
PromptBuilder.py
File metadata and controls
172 lines (149 loc) · 6.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import configparser
import tiktoken
from transformers import AutoTokenizer
from db import DataBase
from prompt_templates.prompt_template_1 import prompt_template_1
from prompt_templates.prompt_template_2 import prompt_template_2
from prompt_templates.system_prompt import system_prompt
class PromptBuilder:
def __init__(self):
"""
This class is responsible for constructing the prompt for the LLM given a reference to a method in the database
:param db_name: the name of the database
:param max_tokens: the maximum number of tokens allowed in the prompt
"""
self.db = DataBase()
self.config = configparser.ConfigParser()
self.config.read('config.ini')
self.USE_MODEL = self.config.getboolean('MODEL', 'USE_MODEL')
if self.USE_MODEL:
self.MODEL_PATH = self.config.get('MODEL', 'MODEL_PATH')
self.max_tokens = int(self.config.get('MODEL', 'MODEL_MAX_INPUT_TOKENS'))
def construct_initial_prompt(self, method_id):
method = self.db.get_method_by_id(method_id)
method_name = method["methodIdentifier"]
class_name = method["classIdentifier"]
size = 1
prompt = ""
while size <= 4 and self.check_token_limit(
self._generate_prompts_with_different_size(method_name, class_name, method)):
prompt = self._generate_prompts_with_different_size(method_name, class_name, method)
size += 1
return prompt
def construct_initial_prompt_class(self, class_identifier):
class_ = self.db.get_class(class_identifier)
class_name = class_identifier
size = 1
prompt = self._generate_prompts_for_class(class_name, class_)
while size <= 4 and self.check_token_limit(self._generate_prompts_for_class(class_name, class_)):
prompt = self._generate_prompts_for_class(class_name, class_)
size += 1
return prompt
def construct_initial_prompt_method(self, method_identifier, class_identifier):
method_id = self.db.get_method_id(method_identifier, class_identifier)
method = self.db.get_method_by_id(method_id)
size = 1
while size <= 4 and self.check_token_limit(self._generate_prompts_for_method_(method, class_identifier)):
prompt = self._generate_prompts_for_method_(method, class_identifier)
size += 1
return prompt
def construct_error_prompt(self, method_id, error_message):
pass
@staticmethod
def construct_code_prompt_from_dict_list(code_list, language_identifier, is_method):
"""
Constructs a prompt from a list of code snippets.
:param code_list: the list of code snippets
:param language_identifier: the language identifier
:param is_method: True if the code snippets are methods, False otherwise
:return: Single string containing all code snippets wrapped in code blocks
"""
prompt = ""
for code in code_list:
if is_method:
prompt += "This is the method: " + str(code["methodIdentifier"]) + " of the class " + str(
code["classIdentifier"]) + ":\n"
prompt += "```" + language_identifier + "\n"
prompt += str(code["fullText"])
prompt += "\n```\n"
return prompt if prompt != "" else "No relations found."
def check_token_limit(self, prompt: str):
"""
Checks if the prompt is too long for the given token limit.
:param prompt: the prompt to check
:param token_limit: the token limit to check against
:return: True if the prompt is too long, False otherwise
"""
if self.USE_MODEL:
tokens = self.tokenize_with_model(prompt, self.MODEL_PATH)
else:
tokens = self.tokenize_with_tiktoken(prompt)
return len(tokens) < self.max_tokens
def tokenize_with_model(self, prompt: str, model_path: str):
"""
Generates a list of tokens for the prompt.
Uses the Llama tokenizer provided by llama-cpp-python.
:param prompt: the prompt to generate tokens for
:param model_path: the path to the model to use
:return: A list of integers representing the tokens
"""
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokens = tokenizer.encode(prompt, truncation=True, max_length=self.max_tokens)
return tokens
@staticmethod
def tokenize_with_tiktoken(prompt: str):
"""
Generates a list of tokens for the prompt using tiktoken with the cl100k_base encoding.
Using this method can be useful when running the application with OpenAI models, as they use
the same encoding.
:param prompt: the prompt to generate tokens for
:return: A list of integers representing the tokens
"""
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(prompt)
return tokens
@staticmethod
def _generate_prompts_with_different_size(method_name: str,
class_name: str,
method: dict):
prompt_template = "\n".join(
[
system_prompt,
prompt_template_1
])
return prompt_template.format(
method_name=method_name,
class_name=class_name,
method_code=method["fullText"],
testing_framework="Unittest",
mocking_framework="MagicMock",
)
@staticmethod
def _generate_prompts_for_class(class_name: str,
class_: dict):
prompt_template = "\n".join(
[
system_prompt,
prompt_template_2
])
return prompt_template.format(
class_name=class_name,
class_code=class_["fullText"],
testing_framework="Unittest",
mocking_framework="MagicMock",
)
@staticmethod
def _generate_prompts_for_method_(method: dict,
class_: str):
prompt_template = "\n".join(
[
system_prompt,
prompt_template_1
])
return prompt_template.format(
method_name=method["methodIdentifier"],
class_name=class_,
method_code = method["fullText"],
testing_framework="Unittest",
mocking_framework="MagicMock",
)