Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 24 additions & 57 deletions 01_text_to_sql_pipeline_vLLM_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,6 @@ def extract_sql_query(self, response_object):
return value
return None

def handle_streaming_response(self, response_gen):
final_response = ""
for chunk in response_gen:
final_response += chunk
return final_response

def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> Union[str, Generator, Iterator]:
# Use the established psycopg2 connection to create a SQLAlchemy engine
Expand Down Expand Up @@ -199,54 +194,22 @@ def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dic

Question: {query_str}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""



synthesis_prompt = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful AI Assistant synthesizing the response from a PostgreSQL query.
Make sure to always use the stop token you were trained on at the end of a response: <|eot_id|>

You are required to use the following format, each taking one line:
<|start_header_id|>user<|end_header_id|>

SQLResponse:



******** edit this *********



Only use tables listed below.
movies

Only use columns listed below.
[('Release Year',), ('title',), ('Origin/Ethnicity',), ('director',), ('Cast',), ('genre',), ('Wiki Page',), ('plot',)]

Question: How many rows in the database?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

SQLQuery: SELECT COUNT(*) FROM "movies"<|eot_id|>

<|start_header_id|>user<|end_header_id|>
Only use tables listed below.
movies

Only use columns listed below.
[('Release Year',), ('title',), ('Origin/Ethnicity',), ('director',), ('Cast',), ('genre',), ('Wiki Page',), ('plot',)]

Question: How many comedy movies in 1995?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

SQLQuery: SELECT COUNT(*) FROM "movies" WHERE "Release Year" = 1995 AND "genre" = 'comedy'<|start_header_id|>user<|end_header_id|>

Only use tables listed below.
movies

Only use columns listed below.
[('Release Year',), ('title',), ('Origin/Ethnicity',), ('director',), ('Cast',), ('genre',), ('Wiki Page',), ('plot',)]

Question: {query_str}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""

response_synthesis_prompt_str = (
"<|begin_of_text|><|start_header_id|>system<|end_header_id|>"
"Given an input question, synthesize a response from the SQL Response results.\n"
"Also provide full received SQL query in your response.\n"
"Format SQL query with: ```sql <sql query>```"
"Answer in the language in which you were asked the question.\n"
"Question: {query_str}\n"
"SQL query: {sql_query}\n"
"SQL Response: {context_str}\n"
"Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>"
)

response_synthesis_prompt = PromptTemplate(
response_synthesis_prompt_str,
)

text_to_sql_template = PromptTemplate(text_to_sql_prompt)

Expand All @@ -257,35 +220,39 @@ def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dic
embed_model="local",
text_to_sql_prompt=text_to_sql_template,
synthesize_response=False,
#response_synthesis_prompt=synthesis_prompt,
#response_synthesis_prompt=response_synthesis_prompt,
streaming=True
)

try:
response = query_engine.query(user_message)
sql_query = self.extract_sql_query(response.metadata)

if hasattr(response, 'response_gen'):
final_response = self.handle_streaming_response(response.response_gen)
result = f"Generated SQL Query:\n```sql\n{sql_query}\n```\nResponse:\n{final_response}"
self.engine.dispose()
return result
return response.response_gen

else:
final_response = response.response
result = f"Generated SQL Query:\n```sql\n{sql_query}\n```\nResponse:\n{final_response}"
self.engine.dispose()
return result

except aiohttp.ClientResponseError as e:
logging.error(f"ClientResponseError: {e}")
self.engine.dispose()
return f"ClientResponseError: {e}"

except aiohttp.ClientPayloadError as e:
logging.error(f"ClientPayloadError: {e}")
self.engine.dispose()
return f"ClientPayloadError: {e}"

except aiohttp.ClientConnectionError as e:
logging.error(f"ClientConnectionError: {e}")
self.engine.dispose()
return f"ClientConnectionError: {e}"

except Exception as e:
logging.error(f"Unexpected error: {e}")
self.engine.dispose()
Expand Down