diff --git a/01_text_to_sql_pipeline_vLLM_llama.py b/01_text_to_sql_pipeline_vLLM_llama.py index 0d0d60f..ccb0a56 100644 --- a/01_text_to_sql_pipeline_vLLM_llama.py +++ b/01_text_to_sql_pipeline_vLLM_llama.py @@ -131,11 +131,6 @@ def extract_sql_query(self, response_object): return value return None - def handle_streaming_response(self, response_gen): - final_response = "" - for chunk in response_gen: - final_response += chunk - return final_response def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dict) -> Union[str, Generator, Iterator]: # Use the established psycopg2 connection to create a SQLAlchemy engine @@ -199,54 +194,22 @@ def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dic Question: {query_str}<|eot_id|><|start_header_id|>assistant<|end_header_id|> """ - - - - synthesis_prompt = """ - <|begin_of_text|><|start_header_id|>system<|end_header_id|> - You are a helpful AI Assistant synthesizing the response from a PostgreSQL query. - Make sure to always use the stop token you were trained on at the end of a response: <|eot_id|> - - You are required to use the following format, each taking one line: - <|start_header_id|>user<|end_header_id|> - - SQLResponse: - - - - ******** edit this ********* - - - - Only use tables listed below. - movies - - Only use columns listed below. - [('Release Year',), ('title',), ('Origin/Ethnicity',), ('director',), ('Cast',), ('genre',), ('Wiki Page',), ('plot',)] - - Question: How many rows in the database?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - - SQLQuery: SELECT COUNT(*) FROM "movies"<|eot_id|> - - <|start_header_id|>user<|end_header_id|> - Only use tables listed below. - movies - - Only use columns listed below. - [('Release Year',), ('title',), ('Origin/Ethnicity',), ('director',), ('Cast',), ('genre',), ('Wiki Page',), ('plot',)] - - Question: How many comedy movies in 1995?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - - SQLQuery: SELECT COUNT(*) FROM "movies" WHERE "Release Year" = 1995 AND "genre" = 'comedy'<|start_header_id|>user<|end_header_id|> - - Only use tables listed below. - movies - - Only use columns listed below. - [('Release Year',), ('title',), ('Origin/Ethnicity',), ('director',), ('Cast',), ('genre',), ('Wiki Page',), ('plot',)] - - Question: {query_str}<|eot_id|><|start_header_id|>assistant<|end_header_id|> - """ + + response_synthesis_prompt_str = ( + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>" + "Given an input question, synthesize a response from the SQL Response results.\n" + "Also provide full received SQL query in your response.\n" + "Format SQL query with: ```sql ```" + "Answer in the language in which you were asked the question.\n" + "Question: {query_str}\n" + "SQL query: {sql_query}\n" + "SQL Response: {context_str}\n" + "Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>" + ) + + response_synthesis_prompt = PromptTemplate( + response_synthesis_prompt_str, + ) text_to_sql_template = PromptTemplate(text_to_sql_prompt) @@ -257,35 +220,39 @@ def pipe(self, user_message: str, model_id: str, messages: List[dict], body: dic embed_model="local", text_to_sql_prompt=text_to_sql_template, synthesize_response=False, - #response_synthesis_prompt=synthesis_prompt, + #response_synthesis_prompt=response_synthesis_prompt, streaming=True ) try: response = query_engine.query(user_message) sql_query = self.extract_sql_query(response.metadata) + if hasattr(response, 'response_gen'): - final_response = self.handle_streaming_response(response.response_gen) - result = f"Generated SQL Query:\n```sql\n{sql_query}\n```\nResponse:\n{final_response}" self.engine.dispose() - return result + return response.response_gen + else: final_response = response.response result = f"Generated SQL Query:\n```sql\n{sql_query}\n```\nResponse:\n{final_response}" self.engine.dispose() return result + except aiohttp.ClientResponseError as e: logging.error(f"ClientResponseError: {e}") self.engine.dispose() return f"ClientResponseError: {e}" + except aiohttp.ClientPayloadError as e: logging.error(f"ClientPayloadError: {e}") self.engine.dispose() return f"ClientPayloadError: {e}" + except aiohttp.ClientConnectionError as e: logging.error(f"ClientConnectionError: {e}") self.engine.dispose() return f"ClientConnectionError: {e}" + except Exception as e: logging.error(f"Unexpected error: {e}") self.engine.dispose()