From 8c72d1dca6d7d31c4252e39ea18dde29741eca6e Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 10:07:33 +0300 Subject: [PATCH 01/58] remove old poetry version --- .github/workflows/pylint.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 1659747f..3689095f 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -17,8 +17,6 @@ jobs: - name: Install Poetry uses: abatilo/actions-poetry@v3 - with: - poetry-version: "1.8.3" - name: Configure Poetry run: | From e1c84bd0e4e02efbc3dbea82f68e469fc402e39d Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 10:14:18 +0300 Subject: [PATCH 02/58] fix lint --- api/agents.py | 144 ++++++++----- api/config.py | 47 ++--- api/constants.py | 105 +++++----- api/extensions.py | 6 +- api/graph.py | 150 ++++++++------ api/helpers/crm_data_generator.py | 333 ++++++++++++++++++------------ api/index.py | 151 +++++++++----- api/loaders/base_loader.py | 1 + api/loaders/csv_loader.py | 296 +++++++++++++++----------- api/loaders/graph_loader.py | 139 +++++++------ api/loaders/json_loader.py | 45 ++-- api/loaders/odata_loader.py | 101 ++++++--- api/loaders/schema_validator.py | 22 +- api/utils.py | 108 ++++++---- onthology.py | 2 +- 15 files changed, 980 insertions(+), 670 deletions(-) diff --git a/api/agents.py b/api/agents.py index 7082c0c8..1f5228dd 100644 --- a/api/agents.py +++ b/api/agents.py @@ -1,9 +1,12 @@ import json +from typing import Any, Dict, List + from litellm import completion + from api.config import Config -from typing import List, Dict, Any -class AnalysisAgent(): + +class AnalysisAgent: def __init__(self, queries_history: list, result_history: list): if result_history is None: self.messages = [] @@ -13,48 +16,63 @@ def __init__(self, queries_history: list, result_history: list): self.messages.append({"role": "user", "content": query}) self.messages.append({"role": "assistant", "content": result}) - def get_analysis(self, user_query: str, combined_tables: list, db_description: str, instructions: str = None) -> dict: + def get_analysis( + self, + user_query: str, + combined_tables: list, + db_description: str, + instructions: str = None, + ) -> dict: formatted_schema = self._format_schema(combined_tables) - prompt = self._build_prompt(user_query, formatted_schema, db_description, instructions) + prompt = self._build_prompt( + user_query, formatted_schema, db_description, instructions + ) self.messages.append({"role": "user", "content": prompt}) - completion_result = completion(model=Config.COMPLETION_MODEL, - messages=self.messages, - temperature=0, - top_p=1, - ) - + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=self.messages, + temperature=0, + top_p=1, + ) + response = completion_result.choices[0].message.content analysis = _parse_response(response) - if isinstance(analysis['ambiguities'], list): - analysis['ambiguities'] = [item.replace('-', ' ') for item in analysis['ambiguities']] - analysis['ambiguities'] = "- " + "- ".join(analysis['ambiguities']) - if isinstance(analysis['missing_information'], list): - analysis['missing_information'] = [item.replace('-', ' ') for item in analysis['missing_information']] - analysis['missing_information'] = "- " + "- ".join(analysis['missing_information']) - self.messages.append({"role": "assistant", "content": analysis['sql_query']}) + if isinstance(analysis["ambiguities"], list): + analysis["ambiguities"] = [ + item.replace("-", " ") for item in analysis["ambiguities"] + ] + analysis["ambiguities"] = "- " + "- ".join(analysis["ambiguities"]) + if isinstance(analysis["missing_information"], list): + analysis["missing_information"] = [ + item.replace("-", " ") for item in analysis["missing_information"] + ] + analysis["missing_information"] = "- " + "- ".join( + analysis["missing_information"] + ) + self.messages.append({"role": "assistant", "content": analysis["sql_query"]}) return analysis - + def _format_schema(self, schema_data: List) -> str: """ Format the schema data into a readable format for the prompt. - + Args: schema_data: Schema in the structure [...] - + Returns: Formatted schema as a string """ formatted_schema = [] - + for table_info in schema_data: table_name = table_info[0] table_description = table_info[1] foreign_keys = table_info[2] columns = table_info[3] - + # Format table header table_str = f"Table: {table_name} - {table_description}\n" - + # Format columns using the updated OrderedDict structure for column in columns: col_name = column.get("columnName", "") @@ -62,11 +80,15 @@ def _format_schema(self, schema_data: List) -> str: col_description = column.get("description", "") col_key = column.get("keyType", None) nullable = column.get("nullable", False) - - key_info = f", PRIMARY KEY" if col_key == "PRI" else f", FOREIGN KEY" if col_key == "FK" else "" + + key_info = ( + f", PRIMARY KEY" + if col_key == "PRI" + else f", FOREIGN KEY" if col_key == "FK" else "" + ) column_str = f" - {col_name} ({col_type},{key_info},{col_key},{nullable}): {col_description}" table_str += column_str + "\n" - + # Format foreign keys if isinstance(foreign_keys, dict) and foreign_keys: table_str += " Foreign Keys:\n" @@ -74,20 +96,24 @@ def _format_schema(self, schema_data: List) -> str: column = fk_info.get("column", "") ref_table = fk_info.get("referenced_table", "") ref_column = fk_info.get("referenced_column", "") - table_str += f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" - + table_str += ( + f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" + ) + formatted_schema.append(table_str) - + return "\n".join(formatted_schema) - def _build_prompt(self, user_input: str, formatted_schema: str, db_description: str, instructions) -> str: + def _build_prompt( + self, user_input: str, formatted_schema: str, db_description: str, instructions + ) -> str: """ Build the prompt for Claude to analyze the query. - + Args: user_input: The natural language query from the user formatted_schema: Formatted database schema - + Returns: The formatted prompt for Claude """ @@ -169,7 +195,7 @@ def _build_prompt(self, user_input: str, formatted_schema: str, db_description: return prompt -class RelevancyAgent(): +class RelevancyAgent: def __init__(self, queries_history: list, result_history: list): if result_history is None: self.messages = [] @@ -180,13 +206,21 @@ def __init__(self, queries_history: list, result_history: list): self.messages.append({"role": "assistant", "content": result}) def get_answer(self, user_question: str, database_desc: dict) -> dict: - self.messages.append({"role": "user", "content": RELEVANCY_PROMPT.format(QUESTION_PLACEHOLDER=user_question, DB_PLACEHOLDER=json.dumps(database_desc))}) + self.messages.append( + { + "role": "user", + "content": RELEVANCY_PROMPT.format( + QUESTION_PLACEHOLDER=user_question, + DB_PLACEHOLDER=json.dumps(database_desc), + ), + } + ) completion_result = completion( model=Config.COMPLETION_MODEL, messages=self.messages, temperature=0, ) - + answer = completion_result.choices[0].message.content self.messages.append({"role": "assistant", "content": answer}) return _parse_response(answer) @@ -240,28 +274,33 @@ def get_answer(self, user_question: str, database_desc: dict) -> dict: """ -class FollowUpAgent(): +class FollowUpAgent: def __init__(self): pass - def get_answer(self, user_question: str, conversation_hist: list, database_schema: dict) -> dict: + def get_answer( + self, user_question: str, conversation_hist: list, database_schema: dict + ) -> dict: completion_result = completion( model=Config.COMPLETION_MODEL, messages=[ { - "content": FOLLOW_UP_PROMPT.format(QUESTION=user_question, HISTORY=conversation_hist, SCHEMA=json.dumps(database_schema)), - "role": "user" + "content": FOLLOW_UP_PROMPT.format( + QUESTION=user_question, + HISTORY=conversation_hist, + SCHEMA=json.dumps(database_schema), + ), + "role": "user", } ], response_format={"type": "json_object"}, temperature=0, ) - + answer = completion_result.choices[0].message.content return json.loads(answer) - FOLLOW_UP_PROMPT = """You are an expert assistant that receives two inputs: 1. The user’s question: {QUESTION} @@ -297,8 +336,7 @@ def get_answer(self, user_question: str, conversation_hist: list, database_schem 4. Ensure your response is concise, polite, and helpful. When asking clarifying questions, be specific and guide the user toward providing the missing details so you can effectively address their query.""" - -class TaxonomyAgent(): +class TaxonomyAgent: def __init__(self): pass @@ -306,7 +344,7 @@ def get_answer(self, question: str, sql: str) -> str: messages = [ { "content": TAXONOMY_PROMPT.format(QUESTION=question, SQL=sql), - "role": "user" + "role": "user", } ] completion_result = completion( @@ -314,12 +352,11 @@ def get_answer(self, question: str, sql: str) -> str: messages=messages, temperature=0, ) - + answer = completion_result.choices[0].message.content return answer - TAXONOMY_PROMPT = """You are an advanced taxonomy generator. For a pair of question and SQL query provde a single clarification question to the user. * For any SQL query that contain WHERE clause, provide a clarification question to the user about the generated value. * Your question can contain more than one clarification related to WHERE clause. @@ -347,22 +384,23 @@ def get_answer(self, question: str, sql: str) -> str: The question to the user:" """ + def _parse_response(response: str) -> Dict[str, Any]: """ Parse Claude's response to extract the analysis. - + Args: response: Claude's response string - + Returns: Parsed analysis results """ try: # Extract JSON from the response - json_start = response.find('{') - json_end = response.rfind('}') + 1 + json_start = response.find("{") + json_end = response.rfind("}") + 1 json_str = response[json_start:json_end] - + # Parse the JSON analysis = json.loads(json_str) return analysis @@ -372,5 +410,5 @@ def _parse_response(response: str) -> Dict[str, Any]: "is_sql_translatable": False, "confidence": 0, "explanation": f"Failed to parse response: {str(e)}", - "error": str(response) - } \ No newline at end of file + "error": str(response), + } diff --git a/api/config.py b/api/config.py index 2642ea9d..6accfb4d 100644 --- a/api/config.py +++ b/api/config.py @@ -1,56 +1,55 @@ -""" +""" This module contains the configuration for the text2sql module. """ + +import dataclasses import os from typing import Union -import dataclasses -from litellm import embedding + import boto3 +from litellm import embedding -class EmbeddingsModel(): - - def __init__( - self, - model_name: str, - config: dict = None - ): +class EmbeddingsModel: + + def __init__(self, model_name: str, config: dict = None): self.model_name = model_name self.config = config - + def embed(self, text: Union[str, list]) -> list: """ Get the embeddings of the text - + Args: text (str|list): The text(s) to embed - + Returns: list: The embeddings of the text - + """ embeddings = embedding(model=self.model_name, input=text) embeddings = [embedding["embedding"] for embedding in embeddings.data] return embeddings - + def get_vector_size(self) -> int: """ Get the size of the vector - + Returns: int: The size of the vector - + """ - response = embedding(input = ["Hello World"], model=self.model_name) - size = len(response.data[0]['embedding']) + response = embedding(input=["Hello World"], model=self.model_name) + size = len(response.data[0]["embedding"]) return size @dataclasses.dataclass class Config: """ - Configuration class for the text2sql module. + Configuration class for the text2sql module. """ + SCHEMA_PATH = "api/schema_schema.json" EMBEDDING_MODEL_NAME = "azure/text-embedding-ada-002" COMPLETION_MODEL = "azure/gpt-4.1" @@ -66,11 +65,7 @@ class Config: # config["aws_region_name"] = AWS_REGION # config["aws_profile_name"] = AWS_PROFILE - EMBEDDING_MODEL = EmbeddingsModel( - model_name=EMBEDDING_MODEL_NAME, - config=config - ) - + EMBEDDING_MODEL = EmbeddingsModel(model_name=EMBEDDING_MODEL_NAME, config=config) FIND_SYSTEM_PROMPT = """ You are an expert in analyzing natural language queries into SQL tables descriptions. @@ -139,4 +134,4 @@ class Config: * **User Query (Natural Language):** You will be given a user's current question or request in natural language. - """ \ No newline at end of file + """ diff --git a/api/constants.py b/api/constants.py index a1eacddf..35357011 100644 --- a/api/constants.py +++ b/api/constants.py @@ -1,86 +1,85 @@ -EXAMPLES = {'crm_usecase': ["Which companies have generated the most revenue through closed deals, and how much revenue did they generate?", - "How many leads converted into deals over the last month", - "Which companies have open sales opportunities and active SLA agreements in place?", - "Which high-value sales opportunities (value > $50,000) have upcoming meetings scheduled, and what companies are they associated with?"], - 'ERP_system': [ - # "What is the total value of all purchase orders created in the last quarter?", - # "Which suppliers have the highest number of active purchase orders, and what is the total value of those orders?", - "What is the total order value for customer Almo Office?", - "Show the total amount of all orders placed on 11/24", - "What's the profit for order SO2400002?", - "List all confirmed orders form today with their final prices", - "How many items are in order SO2400002?", - - # Product-Specific Questions - "What is the price of Office Chair (part 0001100)?", - "List all items with quantity greater than 3 units", - "Show me all products with price above $20", - "What's the total cost of all A4 Paper items ordered?", - "Which items have the highest profit margin?", - - # Financial Analysis Questions - "Calculate the total profit for this year", - "Show me orders with overall discount greater than 5%", - "What's the average profit percentage across all items?", - "List orders with final price exceeding $700", - "Show me items with profit margin above 50%", - - # Customer-Related Questions - "How many orders has customer 100038 placed?", - "What's the total purchase amount by Almo Office?", - "List all orders with their customer names and contact details", - "Show me customers with orders above $500", - "What's the average order value per customer?", - - # Inventory/Stock Questions - "Which items have zero quantity?", - "Show me all items with their crate types", - "List products with their packaging details", - "What's the total quantity ordered for each product?", - "Show me items with pending shipments" - ] - } +EXAMPLES = { + "crm_usecase": [ + "Which companies have generated the most revenue through closed deals, and how much revenue did they generate?", + "How many leads converted into deals over the last month", + "Which companies have open sales opportunities and active SLA agreements in place?", + "Which high-value sales opportunities (value > $50,000) have upcoming meetings scheduled, and what companies are they associated with?", + ], + "ERP_system": [ + # "What is the total value of all purchase orders created in the last quarter?", + # "Which suppliers have the highest number of active purchase orders, and what is the total value of those orders?", + "What is the total order value for customer Almo Office?", + "Show the total amount of all orders placed on 11/24", + "What's the profit for order SO2400002?", + "List all confirmed orders form today with their final prices", + "How many items are in order SO2400002?", + # Product-Specific Questions + "What is the price of Office Chair (part 0001100)?", + "List all items with quantity greater than 3 units", + "Show me all products with price above $20", + "What's the total cost of all A4 Paper items ordered?", + "Which items have the highest profit margin?", + # Financial Analysis Questions + "Calculate the total profit for this year", + "Show me orders with overall discount greater than 5%", + "What's the average profit percentage across all items?", + "List orders with final price exceeding $700", + "Show me items with profit margin above 50%", + # Customer-Related Questions + "How many orders has customer 100038 placed?", + "What's the total purchase amount by Almo Office?", + "List all orders with their customer names and contact details", + "Show me customers with orders above $500", + "What's the average order value per customer?", + # Inventory/Stock Questions + "Which items have zero quantity?", + "Show me all items with their crate types", + "List products with their packaging details", + "What's the total quantity ordered for each product?", + "Show me items with pending shipments", + ], +} BENCHMARK = [ { "question": "List all contacts who are associated with companies that have at least one active deal in the pipeline, and include the deal stage.", - "sql": "SELECT DISTINCT c.contact_id, c.first_name, c.last_name, d.deal_id, d.deal_name, ds.stage_name FROM contacts AS c JOIN company_contacts AS cc ON c.contact_id = cc.contact_id JOIN companies AS co ON cc.company_id = co.company_id JOIN deals AS d ON co.company_id = d.company_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.is_active = 1;" + "sql": "SELECT DISTINCT c.contact_id, c.first_name, c.last_name, d.deal_id, d.deal_name, ds.stage_name FROM contacts AS c JOIN company_contacts AS cc ON c.contact_id = cc.contact_id JOIN companies AS co ON cc.company_id = co.company_id JOIN deals AS d ON co.company_id = d.company_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.is_active = 1;", }, { "question": "Which sales representatives (users) have closed deals worth more than $100,000 in the past year, and what was the total value of deals they closed?", - "sql": "SELECT u.user_id, u.first_name, u.last_name, SUM(d.amount) AS total_closed_value FROM users AS u JOIN deals AS d ON u.user_id = d.owner_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' AND d.close_date >= DATE_SUB(CURDATE(), INTERVAL 1 YEAR) GROUP BY u.user_id HAVING total_closed_value > 100000;" + "sql": "SELECT u.user_id, u.first_name, u.last_name, SUM(d.amount) AS total_closed_value FROM users AS u JOIN deals AS d ON u.user_id = d.owner_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' AND d.close_date >= DATE_SUB(CURDATE(), INTERVAL 1 YEAR) GROUP BY u.user_id HAVING total_closed_value > 100000;", }, { "question": "Find all contacts who attended at least one event and were later converted into leads that became opportunities within three months of the event.", - "sql": "SELECT DISTINCT c.contact_id, c.first_name, c.last_name FROM contacts AS c JOIN event_attendees AS ea ON c.contact_id = ea.contact_id JOIN events AS e ON ea.event_id = e.event_id JOIN leads AS l ON c.contact_id = l.contact_id JOIN opportunities AS o ON l.lead_id = o.lead_id WHERE o.created_date BETWEEN e.event_date AND DATE_ADD(e.event_date, INTERVAL 3 MONTH);" + "sql": "SELECT DISTINCT c.contact_id, c.first_name, c.last_name FROM contacts AS c JOIN event_attendees AS ea ON c.contact_id = ea.contact_id JOIN events AS e ON ea.event_id = e.event_id JOIN leads AS l ON c.contact_id = l.contact_id JOIN opportunities AS o ON l.lead_id = o.lead_id WHERE o.created_date BETWEEN e.event_date AND DATE_ADD(e.event_date, INTERVAL 3 MONTH);", }, { "question": "Which customers have the highest lifetime value based on their total invoice payments, including refunds and discounts?", - "sql": "SELECT c.contact_id, c.first_name, c.last_name, SUM(i.total_amount - COALESCE(r.refund_amount, 0) - COALESCE(d.discount_amount, 0)) AS lifetime_value FROM contacts AS c JOIN orders AS o ON c.contact_id = o.contact_id JOIN invoices AS i ON o.order_id = i.order_id LEFT JOIN refunds AS r ON i.invoice_id = r.invoice_id LEFT JOIN discounts AS d ON i.invoice_id = d.invoice_id GROUP BY c.contact_id ORDER BY lifetime_value DESC LIMIT 10;" + "sql": "SELECT c.contact_id, c.first_name, c.last_name, SUM(i.total_amount - COALESCE(r.refund_amount, 0) - COALESCE(d.discount_amount, 0)) AS lifetime_value FROM contacts AS c JOIN orders AS o ON c.contact_id = o.contact_id JOIN invoices AS i ON o.order_id = i.order_id LEFT JOIN refunds AS r ON i.invoice_id = r.invoice_id LEFT JOIN discounts AS d ON i.invoice_id = d.invoice_id GROUP BY c.contact_id ORDER BY lifetime_value DESC LIMIT 10;", }, { "question": "Show all deals that have involved at least one email exchange, one meeting, and one phone call with a contact in the past six months.", - "sql": "SELECT DISTINCT d.deal_id, d.deal_name FROM deals AS d JOIN contacts AS c ON d.contact_id = c.contact_id JOIN emails AS e ON c.contact_id = e.contact_id JOIN meetings AS m ON c.contact_id = m.contact_id JOIN phone_calls AS p ON c.contact_id = p.contact_id WHERE e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH);" + "sql": "SELECT DISTINCT d.deal_id, d.deal_name FROM deals AS d JOIN contacts AS c ON d.contact_id = c.contact_id JOIN emails AS e ON c.contact_id = e.contact_id JOIN meetings AS m ON c.contact_id = m.contact_id JOIN phone_calls AS p ON c.contact_id = p.contact_id WHERE e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH);", }, { "question": "Which companies have the highest number of active support tickets, and how does their number of tickets correlate with their total deal value?", - "sql": "SELECT co.company_id, co.company_name, COUNT(st.ticket_id) AS active_tickets, SUM(d.amount) AS total_deal_value FROM companies AS co LEFT JOIN support_tickets AS st ON co.company_id = st.company_id AND st.status = 'Open' LEFT JOIN deals AS d ON co.company_id = d.company_id GROUP BY co.company_id ORDER BY active_tickets DESC;" + "sql": "SELECT co.company_id, co.company_name, COUNT(st.ticket_id) AS active_tickets, SUM(d.amount) AS total_deal_value FROM companies AS co LEFT JOIN support_tickets AS st ON co.company_id = st.company_id AND st.status = 'Open' LEFT JOIN deals AS d ON co.company_id = d.company_id GROUP BY co.company_id ORDER BY active_tickets DESC;", }, { "question": "Retrieve all contacts who are assigned to a sales rep but have not been contacted via email, phone, or meeting in the past three months.", - "sql": "SELECT c.contact_id, c.first_name, c.last_name FROM contacts AS c JOIN users AS u ON c.owner_id = u.user_id LEFT JOIN emails AS e ON c.contact_id = e.contact_id AND e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) LEFT JOIN phone_calls AS p ON c.contact_id = p.contact_id AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) LEFT JOIN meetings AS m ON c.contact_id = m.contact_id AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) WHERE e.contact_id IS NULL AND p.contact_id IS NULL AND m.contact_id IS NULL;" + "sql": "SELECT c.contact_id, c.first_name, c.last_name FROM contacts AS c JOIN users AS u ON c.owner_id = u.user_id LEFT JOIN emails AS e ON c.contact_id = e.contact_id AND e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) LEFT JOIN phone_calls AS p ON c.contact_id = p.contact_id AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) LEFT JOIN meetings AS m ON c.contact_id = m.contact_id AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) WHERE e.contact_id IS NULL AND p.contact_id IS NULL AND m.contact_id IS NULL;", }, { "question": "Which email campaigns resulted in the highest number of closed deals, and what was the average deal size for those campaigns?", - "sql": "SELECT ec.campaign_id, ec.campaign_name, COUNT(d.deal_id) AS closed_deals, AVG(d.amount) AS avg_deal_value FROM email_campaigns AS ec JOIN contacts AS c ON ec.campaign_id = c.campaign_id JOIN deals AS d ON c.contact_id = d.contact_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' GROUP BY ec.campaign_id ORDER BY closed_deals DESC;" + "sql": "SELECT ec.campaign_id, ec.campaign_name, COUNT(d.deal_id) AS closed_deals, AVG(d.amount) AS avg_deal_value FROM email_campaigns AS ec JOIN contacts AS c ON ec.campaign_id = c.campaign_id JOIN deals AS d ON c.contact_id = d.contact_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' GROUP BY ec.campaign_id ORDER BY closed_deals DESC;", }, { "question": "Find the average time it takes for a lead to go from creation to conversion into a deal, broken down by industry.", - "sql": "SELECT ind.industry_name, AVG(DATEDIFF(d.close_date, l.created_date)) AS avg_conversion_time FROM leads AS l JOIN companies AS co ON l.company_id = co.company_id JOIN industries AS ind ON co.industry_id = ind.industry_id JOIN opportunities AS o ON l.lead_id = o.lead_id JOIN deals AS d ON o.opportunity_id = d.opportunity_id WHERE d.stage_id IN (SELECT stage_id FROM deal_stages WHERE stage_name = 'Closed Won') GROUP BY ind.industry_name ORDER BY avg_conversion_time ASC;" + "sql": "SELECT ind.industry_name, AVG(DATEDIFF(d.close_date, l.created_date)) AS avg_conversion_time FROM leads AS l JOIN companies AS co ON l.company_id = co.company_id JOIN industries AS ind ON co.industry_id = ind.industry_id JOIN opportunities AS o ON l.lead_id = o.lead_id JOIN deals AS d ON o.opportunity_id = d.opportunity_id WHERE d.stage_id IN (SELECT stage_id FROM deal_stages WHERE stage_name = 'Closed Won') GROUP BY ind.industry_name ORDER BY avg_conversion_time ASC;", }, { "question": "Which sales reps (users) have the highest win rate, calculated as the percentage of their assigned leads that convert into closed deals?", - "sql": "SELECT u.user_id, u.first_name, u.last_name, COUNT(DISTINCT d.deal_id) / COUNT(DISTINCT l.lead_id) * 100 AS win_rate FROM users AS u JOIN leads AS l ON u.user_id = l.owner_id LEFT JOIN opportunities AS o ON l.lead_id = o.lead_id LEFT JOIN deals AS d ON o.opportunity_id = d.opportunity_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' GROUP BY u.user_id ORDER BY win_rate DESC;" - } + "sql": "SELECT u.user_id, u.first_name, u.last_name, COUNT(DISTINCT d.deal_id) / COUNT(DISTINCT l.lead_id) * 100 AS win_rate FROM users AS u JOIN leads AS l ON u.user_id = l.owner_id LEFT JOIN opportunities AS o ON l.lead_id = o.lead_id LEFT JOIN deals AS d ON o.opportunity_id = d.opportunity_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' GROUP BY u.user_id ORDER BY win_rate DESC;", + }, ] diff --git a/api/extensions.py b/api/extensions.py index 0c40559d..42602e8f 100644 --- a/api/extensions.py +++ b/api/extensions.py @@ -1,12 +1,14 @@ -""" Extensions for the text2sql library """ +"""Extensions for the text2sql library""" + import os + from falkordb import FalkorDB # Connect to FalkorDB url = os.getenv("FALKORDB_URL", None) if url is None: try: - db = FalkorDB(host='localhost', port=6379) + db = FalkorDB(host="localhost", port=6379) except Exception as e: raise Exception(f"Failed to connect to FalkorDB: {e}") else: diff --git a/api/graph.py b/api/graph.py index 6632caef..c8938a1d 100644 --- a/api/graph.py +++ b/api/graph.py @@ -1,76 +1,95 @@ -""" Module to handle the graph data loading into the database. """ +"""Module to handle the graph data loading into the database.""" + import json import logging +from itertools import combinations from typing import List, Tuple + from litellm import completion from pydantic import BaseModel + from api.config import Config from api.extensions import db -from itertools import combinations -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) + class TableDescription(BaseModel): - """ Table Description """ + """Table Description""" + name: str description: str + class ColumnDescription(BaseModel): - """ Column Description """ + """Column Description""" + name: str description: str + class Descriptions(BaseModel): - """ List of tables """ + """List of tables""" + tables_descriptions: list[TableDescription] columns_descriptions: list[ColumnDescription] + def get_db_description(graph_id: str) -> str: - """ Get the database description from the graph. """ + """Get the database description from the graph.""" graph = db.select_graph(graph_id) - query_result = graph.query(""" + query_result = graph.query( + """ MATCH (d:Database) RETURN d.description """ ) - + if not query_result.result_set: return "No description available for this database." - + return query_result.result_set[0][0] # Return the first result's description + def find( - graph_id: str, - queries_history: List[str], - db_description: str = None + graph_id: str, queries_history: List[str], db_description: str = None ) -> Tuple[bool, List[dict]]: - """ Find the tables and columns relevant to the user's query. """ - + """Find the tables and columns relevant to the user's query.""" + graph = db.select_graph(graph_id) user_query = queries_history[-1] previous_queries = queries_history[:-1] - logging.info(f"Calling to an LLM to find relevant tables and columns for the query: {user_query}") + logging.info( + f"Calling to an LLM to find relevant tables and columns for the query: {user_query}" + ) # Call the completion model to get the relevant Cypher queries to retrieve # from the Graph that represent the Database schema. # The completion model will generate a set of Cypher query to retrieve the relevant nodes. - completion_result = completion(model=Config.COMPLETION_MODEL, - response_format=Descriptions, - messages=[ - { - "content": Config.FIND_SYSTEM_PROMPT.format(db_description=db_description), - "role": "system" - }, - { - "content": json.dumps({ - "previous_user_queries:": previous_queries, - "user_query": user_query - }), - "role": "user" - } - ], - temperature=0, - ) + completion_result = completion( + model=Config.COMPLETION_MODEL, + response_format=Descriptions, + messages=[ + { + "content": Config.FIND_SYSTEM_PROMPT.format( + db_description=db_description + ), + "role": "system", + }, + { + "content": json.dumps( + { + "previous_user_queries:": previous_queries, + "user_query": user_query, + } + ), + "role": "user", + }, + ], + temperature=0, + ) json_str = completion_result.choices[0].message.content @@ -80,7 +99,9 @@ def find( logging.info(f"Find tables based on: {descriptions.tables_descriptions}") tables_des = _find_tables(graph, descriptions.tables_descriptions) logging.info(f"Find tables based on columns: {descriptions.columns_descriptions}") - tables_by_columns_des = _find_tables_by_columns(graph, descriptions.columns_descriptions) + tables_by_columns_des = _find_tables_by_columns( + graph, descriptions.columns_descriptions + ) # table names for sphere and route extraction base_tables_names = [table[0] for table in tables_des] @@ -88,9 +109,16 @@ def find( tables_by_sphere = _find_tables_sphere(graph, base_tables_names) logging.info(f"Extracting tables by connecting routes {base_tables_names}") tables_by_route, _ = find_connecting_tables(graph, base_tables_names) - combined_tables = _get_unique_tables(tables_des + tables_by_columns_des + tables_by_route + tables_by_sphere) - - return True, combined_tables, [tables_des, tables_by_columns_des, tables_by_route, tables_by_sphere] + combined_tables = _get_unique_tables( + tables_des + tables_by_columns_des + tables_by_route + tables_by_sphere + ) + + return ( + True, + combined_tables, + [tables_des, tables_by_columns_des, tables_by_route, tables_by_sphere], + ) + def _find_tables(graph, descriptions: List[TableDescription]) -> List[dict]: @@ -99,7 +127,8 @@ def _find_tables(graph, descriptions: List[TableDescription]) -> List[dict]: # Get the table node from the graph embedding_result = Config.EMBEDDING_MODEL.embed(table.description) - query_result = graph.query(""" + query_result = graph.query( + """ CALL db.idx.vector.queryNodes( 'Table', 'embedding', @@ -115,20 +144,21 @@ def _find_tables(graph, descriptions: List[TableDescription]) -> List[dict]: nullable: columns.nullable }) """, - { - 'embedding': embedding_result[0] - }) + {"embedding": embedding_result[0]}, + ) for node in query_result.result_set: if node not in result: result.append(node) - + return result + def _find_tables_sphere(graph, tables: List[str]) -> List[dict]: result = [] for table_name in tables: - query_result = graph.query(""" + query_result = graph.query( + """ MATCH (node:Table {name: $name}) MATCH (node)-[:BELONGS_TO]-(column)-[:REFERENCES]-()-[:BELONGS_TO]-(table_ref) WITH table_ref @@ -141,9 +171,8 @@ def _find_tables_sphere(graph, tables: List[str]) -> List[dict]: nullable: columns.nullable }) """, - { - 'name': table_name - }) + {"name": table_name}, + ) for node in query_result.result_set: if node not in result: result.append(node) @@ -158,7 +187,8 @@ def _find_tables_by_columns(graph, descriptions: List[ColumnDescription]) -> Lis # Get the table node from the graph embedding_result = Config.EMBEDDING_MODEL.embed(column.description) - query_result = graph.query(""" + query_result = graph.query( + """ CALL db.idx.vector.queryNodes( 'Column', 'embedding', @@ -178,9 +208,8 @@ def _find_tables_by_columns(graph, descriptions: List[ColumnDescription]) -> Lis nullable: columns.nullable }) """, - { - 'embedding': embedding_result[0] - }) + {"embedding": embedding_result[0]}, + ) for node in query_result.result_set: if node not in result: @@ -188,35 +217,38 @@ def _find_tables_by_columns(graph, descriptions: List[ColumnDescription]) -> Lis return result + def _get_unique_tables(tables_list): # Dictionary to store unique tables with the table name as the key unique_tables = {} - + for table_info in tables_list: table_name = table_info[0] # The first element is the table name - + # Only add if this table name hasn't been seen before try: if table_name not in unique_tables: table_info[3] = [dict(od) for od in table_info[3]] - table_info[2] = 'Foreign keys: ' + table_info[2] + table_info[2] = "Foreign keys: " + table_info[2] unique_tables[table_name] = table_info except: print(f"Error: {table_info}") - + # Return the values (the unique table info lists) return list(unique_tables.values()) -def find_connecting_tables(graph, table_names: List[str]) -> Tuple[List[dict], List[str]]: +def find_connecting_tables( + graph, table_names: List[str] +) -> Tuple[List[dict], List[str]]: """ Find all tables that form connections between any pair of tables in the input list. Handles both Table nodes and Column nodes with primary keys. - + Args: graph: The FalkorDB graph database connection table_names: List of table names to check connections between - + Returns: A set of all table names that form connections between any pair in the input """ @@ -259,5 +291,5 @@ def find_connecting_tables(graph, table_names: List[str]) -> Tuple[List[dict], L target_table.foreign_keys AS foreign_keys, columns """ - result = graph.query(query, {'pairs': pair_params}, timeout=300).result_set - return result, None \ No newline at end of file + result = graph.query(query, {"pairs": pair_params}, timeout=300).result_set + return result, None diff --git a/api/helpers/crm_data_generator.py b/api/helpers/crm_data_generator.py index 289e97e8..ce4818b2 100644 --- a/api/helpers/crm_data_generator.py +++ b/api/helpers/crm_data_generator.py @@ -1,9 +1,12 @@ import json import os import time +from typing import Any, Dict, List, Optional, Set, Tuple + import requests -from typing import Dict, List, Any, Optional, Set, Tuple -from litellm import completion, validate_environment, utils as litellm_utils +from litellm import completion +from litellm import utils as litellm_utils +from litellm import validate_environment OUTPUT_FILE = "complete_crm_schema.json" MAX_RETRIES = 3 @@ -14,13 +17,14 @@ "primary_keys": {}, # table_name -> primary_key_column "foreign_keys": {}, # table_name -> {column_name -> (referenced_table, referenced_column)} "processed_tables": set(), # Set of tables that have been processed - "table_relationships": {} # table_name -> set of related tables + "table_relationships": {}, # table_name -> set of related tables } + def load_initial_schema(file_path: str) -> Dict[str, Any]: """Load the initial schema file with table names""" try: - with open(file_path, 'r') as file: + with open(file_path, "r") as file: schema = json.load(file) print(f"Loaded initial schema with {len(schema.get('tables', {}))} tables") return schema @@ -28,140 +32,157 @@ def load_initial_schema(file_path: str) -> Dict[str, Any]: print(f"Error loading schema file: {e}") return {"database": "crm_system", "tables": {}} + def save_schema(schema: Dict[str, Any], output_file: str = OUTPUT_FILE) -> None: """Save the current schema to a file with metadata""" # Add metadata if "metadata" not in schema: schema["metadata"] = {} - + schema["metadata"]["last_updated"] = time.strftime("%Y-%m-%d %H:%M:%S") schema["metadata"]["completed_tables"] = len(key_registry["processed_tables"]) schema["metadata"]["total_tables"] = len(schema.get("tables", {})) schema["metadata"]["key_registry"] = { "primary_keys": key_registry["primary_keys"], "foreign_keys": key_registry["foreign_keys"], - "table_relationships": {k: list(v) for k, v in key_registry["table_relationships"].items()} + "table_relationships": { + k: list(v) for k, v in key_registry["table_relationships"].items() + }, } - - with open(output_file, 'w') as file: + + with open(output_file, "w") as file: json.dump(schema, file, indent=2) print(f"Schema saved to {output_file}") + def update_key_registry(table_name: str, table_data: Dict[str, Any]) -> None: """Update the key registry with information from a processed table""" # Mark table as processed key_registry["processed_tables"].add(table_name) - + # Track primary keys if "columns" in table_data: for col_name, col_data in table_data["columns"].items(): if col_data.get("key") == "PRI": key_registry["primary_keys"][table_name] = col_name break - + # Track foreign keys and relationships if "foreign_keys" in table_data: if table_name not in key_registry["foreign_keys"]: key_registry["foreign_keys"][table_name] = {} - + if table_name not in key_registry["table_relationships"]: key_registry["table_relationships"][table_name] = set() - + for fk_name, fk_data in table_data["foreign_keys"].items(): column = fk_data.get("column") ref_table = fk_data.get("referenced_table") ref_column = fk_data.get("referenced_column") - + if column and ref_table and ref_column: - key_registry["foreign_keys"][table_name][column] = (ref_table, ref_column) - + key_registry["foreign_keys"][table_name][column] = ( + ref_table, + ref_column, + ) + # Update relationships key_registry["table_relationships"][table_name].add(ref_table) - + # Ensure the referenced table has an entry if ref_table not in key_registry["table_relationships"]: key_registry["table_relationships"][ref_table] = set() - + # Add the reverse relationship key_registry["table_relationships"][ref_table].add(table_name) + def find_related_tables(table_name: str, all_tables: List[str]) -> List[str]: """Find tables that might be related to the current table""" related = [] - + # Check registry first for already established relationships if table_name in key_registry["table_relationships"]: related.extend(key_registry["table_relationships"][table_name]) - + # Extract base name - base_parts = table_name.split('_') - + base_parts = table_name.split("_") + for other_table in all_tables: if other_table == table_name or other_table in related: continue - + # Direct naming relationship if table_name in other_table or other_table in table_name: related.append(other_table) continue - + # Check for common roots - other_parts = other_table.split('_') + other_parts = other_table.split("_") for part in base_parts: if part in other_parts and len(part) > 3: # Avoid short common words related.append(other_table) break - + return list(set(related)) # Remove duplicates -def get_table_prompt(table_name: str, schema: Dict[str, Any], all_table_names: List[str], topology) -> str: + +def get_table_prompt( + table_name: str, schema: Dict[str, Any], all_table_names: List[str], topology +) -> str: """Generate a prompt for the LLM to create a table schema with proper relationships""" existing_tables = schema.get("tables", {}) - + # Find related tables related_tables = find_related_tables(table_name, all_table_names) - related_tables_str = ", ".join(related_tables) if related_tables else "None identified yet" - + related_tables_str = ( + ", ".join(related_tables) if related_tables else "None identified yet" + ) + # Suggest primary key pattern - table_base = table_name.split('_')[0] if '_' in table_name else table_name + table_base = table_name.split("_")[0] if "_" in table_name else table_name suggested_pk = f"{table_name}_id" # Default pattern - + # Check if related tables have primary keys to follow same pattern for related in related_tables: if related in key_registry["primary_keys"]: related_pk = key_registry["primary_keys"][related] - if related_pk.endswith('_id') and related in related_pk: + if related_pk.endswith("_id") and related in related_pk: # Follow the same pattern suggested_pk = f"{table_name}_id" break - + # Prepare foreign key suggestions fk_suggestions = [] for related in related_tables: if related in key_registry["primary_keys"]: - fk_suggestions.append({ - "column": f"{related}_id", - "referenced_table": related, - "referenced_column": key_registry["primary_keys"][related] - }) - + fk_suggestions.append( + { + "column": f"{related}_id", + "referenced_table": related, + "referenced_column": key_registry["primary_keys"][related], + } + ) + fk_suggestions_str = "" if fk_suggestions: fk_suggestions_str = "Consider these foreign key relationships:\n" for i, fk in enumerate(fk_suggestions[:5]): # Limit to 5 suggestions fk_suggestions_str += f"{i+1}. {fk['column']} -> {fk['referenced_table']}.{fk['referenced_column']}\n" - + # Include examples of related tables that have been processed related_examples = "" example_count = 0 for related in related_tables: - if (related in existing_tables and - isinstance(existing_tables[related], dict) and - 'columns' in existing_tables[related] and - example_count < 2): + if ( + related in existing_tables + and isinstance(existing_tables[related], dict) + and "columns" in existing_tables[related] + and example_count < 2 + ): related_examples += f"\nRelated table example:\n```json\n{json.dumps({related: existing_tables[related]}, indent=2)}\n```\n" example_count += 1 - + # Use contacts table as primary example if no related examples found contacts_example = """ { @@ -264,7 +285,7 @@ def get_table_prompt(table_name: str, schema: Dict[str, Any], all_table_names: L """ # Create context about the table's purpose table_context = get_table_context(table_name, related_tables) - keys = json.dumps(topology['tables'][table_name]) + keys = json.dumps(topology["tables"][table_name]) prompt = f""" You are an expert database architect specializing in CRM systems. Create a detailed JSON schema for the '{table_name}' table in our CRM database. @@ -326,11 +347,12 @@ def get_table_prompt(table_name: str, schema: Dict[str, Any], all_table_names: L """ return prompt + def get_table_context(table_name: str, related_tables: List[str]) -> str: """Generate contextual information about a table based on its name and related tables""" # Extract words from table name - words = table_name.replace('_', ' ').split() - + words = table_name.replace("_", " ").split() + # Common CRM entities entities = { "contact": "Contains information about individuals", @@ -349,9 +371,9 @@ def get_table_context(table_name: str, related_tables: List[str]) -> str: "order": "Contains information about customer orders", "subscription": "Contains information about recurring subscriptions", "ticket": "Contains information about support tickets", - "campaign": "Contains information about marketing campaigns" + "campaign": "Contains information about marketing campaigns", } - + # Common relationship patterns relationship_patterns = { "tags": "This is a tagging or categorization table that likely links to various entities", @@ -369,17 +391,19 @@ def get_table_context(table_name: str, related_tables: List[str]) -> str: "attachments": "This contains file attachments", "performance": "This tracks performance metrics", "feedback": "This contains feedback information", - "settings": "This contains configuration settings" + "settings": "This contains configuration settings", } - + context = f"The '{table_name}' table appears to be " - + # Check if this is a junction/linking table - if "_" in table_name and not any(p in table_name for p in relationship_patterns.keys()): + if "_" in table_name and not any( + p in table_name for p in relationship_patterns.keys() + ): parts = table_name.split("_") if len(parts) == 2 and all(len(p) > 2 for p in parts): return f"This appears to be a junction table linking '{parts[0]}' and '{parts[1]}', likely with a many-to-many relationship." - + # Check for main entities for entity, description in entities.items(): if entity in words: @@ -387,56 +411,65 @@ def get_table_context(table_name: str, related_tables: List[str]) -> str: break else: context += "part of the CRM system. " - + # Check for relationship patterns for pattern, description in relationship_patterns.items(): if pattern in table_name: context += f"{description}. " break - + # Add related tables info if related_tables: context += f"It appears to be related to the following tables: {', '.join(related_tables)}. " - + # Guess if it's a child table for related in related_tables: if related in table_name and len(related) < len(table_name): - context += f"It may be a child or detail table for the {related} table. " + context += ( + f"It may be a child or detail table for the {related} table. " + ) break - + return context + def call_llm_api(prompt: str, retries: int = MAX_RETRIES) -> Optional[str]: """Call the LLM API with the given prompt, with retry logic""" for attempt in range(1, retries + 1): try: config = {} - config['temperature'] = 0.5 - config['response_format'] = { "type": "json_object" } - - + config["temperature"] = 0.5 + config["response_format"] = {"type": "json_object"} + response = completion( model="gemini/gemini-2.0-flash", messages=[{"role": "user", "content": prompt}], - **config + **config, + ) + result = ( + response.json() + .get("choices", [{}])[0] + .get("message", "") + .get("content", "") + .strip() ) - result = response.json().get("choices", [{}])[0].get("message", "").get("content", "").strip() if result: return result else: print(f"Empty response from API (attempt {attempt}/{retries})") - + except requests.exceptions.RequestException as e: print(f"API request error (attempt {attempt}/{retries}): {e}") - + if attempt < retries: sleep_time = RETRY_DELAY * attempt print(f"Retrying in {sleep_time} seconds...") time.sleep(sleep_time) - + print("All retry attempts failed") return None + def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any]]: """Parse the LLM response and extract the table schema with validation""" try: @@ -445,31 +478,33 @@ def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any response = response.split("```json")[1].split("```")[0].strip() elif "```" in response: response = response.split("```")[1].strip() - + # Handle common formatting issues - response = response.replace('\n', ' ').replace('\r', ' ') - + response = response.replace("\n", " ").replace("\r", " ") + # Cleanup any trailing/leading text - start_idx = response.find('{') - end_idx = response.rfind('}') + 1 + start_idx = response.find("{") + end_idx = response.rfind("}") + 1 if start_idx >= 0 and end_idx > start_idx: response = response[start_idx:end_idx] - + parsed = json.loads(response) - + # Validation of required components if table_name in parsed: table_data = parsed[table_name] required_keys = ["description", "columns", "indexes", "foreign_keys"] - + # Check if all required sections exist if all(key in table_data for key in required_keys): # Verify columns have required attributes for col_name, col_data in table_data["columns"].items(): required_col_attrs = ["description", "type", "null"] if not all(attr in col_data for attr in required_col_attrs): - print(f"Warning: Column {col_name} is missing required attributes") - + print( + f"Warning: Column {col_name} is missing required attributes" + ) + return {table_name: table_data} else: missing = [key for key in required_keys if key not in table_data] @@ -478,77 +513,83 @@ def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any else: # Try to get the first key if table_name is not found first_key = next(iter(parsed)) - print(f"Warning: Table name mismatch. Expected {table_name}, got {first_key}") + print( + f"Warning: Table name mismatch. Expected {table_name}, got {first_key}" + ) return {table_name: parsed[first_key]} except Exception as e: print(f"Error parsing LLM response for {table_name}: {e}") print(f"Raw response: {response[:500]}...") # Show first 500 chars return None -def process_table(table_name: str, schema: Dict[str, Any], all_table_names: List[str], topology) -> Dict[str, Any]: + +def process_table( + table_name: str, schema: Dict[str, Any], all_table_names: List[str], topology +) -> Dict[str, Any]: """Process a single table and update the schema""" print(f"Processing table: {table_name}") - + # Skip if table already has detailed schema - if (table_name in schema["tables"] and - isinstance(schema["tables"][table_name], dict) and - "columns" in schema["tables"][table_name] and - "indexes" in schema["tables"][table_name] and - "foreign_keys" in schema["tables"][table_name]): + if ( + table_name in schema["tables"] + and isinstance(schema["tables"][table_name], dict) + and "columns" in schema["tables"][table_name] + and "indexes" in schema["tables"][table_name] + and "foreign_keys" in schema["tables"][table_name] + ): print(f"Table {table_name} already processed. Skipping.") return schema - + # Generate prompt for this table prompt = get_table_prompt(table_name, schema["tables"], all_table_names, topology) - + # Call LLM API response = call_llm_api(prompt) if not response: print(f"Failed to get response for {table_name}. Skipping.") return schema - + # Parse response table_schema = parse_llm_response(response, table_name) if not table_schema: print(f"Failed to parse response for {table_name}. Skipping.") return schema - + # Update schema schema["tables"].update(table_schema) print(f"Successfully processed {table_name}") - + # Save intermediate results save_schema(schema, f"intermediate_{table_name.replace('/', '_')}.json") - + return schema + def main(): # Load the initial schema with table names - initial_schema_path = "examples/crm_tables.json" # Replace with your actual file path + initial_schema_path = ( + "examples/crm_tables.json" # Replace with your actual file path + ) initial_schema = load_initial_schema(initial_schema_path) - + # Get the list of tables to process tables = list(initial_schema.get("tables", {}).keys()) all_table_names = tables.copy() # Keep a full list for reference - - topology = generate_keys(tables) + topology = generate_keys(tables) # Initialize our working schema - schema = { - "database": initial_schema.get("database", "crm_system"), - "tables": {} - } - + schema = {"database": initial_schema.get("database", "crm_system"), "tables": {}} + # If we have existing work, load it if os.path.exists(OUTPUT_FILE): try: - with open(OUTPUT_FILE, 'r') as file: + with open(OUTPUT_FILE, "r") as file: schema = json.load(file) print(f"Loaded existing schema from {OUTPUT_FILE}") except Exception as e: print(f"Error loading existing schema: {e}") - + # Prioritize tables to process - process base tables first def table_priority(table_name): # Base tables should be processed first @@ -559,38 +600,41 @@ def table_priority(table_name): return 2 # Related tables in the middle return 1 - + # Sort tables by priority tables.sort(key=table_priority) - + # Process tables for i, table_name in enumerate(tables): - print(f"\nProcessing table {i+1}/{len(tables)}: {table_name} (Priority: {table_priority(table_name)})") + print( + f"\nProcessing table {i+1}/{len(tables)}: {table_name} (Priority: {table_priority(table_name)})" + ) schema = process_table(table_name, schema, all_table_names, topology) - + # Save progress after each table save_schema(schema) - + # Add delay to avoid rate limits if i < len(tables) - 1: delay = 2 + (0.5 * i % 5) # Varied delay to help avoid pattern detection print(f"Waiting {delay} seconds before next request...") time.sleep(delay) - + print(f"\nCompleted processing all {len(tables)} tables") print(f"Final schema saved to {OUTPUT_FILE}") - + # Validate the final schema validate_schema(schema) + def generate_keys(tables) -> Dict[str, Any]: path = "examples/crm_topology.json" # If we have existing work, load it if os.path.exists(path): try: - with open(path, 'r') as file: + with open(path, "r") as file: schema = json.load(file) - last_key = tables.index(list(schema['tables'].keys())[-1]) + last_key = tables.index(list(schema["tables"].keys())[-1]) print(f"Loaded existing schema from {path}") except Exception as e: print(f"Error loading existing schema: {e}") @@ -618,9 +662,9 @@ def generate_keys(tables) -> Dict[str, Any]: p = prompt.format(table_name=table, tables=tables) response = call_llm_api(p) new_table = json.loads(response) - schema['tables'].update(new_table) + schema["tables"].update(new_table) - with open(path, 'w') as file: + with open(path, "w") as file: json.dump(schema, file, indent=2) print(f"Schema saved to {path}") print(f"Final schema saved to {path}") @@ -632,27 +676,32 @@ def validate_schema(schema: Dict[str, Any]) -> None: """Perform final validation on the complete schema""" print("\nValidating schema...") issues = [] - + table_count = len(schema["tables"]) - tables_with_columns = sum(1 for t in schema["tables"].values() - if isinstance(t, dict) and "columns" in t) - tables_with_indexes = sum(1 for t in schema["tables"].values() - if isinstance(t, dict) and "indexes" in t) - tables_with_foreign_keys = sum(1 for t in schema["tables"].values() - if isinstance(t, dict) and "foreign_keys" in t) - + tables_with_columns = sum( + 1 for t in schema["tables"].values() if isinstance(t, dict) and "columns" in t + ) + tables_with_indexes = sum( + 1 for t in schema["tables"].values() if isinstance(t, dict) and "indexes" in t + ) + tables_with_foreign_keys = sum( + 1 + for t in schema["tables"].values() + if isinstance(t, dict) and "foreign_keys" in t + ) + print(f"Total tables: {table_count}") print(f"Tables with columns: {tables_with_columns}") print(f"Tables with indexes: {tables_with_indexes}") print(f"Tables with foreign keys: {tables_with_foreign_keys}") - + # Check if all tables have required sections incomplete_tables = [] for table_name, table_data in schema["tables"].items(): if not isinstance(table_data, dict): incomplete_tables.append(f"{table_name} (empty)") continue - + missing = [] if "description" not in table_data or not table_data["description"]: missing.append("description") @@ -662,10 +711,10 @@ def validate_schema(schema: Dict[str, Any]) -> None: missing.append("indexes") if "foreign_keys" not in table_data: # Can be empty, just needs to exist missing.append("foreign_keys") - + if missing: incomplete_tables.append(f"{table_name} (missing: {', '.join(missing)})") - + if incomplete_tables: issues.append(f"Incomplete tables: {len(incomplete_tables)}") print("Incomplete tables:") @@ -673,26 +722,35 @@ def validate_schema(schema: Dict[str, Any]) -> None: print(f" - {table}") if len(incomplete_tables) > 10: print(f" ... and {len(incomplete_tables) - 10} more") - + # Check foreign key references invalid_fks = [] for table_name, table_data in schema["tables"].items(): if not isinstance(table_data, dict) or "foreign_keys" not in table_data: continue - + for fk_name, fk_data in table_data["foreign_keys"].items(): ref_table = fk_data.get("referenced_table") ref_column = fk_data.get("referenced_column") - + if ref_table and ref_table not in schema["tables"]: - invalid_fks.append(f"{table_name}.{fk_name} -> {ref_table} (table not found)") + invalid_fks.append( + f"{table_name}.{fk_name} -> {ref_table} (table not found)" + ) elif ref_table and ref_column: ref_table_data = schema["tables"].get(ref_table, {}) - if not isinstance(ref_table_data, dict) or "columns" not in ref_table_data: - invalid_fks.append(f"{table_name}.{fk_name} -> {ref_table} (no columns)") + if ( + not isinstance(ref_table_data, dict) + or "columns" not in ref_table_data + ): + invalid_fks.append( + f"{table_name}.{fk_name} -> {ref_table} (no columns)" + ) elif ref_column not in ref_table_data.get("columns", {}): - invalid_fks.append(f"{table_name}.{fk_name} -> {ref_table}.{ref_column} (column not found)") - + invalid_fks.append( + f"{table_name}.{fk_name} -> {ref_table}.{ref_column} (column not found)" + ) + if invalid_fks: issues.append(f"Invalid foreign keys: {len(invalid_fks)}") print("Invalid foreign keys:") @@ -700,11 +758,12 @@ def validate_schema(schema: Dict[str, Any]) -> None: print(f" - {fk}") if len(invalid_fks) > 10: print(f" ... and {len(invalid_fks) - 10} more") - + if issues: print(f"\nValidation complete. Found {len(issues)} issue types.") else: print("\nValidation complete. No issues found!") + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/api/index.py b/api/index.py index bc7c3ac8..2fd62936 100644 --- a/api/index.py +++ b/api/index.py @@ -1,46 +1,59 @@ -""" This module contains the routes for the text2sql API. """ +"""This module contains the routes for the text2sql API.""" + import json -import os import logging +import os +import random +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import TimeoutError as FuturesTimeoutError from functools import wraps + from dotenv import load_dotenv -from flask import Blueprint, Response, jsonify, render_template, request, stream_with_context, Flask -from concurrent.futures import ThreadPoolExecutor, TimeoutError as FuturesTimeoutError -from api.graph import find, get_db_description +from flask import (Blueprint, Flask, Response, jsonify, render_template, + request, stream_with_context) + +from api.agents import AnalysisAgent, RelevancyAgent +from api.constants import BENCHMARK, EXAMPLES from api.extensions import db +from api.graph import find, get_db_description from api.loaders.csv_loader import CSVLoader from api.loaders.json_loader import JSONLoader from api.loaders.odata_loader import ODataLoader -from api.agents import RelevancyAgent, AnalysisAgent -from api.constants import BENCHMARK, EXAMPLES -import random # Load environment variables from .env file load_dotenv() -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) # Use the same delimiter as in the JavaScript -MESSAGE_DELIMITER = '|||FALKORDB_MESSAGE_BOUNDARY|||' +MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||" main = Blueprint("main", __name__) -SECRET_TOKEN = os.getenv('SECRET_TOKEN') -SECRET_TOKEN_ERP = os.getenv('SECRET_TOKEN_ERP') +SECRET_TOKEN = os.getenv("SECRET_TOKEN") +SECRET_TOKEN_ERP = os.getenv("SECRET_TOKEN_ERP") + + def verify_token(token): - """ Verify the token provided in the request """ + """Verify the token provided in the request""" return token == SECRET_TOKEN or token == SECRET_TOKEN_ERP or token == "null" + def token_required(f): - """ Decorator to protect routes with token authentication """ + """Decorator to protect routes with token authentication""" + @wraps(f) def decorated_function(*args, **kwargs): - token = request.args.get('token', 'null') # Get token from header + token = request.args.get("token", "null") # Get token from header os.environ["USER_TOKEN"] = token if not verify_token(token): return jsonify(message="Unauthorized"), 401 return f(*args, **kwargs) + return decorated_function + app = Flask(__name__) # @app.before_request @@ -53,13 +66,15 @@ def decorated_function(*args, **kwargs): # # Optional: require it for protected routes # pass -@app.route('/') + +@app.route("/") @token_required # Apply token authentication decorator def home(): - """ Home route """ - return render_template('chat.html') + """Home route""" + return render_template("chat.html") -@app.route('/graphs') + +@app.route("/graphs") @token_required # Apply token authentication decorator def graphs(): """ @@ -67,25 +82,26 @@ def graphs(): """ graphs = db.list_graphs() if os.getenv("USER_TOKEN") == SECRET_TOKEN: - if 'hospital' in graphs: - return ['hospital'] + if "hospital" in graphs: + return ["hospital"] else: return [] - + if os.getenv("USER_TOKEN") == SECRET_TOKEN_ERP: - if 'ERP_system' in graphs: - return ['ERP_system'] + if "ERP_system" in graphs: + return ["ERP_system"] else: - return ['crm_usecase'] + return ["crm_usecase"] elif os.getenv("USER_TOKEN") == "null": - if 'crm_usecase' in graphs: - return ['crm_usecase'] + if "crm_usecase" in graphs: + return ["crm_usecase"] else: return [] else: - graphs.remove('hospital') + graphs.remove("hospital") return graphs + @app.route("/graphs", methods=["POST"]) @token_required # Apply token authentication decorator def load(): @@ -110,7 +126,9 @@ def load(): success, result = JSONLoader.load(graph_id, data) # ✅ Handle XML Payload - elif content_type.startswith("application/xml") or content_type.startswith("text/xml"): + elif content_type.startswith("application/xml") or content_type.startswith( + "text/xml" + ): xml_data = request.data graph_id = "" success, result = ODataLoader.load(graph_id, xml_data) @@ -144,13 +162,13 @@ def load(): xml_data = file.read().decode("utf-8") # Convert bytes to string graph_id = file.filename.replace(".xml", "") success, result = ODataLoader.load(graph_id, xml_data) - + # ✅ Check if file is csv elif file.filename.endswith(".csv"): csv_data = file.read().decode("utf-8") # Convert bytes to string graph_id = file.filename.replace(".csv", "") success, result = CSVLoader.load(graph_id, csv_data) - + else: return jsonify({"error": "Unsupported file type"}), 415 else: @@ -162,6 +180,7 @@ def load(): return jsonify({"error": result}), 400 + @app.route("/graphs/", methods=["POST"]) @token_required # Apply token authentication decorator def query(graph_id: str): @@ -174,7 +193,7 @@ def query(graph_id: str): instructions = request_data.get("instructions") if not queries_history: return jsonify({"error": "Invalid or missing JSON data"}), 400 - + logging.info(f"User Query: {queries_history[-1]}") # Create a generator function for streaming @@ -182,47 +201,68 @@ def generate(): agent_rel = RelevancyAgent(queries_history, result_history) agent_an = AnalysisAgent(queries_history, result_history) - step = {"type": "reasoning_step", "message": "Step 1: Analyzing the user query"} yield json.dumps(step) + MESSAGE_DELIMITER - db_description = get_db_description(graph_id) # Ensure the database description is loaded - + db_description = get_db_description( + graph_id + ) # Ensure the database description is loaded + logging.info(f"Calling to relvancy agent with query: {queries_history[-1]}") answer_rel = agent_rel.get_answer(queries_history[-1], db_description) if answer_rel["status"] != "On-topic": - step = {"type": "followup_questions", "message": "Off topic question: " + answer_rel["reason"]} + step = { + "type": "followup_questions", + "message": "Off topic question: " + answer_rel["reason"], + } logging.info(f"SQL Fail reason: {answer_rel["reason"]}") yield json.dumps(step) + MESSAGE_DELIMITER else: # Use a thread pool to enforce timeout with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(find, graph_id, queries_history, db_description) + future = executor.submit( + find, graph_id, queries_history, db_description + ) try: success, result, _ = future.result(timeout=120) except FuturesTimeoutError: - yield json.dumps({"type": "error", "message": "Timeout error while finding tables relevant to your request."}) + MESSAGE_DELIMITER + yield json.dumps( + { + "type": "error", + "message": "Timeout error while finding tables relevant to your request.", + } + ) + MESSAGE_DELIMITER return except Exception as e: logging.info(f"Error in find function: {e}") - yield json.dumps({"type": "error", "message": "Error in find function"}) + MESSAGE_DELIMITER + yield json.dumps( + {"type": "error", "message": "Error in find function"} + ) + MESSAGE_DELIMITER return - step = {"type": "reasoning_step", - "message": "Step 2: Generating SQL query"} + step = {"type": "reasoning_step", "message": "Step 2: Generating SQL query"} yield json.dumps(step) + MESSAGE_DELIMITER logging.info(f"Calling to analysis agent with query: {queries_history[-1]}") - answer_an = agent_an.get_analysis(queries_history[-1], result, db_description, instructions) + answer_an = agent_an.get_analysis( + queries_history[-1], result, db_description, instructions + ) logging.info(f"SQL Result: {answer_an['sql_query']}") - yield json.dumps({"type": "final_result", "data": answer_an['sql_query'], "conf": answer_an['confidence'], - "miss": answer_an['missing_information'], - "amb": answer_an['ambiguities'], - "exp": answer_an['explanation'], - "is_valid": answer_an['is_sql_translatable']}) + MESSAGE_DELIMITER + yield json.dumps( + { + "type": "final_result", + "data": answer_an["sql_query"], + "conf": answer_an["confidence"], + "miss": answer_an["missing_information"], + "amb": answer_an["ambiguities"], + "exp": answer_an["explanation"], + "is_valid": answer_an["is_sql_translatable"], + } + ) + MESSAGE_DELIMITER + + return Response(stream_with_context(generate()), content_type="application/json") - return Response(stream_with_context(generate()), content_type='application/json') -@app.route('/suggestions') +@app.route("/suggestions") @token_required # Apply token authentication decorator def suggestions(): """ @@ -232,25 +272,28 @@ def suggestions(): """ try: # Get graph_id from query parameters - graph_id = request.args.get('graph_id', '') - + graph_id = request.args.get("graph_id", "") + if not graph_id: return jsonify([]), 400 - + # Check if graph has specific examples if graph_id in EXAMPLES: graph_examples = EXAMPLES[graph_id] # Return up to 3 examples, or all if less than 3 - suggestion_questions = random.sample(graph_examples, min(3, len(graph_examples))) + suggestion_questions = random.sample( + graph_examples, min(3, len(graph_examples)) + ) return jsonify(suggestion_questions) else: # If graph doesn't exist in EXAMPLES, return empty list return jsonify([]) - + except Exception as e: logging.error(f"Error fetching suggestions: {e}") return jsonify([]), 500 + if __name__ == "__main__": app.register_blueprint(main) app.run(debug=True) diff --git a/api/loaders/base_loader.py b/api/loaders/base_loader.py index 72d2b20b..7b46649b 100644 --- a/api/loaders/base_loader.py +++ b/api/loaders/base_loader.py @@ -1,6 +1,7 @@ from abc import ABC from typing import Tuple + class BaseLoader(ABC): @staticmethod diff --git a/api/loaders/csv_loader.py b/api/loaders/csv_loader.py index f0913e91..3a7569e1 100644 --- a/api/loaders/csv_loader.py +++ b/api/loaders/csv_loader.py @@ -1,11 +1,13 @@ -from typing import Tuple, Dict, List import io +from collections import defaultdict +from typing import Dict, List, Tuple + # import pandas as pd import tqdm -from collections import defaultdict from litellm import embedding -from api.loaders.base_loader import BaseLoader + from api.extensions import db +from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph @@ -14,25 +16,36 @@ class CSVLoader(BaseLoader): def load(graph_id: str, data) -> Tuple[bool, str]: """ Load the data dictionary CSV file into the graph database. - + Args: graph_id: The ID of the graph to load the data into data: CSV file - + Returns: Tuple of (success, message) """ raise NotImplementedError("CSVLoader is not implemented yet") try: # Parse CSV data using pandas for better handling of large files - df = pd.read_csv(io.StringIO(data), encoding='utf-8') + df = pd.read_csv(io.StringIO(data), encoding="utf-8") # Check if required columns exist - required_columns = ['Schema', 'Domain', 'Field', 'Type', 'Description', 'Related', 'Cardinality'] + required_columns = [ + "Schema", + "Domain", + "Field", + "Type", + "Description", + "Related", + "Cardinality", + ] missing_columns = [col for col in required_columns if col not in df.columns] - + if missing_columns: - return False, f"Missing required columns in CSV: {', '.join(missing_columns)}" + return ( + False, + f"Missing required columns in CSV: {', '.join(missing_columns)}", + ) db_name = """Abacus Domain Model 25.3.5 The Abacus Domain Model is a physical manifestation of the hierarchical object model that Abacus Insights uses to store data. (It is not a relational database.) It is a foundational aspect of @@ -41,62 +54,82 @@ def load(graph_id: str, data) -> Tuple[bool, str]: The Abacus Domain Model is organized into schemas, which group related domains. We implement each domain as a broad structure with minimal nesting. The model avoids inheritance and deep nesting to minimize complexity and optimize performance.""" - # Process data by grouping by Schema and Domain to identify tables # Group by Schema and Domain to get tables - tables = defaultdict(lambda: { - 'description': '', - 'columns': {}, - # 'relationships': [], - 'col_descriptions': [] - }) - - rel_table = defaultdict(lambda: { - 'primary_key_table': '', - 'fk_tables': [] - }) + tables = defaultdict( + lambda: { + "description": "", + "columns": {}, + # 'relationships': [], + "col_descriptions": [], + } + ) + + rel_table = defaultdict(lambda: {"primary_key_table": "", "fk_tables": []}) relationships = {} # First pass: Organize data into tables - for idx, row in tqdm.tqdm(df.iterrows(), total=len(df), desc="Organizing data"): - schema = row['Schema'] - domain = row['Domain'] + for idx, row in tqdm.tqdm( + df.iterrows(), total=len(df), desc="Organizing data" + ): + schema = row["Schema"] + domain = row["Domain"] table_name = f"{schema}.{domain}" - + # Set table description (use Domain Description if available) - if 'Domain Description' in row and not pd.isna(row['Domain Description']) and not tables[table_name]['description']: - tables[table_name]['description'] = row['Domain Description'] - + if ( + "Domain Description" in row + and not pd.isna(row["Domain Description"]) + and not tables[table_name]["description"] + ): + tables[table_name]["description"] = row["Domain Description"] + # Add column information - field = row['Field'] - field_type = row['Type'] if not pd.isna(row['Type']) else 'STRING' - field_desc = row['Description'] if not pd.isna(row['Description']) else field - - nullable = True # Default to nullable since we don't have explicit null info + field = row["Field"] + field_type = row["Type"] if not pd.isna(row["Type"]) else "STRING" + field_desc = ( + row["Description"] if not pd.isna(row["Description"]) else field + ) + + nullable = ( + True # Default to nullable since we don't have explicit null info + ) if not pd.isna(field): - tables[table_name]['col_descriptions'].append(field_desc) - tables[table_name]['columns'][field] = { - 'type': field_type, - 'description': field_desc, - 'null': nullable, - 'key': 'PRI' if field.lower().endswith('_id') else '', # Assumption: *_id fields are primary keys - 'default': '', - 'extra': '' + tables[table_name]["col_descriptions"].append(field_desc) + tables[table_name]["columns"][field] = { + "type": field_type, + "description": field_desc, + "null": nullable, + "key": ( + "PRI" if field.lower().endswith("_id") else "" + ), # Assumption: *_id fields are primary keys + "default": "", + "extra": "", } - + # Add relationship information if available - if not pd.isna(row['Related']) and not pd.isna(row['Cardinality']): + if not pd.isna(row["Related"]) and not pd.isna(row["Cardinality"]): source_field = field - target_table = row['Related'] + target_table = row["Related"] # cardinality = row['Cardinality'] if table_name not in relationships: relationships[table_name] = [] - relationships[table_name].append({"from": table_name, - "to": target_table, - "source_column": source_field, - "target_column": df.to_dict("records")[idx+1]['Array Field'] if not pd.isna(df.to_dict("records")[idx+1]['Array Field']) else '', - "note": ""}) - + relationships[table_name].append( + { + "from": table_name, + "to": target_table, + "source_column": source_field, + "target_column": ( + df.to_dict("records")[idx + 1]["Array Field"] + if not pd.isna( + df.to_dict("records")[idx + 1]["Array Field"] + ) + else "" + ), + "note": "", + } + ) + # tables[table_name]['relationships'].append({ # 'source_field': source_field, # 'target_table': target_table, @@ -106,95 +139,110 @@ def load(graph_id: str, data) -> Tuple[bool, str]: tables[target_table]["description"] = field_desc else: - field = row['Array Field'] + field = row["Array Field"] field_desc = field_desc if not pd.isna(field_desc) else field # if len(tables[target_table]['col_descriptions']) == 0: # tables[table_name]['relationships'][-1]['target_field'] = field - tables[target_table]['col_descriptions'].append(field_desc) - tables[target_table]['columns'][field] = { - 'type': field_type, - 'description': field_desc, - 'null': nullable, - 'key': 'PRI' if field.lower().endswith('_id') else '', # Assumption: *_id fields are primary keys - 'default': '', - 'extra': '' + tables[target_table]["col_descriptions"].append(field_desc) + tables[target_table]["columns"][field] = { + "type": field_type, + "description": field_desc, + "null": nullable, + "key": ( + "PRI" if field.lower().endswith("_id") else "" + ), # Assumption: *_id fields are primary keys + "default": "", + "extra": "", } - if field.endswith('_id'): - if len(tables[table_name]['columns']) == 1 and field.endswith('_id'): + if field.endswith("_id"): + if len(tables[table_name]["columns"]) == 1 and field.endswith( + "_id" + ): suspected_primary_key = field[:-3] if suspected_primary_key in domain: - rel_table[field]['primary_key_table'] = table_name + rel_table[field]["primary_key_table"] = table_name else: - rel_table[field]['fk_tables'].append(table_name) + rel_table[field]["fk_tables"].append(table_name) else: - rel_table[field]['fk_tables'].append(table_name) + rel_table[field]["fk_tables"].append(table_name) - for key, tables_info in tqdm.tqdm(rel_table.items(), desc="Creating relationships from names"): - if len(tables_info['fk_tables']) > 0: - fk_tables = list(set(tables_info['fk_tables'])) - if len(tables_info['primary_key_table']) > 0: + for key, tables_info in tqdm.tqdm( + rel_table.items(), desc="Creating relationships from names" + ): + if len(tables_info["fk_tables"]) > 0: + fk_tables = list(set(tables_info["fk_tables"])) + if len(tables_info["primary_key_table"]) > 0: for table in fk_tables: if table not in relationships: relationships[table_name] = [] - relationships[table].append({"from": table, - "to": tables_info['primary_key_table'], - "source_column": key, - "target_column": key, - "note": 'many-one'}) + relationships[table].append( + { + "from": table, + "to": tables_info["primary_key_table"], + "source_column": key, + "target_column": key, + "note": "many-one", + } + ) else: for table_1 in fk_tables: for table_2 in fk_tables: if table_1 != table_2: if table_1 not in relationships: relationships[table_1] = [] - relationships[table_1].append({"from": table_1, - "to": table_2, - "source_column": key, - "target_column": key, - "note": 'many-many'}) - + relationships[table_1].append( + { + "from": table_1, + "to": table_2, + "source_column": key, + "target_column": key, + "note": "many-many", + } + ) + load_to_graph(graph_id, tables, relationships, db_name=db_name) return True, "Data dictionary loaded successfully into graph" - + except Exception as e: return False, f"Error loading CSV: {str(e)}" - # else: - # # For case 2: when no primary key table exists, connect all FK tables to each other - # graph.query( - # """ - # CREATE (src: Column {name: $col, cardinality: $cardinality}) - # """, - # { - # 'col': key, - # 'cardinality': 'many-many' - # } - # ) - # for i in range(len(fk_tables)): - # graph.query( - # """ - # MATCH (src:Column {name: $source_col}) - # -[:BELONGS_TO]->(source:Table {name: $source_table}) - # MATCH (tgt:Column {name: $target_col, cardinality: $cardinality}) - # CREATE (src)-[:REFERENCES { - # constraint_name: $fk_name, - # cardinality: $cardinality - # }]->(tgt) - # """, - # { - # 'source_col': key, - # 'target_col': key, - # 'source_table': fk_tables[i], - # 'fk_name': key, - # 'cardinality': 'many-many' - # } - # ) + # else: + # # For case 2: when no primary key table exists, connect all FK tables to each other + # graph.query( + # """ + # CREATE (src: Column {name: $col, cardinality: $cardinality}) + # """, + # { + # 'col': key, + # 'cardinality': 'many-many' + # } + # ) + # for i in range(len(fk_tables)): + # graph.query( + # """ + # MATCH (src:Column {name: $source_col}) + # -[:BELONGS_TO]->(source:Table {name: $source_table}) + # MATCH (tgt:Column {name: $target_col, cardinality: $cardinality}) + # CREATE (src)-[:REFERENCES { + # constraint_name: $fk_name, + # cardinality: $cardinality + # }]->(tgt) + # """, + # { + # 'source_col': key, + # 'target_col': key, + # 'source_table': fk_tables[i], + # 'fk_name': key, + # 'cardinality': 'many-many' + # } + # ) + # # Second pass: Create table nodes # for table_name, table_info in tqdm.tqdm(tables.items(), desc="Creating Table nodes"): # # Skip if no columns (probably just a reference) # if not table_info['columns']: # continue - + # # Generate embedding for table description # table_desc = table_info['description'] # embedding_result = client.models.embed_content( @@ -206,7 +254,7 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # graph.query( # """ # CREATE (t:Table { -# name: $table_name, +# name: $table_name, # description: $description, # embedding: vecf32($embedding) # }) @@ -224,23 +272,23 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # for batch in tqdm.tqdm( # [col_descriptions[i:i + batch_size] for i in range(0, len(col_descriptions), batch_size)], # desc=f"Creating embeddings for {table_name}"): - + # embedding_result = embedding(model='bedrock/cohere.embed-english-v3', input=batch[:95], aws_profile_name=Config.AWS_PROFILE, aws_region_name=Config.AWS_REGION) # embed_columns.extend([emb.values for emb in embedding_result.embeddings]) # except Exception as e: # print(f"Error creating embeddings: {str(e)}") - + # # Create column nodes # for idx, (col_name, col_info) in tqdm.tqdm(enumerate(table_info['columns'].items()), desc=f"Creating columns for {table_name}", total=len(table_info['columns'])): # # embedding_result = embedding( # # model=Config.EMBEDDING_MODEL, # # input=[col_info['description'] if col_info['description'] else col_name] # # ) - + # ## Temp # # agent_tax = TaxonomyAgent() # # tax = agent_tax.get_answer(col_name, col_info) -# # # +# # # # graph.query( # """ # MATCH (t:Table {name: $table_name}) @@ -267,7 +315,7 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # 'embedding': embed_columns[idx] # } # ) - + # # Third pass: Create relationships # for table_name, table_info in tqdm.tqdm(tables.items(), desc="Creating relationships"): # for rel in table_info['relationships']: @@ -277,7 +325,7 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # target_field = rel['target_field']#list(tables[tables[table_name]['relationships'][-1]['target_table']]['columns'].keys())[0] # # Create constraint name # constraint_name = f"fk_{table_name.replace('.', '_')}_{source_field}_to_{target_table.replace('.', '_')}" - + # # Create relationship if both tables and columns exist # try: # graph.query( @@ -359,15 +407,15 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # } # ) - # load_to_graph(graph_id, entities, relationships, db_name="ERP system") - # return True, "Data dictionary loaded successfully into graph" - - # except Exception as e: - # return False, f"Error loading CSV: {str(e)}" +# load_to_graph(graph_id, entities, relationships, db_name="ERP system") +# return True, "Data dictionary loaded successfully into graph" + +# except Exception as e: +# return False, f"Error loading CSV: {str(e)}" # if __name__ == "__main__": # # Example usage # loader = CSVLoader() # success, message = loader.load("my_graph", "Data Dictionary.csv") -# print(message) \ No newline at end of file +# print(message) diff --git a/api/loaders/graph_loader.py b/api/loaders/graph_loader.py index eb551d7b..b02918e9 100644 --- a/api/loaders/graph_loader.py +++ b/api/loaders/graph_loader.py @@ -1,14 +1,23 @@ +import json + import tqdm + from api.config import Config from api.extensions import db from api.utils import generate_db_description -import json -def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size: int=100, db_name: str="TBD") -> None: + +def load_to_graph( + graph_id: str, + entities: dict, + relationships: dict, + batch_size: int = 100, + db_name: str = "TBD", +) -> None: """ Load the graph data into the database. It gets the Graph name as an argument and expects - + Input: - entities: A dictionary containing the entities and their attributes. - relationships: A dictionary containing the relationships between entities. @@ -19,33 +28,28 @@ def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size embedding_model = Config.EMBEDDING_MODEL vec_len = embedding_model.get_vector_size() - try: + try: # Create vector indices - graph.query(""" + graph.query( + """ CREATE VECTOR INDEX FOR (t:Table) ON (t.embedding) OPTIONS {dimension:$size, similarityFunction:'euclidean'} """, - { - 'size': vec_len - }) - - graph.query(""" + {"size": vec_len}, + ) + + graph.query( + """ CREATE VECTOR INDEX FOR (c:Column) ON (c.embedding) OPTIONS {dimension:$size, similarityFunction:'euclidean'} """, - { - 'size': vec_len - }) + {"size": vec_len}, + ) graph.query("CREATE INDEX FOR (p:Table) ON (p.name)") except Exception as e: print(f"Error creating vector indices: {str(e)}") - - - db_des = generate_db_description( - db_name=db_name, - table_names=list(entities.keys()) - ) + db_des = generate_db_description(db_name=db_name, table_names=list(entities.keys())) graph.query( """ CREATE (d:Database { @@ -53,17 +57,15 @@ def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size description: $description }) """, - { - 'db_name': db_name, - 'description': db_des - } + {"db_name": db_name, "description": db_des}, ) - - for table_name, table_info in tqdm.tqdm(entities.items(), desc="Creating Graph Table Nodes"): - table_desc = table_info['description'] + for table_name, table_info in tqdm.tqdm( + entities.items(), desc="Creating Graph Table Nodes" + ): + table_desc = table_info["description"] embedding_result = embedding_model.embed(table_desc) - fk = json.dumps(table_info.get('foreign_keys', [])) + fk = json.dumps(table_info.get("foreign_keys", [])) # Create table node graph.query( @@ -76,38 +78,45 @@ def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size }) """, { - 'table_name': table_name, - 'description': table_desc, - 'embedding': embedding_result[0], - 'foreign_keys': fk - } + "table_name": table_name, + "description": table_desc, + "embedding": embedding_result[0], + "foreign_keys": fk, + }, ) - # Batch embeddings for table columns # TODO: Check if the embedding model and description are correct (without 2 sources of truth) batch_flag = True - col_descriptions = table_info.get('col_descriptions') + col_descriptions = table_info.get("col_descriptions") if col_descriptions is None: batch_flag = False else: try: embed_columns = [] for batch in tqdm.tqdm( - [col_descriptions[i:i + batch_size] for i in range(0, len(col_descriptions), batch_size)], - desc=f"Creating embeddings for {table_name} columns",): - + [ + col_descriptions[i : i + batch_size] + for i in range(0, len(col_descriptions), batch_size) + ], + desc=f"Creating embeddings for {table_name} columns", + ): + embedding_result = embedding_model.embed(batch) embed_columns.extend(embedding_result) except Exception as e: print(f"Error creating embeddings: {str(e)}") batch_flag = False - + # Create column nodes - for idx, (col_name, col_info) in tqdm.tqdm(enumerate(table_info['columns'].items()), desc=f"Creating Graph Columns for {table_name}", total=len(table_info['columns'])): + for idx, (col_name, col_info) in tqdm.tqdm( + enumerate(table_info["columns"].items()), + desc=f"Creating Graph Columns for {table_name}", + total=len(table_info["columns"]), + ): if not batch_flag: embed_columns = [] - embedding_result = embedding_model.embed(col_info['description']) + embedding_result = embedding_model.embed(col_info["description"]) embed_columns.extend(embedding_result) idx = 0 @@ -124,25 +133,27 @@ def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size })-[:BELONGS_TO]->(t) """, { - 'table_name': table_name, - 'col_name': col_name, - 'type': col_info.get('type', 'unknown'), - 'nullable': col_info.get('null', 'unknown'), - 'key': col_info.get('key', 'unknown'), - 'description': col_info['description'], - 'embedding': embed_columns[idx] - } + "table_name": table_name, + "col_name": col_name, + "type": col_info.get("type", "unknown"), + "nullable": col_info.get("null", "unknown"), + "key": col_info.get("key", "unknown"), + "description": col_info["description"], + "embedding": embed_columns[idx], + }, ) - + # Create relationships - for rel_name, table_info in tqdm.tqdm(relationships.items(), desc="Creating Graph Table Relationships"): + for rel_name, table_info in tqdm.tqdm( + relationships.items(), desc="Creating Graph Table Relationships" + ): for rel in table_info: - source_table = rel['from'] - source_field = rel['source_column'] - target_table = rel['to'] - target_field = rel['target_column'] - note = rel.get('note', '') - + source_table = rel["from"] + source_field = rel["source_column"] + target_table = rel["to"] + target_field = rel["target_column"] + note = rel.get("note", "") + # Create relationship if both tables and columns exist try: graph.query( @@ -157,14 +168,14 @@ def load_to_graph(graph_id: str, entities: dict, relationships: dict, batch_size }]->(tgt) """, { - 'source_col': source_field, - 'target_col': target_field, - 'source_table': source_table, - 'target_table': target_table, - 'rel_name': rel_name, - 'note': note - } + "source_col": source_field, + "target_col": target_field, + "source_table": source_table, + "target_table": target_table, + "rel_name": rel_name, + "note": note, + }, ) except Exception as e: print(f"Warning: Could not create relationship: {str(e)}") - continue \ No newline at end of file + continue diff --git a/api/loaders/json_loader.py b/api/loaders/json_loader.py index 853e0809..12427828 100644 --- a/api/loaders/json_loader.py +++ b/api/loaders/json_loader.py @@ -1,23 +1,26 @@ -from typing import Tuple import json +from typing import Tuple + import tqdm from jsonschema import ValidationError, validate from litellm import embedding + from api.config import Config -from api.loaders.base_loader import BaseLoader from api.extensions import db -from api.utils import generate_db_description +from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph from api.loaders.schema_validator import validate_table_schema +from api.utils import generate_db_description try: - with open(Config.SCHEMA_PATH, 'r', encoding='utf-8') as f: + with open(Config.SCHEMA_PATH, "r", encoding="utf-8") as f: schema = json.load(f) except FileNotFoundError as exc: raise FileNotFoundError(f"Schema file not found: {Config.SCHEMA_PATH}") from exc except json.JSONDecodeError as exc: raise ValueError(f"Invalid schema JSON: {str(exc)}") from exc + class JSONLoader(BaseLoader): @staticmethod @@ -37,22 +40,32 @@ def load(graph_id: str, data) -> Tuple[bool, str]: print("❌ Schema validation failed with the following issues:") for error in validation_errors: print(f" - {error}") - raise ValidationError("Schema validation failed. Please check the schema and try again.") + raise ValidationError( + "Schema validation failed. Please check the schema and try again." + ) except ValidationError as exc: return False, str(exc) - + relationships = {} - for table_name, table_info in tqdm.tqdm(data['tables'].items(), "Create Table relationships"): + for table_name, table_info in tqdm.tqdm( + data["tables"].items(), "Create Table relationships" + ): # Create Foreign Key relationships - for fk_name, fk_info in tqdm.tqdm(table_info['foreign_keys'].items(), "Create Foreign Key relationships"): + for fk_name, fk_info in tqdm.tqdm( + table_info["foreign_keys"].items(), "Create Foreign Key relationships" + ): if table_name not in relationships: relationships[table_name] = [] - relationships[table_name].append({"from": table_name, - "to": fk_info['referenced_table'], - "source_column": fk_info['column'], - "target_column": fk_info['referenced_column'], - "note": fk_name}) - load_to_graph(graph_id, data['tables'], relationships, db_name=data['database']) - - return True, "Graph loaded successfully" \ No newline at end of file + relationships[table_name].append( + { + "from": table_name, + "to": fk_info["referenced_table"], + "source_column": fk_info["column"], + "target_column": fk_info["referenced_column"], + "note": fk_name, + } + ) + load_to_graph(graph_id, data["tables"], relationships, db_name=data["database"]) + + return True, "Graph loaded successfully" diff --git a/api/loaders/odata_loader.py b/api/loaders/odata_loader.py index 762fe525..31bfb787 100644 --- a/api/loaders/odata_loader.py +++ b/api/loaders/odata_loader.py @@ -1,9 +1,11 @@ import re -from typing import Tuple import xml.etree.ElementTree as ET +from typing import Tuple + import tqdm -from api.loaders.base_loader import BaseLoader + from api.extensions import db +from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph @@ -14,7 +16,7 @@ class ODataLoader(BaseLoader): @staticmethod def load(graph_id: str, data) -> Tuple[bool, str]: - """ Load XML ODATA schema into a Graph. """ + """Load XML ODATA schema into a Graph.""" try: # Parse the OData schema @@ -38,90 +40,121 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: # Define namespaces namespaces = { - 'edmx': "http://docs.oasis-open.org/odata/ns/edmx", - 'edm': "http://docs.oasis-open.org/odata/ns/edm" + "edmx": "http://docs.oasis-open.org/odata/ns/edmx", + "edm": "http://docs.oasis-open.org/odata/ns/edm", } schema_element = root.find(".//edmx:DataServices/edm:Schema", namespaces) if schema_element is None: raise ET.ParseError("Schema element not found") - + entity_types = schema_element.findall("edm:EntityType", namespaces) for entity_type in tqdm.tqdm(entity_types, "Parsing OData schema"): entity_name = entity_type.get("Name") - entities[entity_name] = {'col_descriptions': []} + entities[entity_name] = {"col_descriptions": []} entities[entity_name]["columns"] = {} for prop in entity_type.findall("edm:Property", namespaces): prop_name = prop.get("Name") try: if prop_name is not None: entities[entity_name]["columns"][prop_name] = {} - entities[entity_name]["columns"][prop_name]["type"] = prop.get("Type").split(".")[-1] + entities[entity_name]["columns"][prop_name]["type"] = prop.get( + "Type" + ).split(".")[-1] col_des = entity_name if len(prop.findall("edm:Annotation", namespaces)) > 0: - if len(prop.findall("edm:Annotation", namespaces)[0].get("String")) > 0: - col_des = prop.findall("edm:Annotation", namespaces)[0].get("String") + if ( + len( + prop.findall("edm:Annotation", namespaces)[0].get( + "String" + ) + ) + > 0 + ): + col_des = prop.findall("edm:Annotation", namespaces)[ + 0 + ].get("String") entities[entity_name]["col_descriptions"].append(col_des) - entities[entity_name]["columns"][prop_name]["description"] = col_des + entities[entity_name]["columns"][prop_name][ + "description" + ] = col_des except Exception as e: - print(f"Error parsing property {prop_name} for entity {entity_name}") + print( + f"Error parsing property {prop_name} for entity {entity_name}" + ) continue # = {prop.get("Name"): prop.get("Type") for prop in entity_type.findall("edm:Property", namespaces)} description = entity_type.findall("edm:Annotation", namespaces) if len(description) > 0: - entities[entity_name]["description"] = description[0].get("String").replace("'", "\\'") + entities[entity_name]["description"] = ( + description[0].get("String").replace("'", "\\'") + ) else: try: - entities[entity_name]["description"] = entity_name + " with Primery key: " + entity_type.find("edm:Key/edm:PropertyRef", namespaces).attrib['Name'] + entities[entity_name]["description"] = ( + entity_name + + " with Primery key: " + + entity_type.find( + "edm:Key/edm:PropertyRef", namespaces + ).attrib["Name"] + ) except: print(f"Error parsing description for entity {entity_name}") entities[entity_name]["description"] = entity_name + for entity_type in tqdm.tqdm( + entity_types, "Parsing OData schema - relationships" + ): - for entity_type in tqdm.tqdm(entity_types, "Parsing OData schema - relationships"): - - entity_name = entity_type.attrib['Name'] + entity_name = entity_type.attrib["Name"] for rel in entity_type.findall("edm:NavigationProperty", namespaces): rel_name = rel.get("Name") - raw_type = rel.get("Type") # e.g., 'Collection(Priority.OData.ABILITYVALUES)' + raw_type = rel.get( + "Type" + ) # e.g., 'Collection(Priority.OData.ABILITYVALUES)' # Clean 'Collection(...)' wrapper if exists if raw_type.startswith("Collection(") and raw_type.endswith(")"): - raw_type = raw_type[len("Collection("):-1] + raw_type = raw_type[len("Collection(") : -1] # Extract the target entity name - match = re.search(r'(\w+)$', raw_type) + match = re.search(r"(\w+)$", raw_type) target_entity = match.group(1) if match else "UNKNOWN" - - source_entity = entity_name target_entity = target_entity source_fields = entities.get(entity_name, {})["columns"] target_fields = entities.get(target_entity, {})["columns"] - #TODO This usage is for demonstration purposes only, it should be replaced with a more robust method - source_col, target_col = guess_relationship_columns(source_fields, target_fields) + # TODO This usage is for demonstration purposes only, it should be replaced with a more robust method + source_col, target_col = guess_relationship_columns( + source_fields, target_fields + ) if source_col and target_col: # Store the relationship if rel_name not in relationships: relationships[rel_name] = [] - # src_col, tgt_col = guess_relationship_columns(source_entity, target_entity, entities[source_entity], entities[target_entity]) - relationships[rel_name].append({ - "from": source_entity, - "to": target_entity, - "source_column": source_col, - "target_column": target_col, - "note": "inferred" if source_col and target_col else "implicit/subform" - }) - + # src_col, tgt_col = guess_relationship_columns(source_entity, target_entity, entities[source_entity], entities[target_entity]) + relationships[rel_name].append( + { + "from": source_entity, + "to": target_entity, + "source_column": source_col, + "target_column": target_col, + "note": ( + "inferred" + if source_col and target_col + else "implicit/subform" + ), + } + ) return entities, relationships -#TODO: this funtion is for demonstration purposes only, it should be replaced with a more robust method +# TODO: this funtion is for demonstration purposes only, it should be replaced with a more robust method def guess_relationship_columns(source_fields, target_fields): for src_key, src_meta in source_fields.items(): if src_key == "description": diff --git a/api/loaders/schema_validator.py b/api/loaders/schema_validator.py index 5ad2206f..5b6ef0b5 100644 --- a/api/loaders/schema_validator.py +++ b/api/loaders/schema_validator.py @@ -1,7 +1,7 @@ - REQUIRED_COLUMN_KEYS = {"description", "type", "null", "key", "default"} VALID_NULL_VALUES = {"YES", "NO"} + def validate_table_schema(schema): errors = [] @@ -26,25 +26,35 @@ def validate_table_schema(schema): # Check for missing required keys missing_keys = REQUIRED_COLUMN_KEYS - column_data.keys() if missing_keys: - errors.append(f"Column '{column_name}' in table '{table_name}' is missing keys: {missing_keys}") + errors.append( + f"Column '{column_name}' in table '{table_name}' is missing keys: {missing_keys}" + ) continue # Validate non-empty description if not column_data.get("description"): - errors.append(f"Column '{column_name}' in table '{table_name}' has an empty description") + errors.append( + f"Column '{column_name}' in table '{table_name}' has an empty description" + ) # Validate 'null' field if column_data["null"] not in VALID_NULL_VALUES: - errors.append(f"Column '{column_name}' in table '{table_name}' has invalid 'null' value: {column_data['null']}") + errors.append( + f"Column '{column_name}' in table '{table_name}' has invalid 'null' value: {column_data['null']}" + ) # Optional: validate foreign keys if "foreign_keys" in table_data: if not isinstance(table_data["foreign_keys"], dict): - errors.append(f"Foreign keys for table '{table_name}' must be a dictionary") + errors.append( + f"Foreign keys for table '{table_name}' must be a dictionary" + ) else: for fk_name, fk_data in table_data["foreign_keys"].items(): for key in ("column", "referenced_table", "referenced_column"): if key not in fk_data or not fk_data[key]: - errors.append(f"Foreign key '{fk_name}' in table '{table_name}' is missing '{key}'") + errors.append( + f"Foreign key '{fk_name}' in table '{table_name}' is missing '{key}'" + ) return errors diff --git a/api/utils.py b/api/utils.py index 60d8c7b1..26124d60 100644 --- a/api/utils.py +++ b/api/utils.py @@ -1,11 +1,18 @@ import json from typing import List, Tuple + from litellm import completion + from api.config import Config from api.constants import BENCHMARK -def generate_db_description(db_name: str, table_names: List[str], temperature: float = 0.5, - max_tokens: int = 150) -> str: + +def generate_db_description( + db_name: str, + table_names: List[str], + temperature: float = 0.5, + max_tokens: int = 150, +) -> str: """ Generates a short and concise description of a database. @@ -20,14 +27,14 @@ def generate_db_description(db_name: str, table_names: List[str], temperature: f """ if not isinstance(db_name, str): raise TypeError("database_name must be a string.") - + if not isinstance(table_names, list): raise TypeError("table_names must be a list of strings.") - + # Ensure all table names are strings if not all(isinstance(table, str) for table in table_names): raise ValueError("All items in table_names must be strings.") - + if not table_names: return f"{db_name} is a database with no tables." @@ -44,21 +51,25 @@ def generate_db_description(db_name: str, table_names: List[str], temperature: f f"which contains the following tables: {tables_formatted}.\n\n" f"Description:" ) - - response = completion(model=Config.COMPLETION_MODEL, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": prompt} - ], - temperature=temperature, - max_tokens=max_tokens, - n=1, - stop=None, - ) - description = response.choices[0].message['content'] + + response = completion( + model=Config.COMPLETION_MODEL, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ], + temperature=temperature, + max_tokens=max_tokens, + n=1, + stop=None, + ) + description = response.choices[0].message["content"] return description -def llm_answer_validator(question: str, answer: str, expected_answer: str = None) -> str: + +def llm_answer_validator( + question: str, answer: str, expected_answer: str = None +) -> str: prompt = """ You are evaluating an answer generated by a text-to-sql RAG-based system. Assess how well the Generated Answer (generated sql) addresses the Question based on the Expected Answer. @@ -76,18 +87,28 @@ def llm_answer_validator(question: str, answer: str, expected_answer: str = None Output Json format: {{"relevance_score": float, "explanation": "Your assessment here."}} """ - response = completion(model=Config.VALIDTOR_MODEL, - messages=[ - {"role": "system", "content": "You are a Validator assistant."}, - {"role": "user", "content": prompt.format(question=question, expected_answer=expected_answer, generated_answer=answer)} - ], - response_format={"type": "json_object"}, - - ) - validation_set = response.choices[0].message['content'].strip() + response = completion( + model=Config.VALIDTOR_MODEL, + messages=[ + {"role": "system", "content": "You are a Validator assistant."}, + { + "role": "user", + "content": prompt.format( + question=question, + expected_answer=expected_answer, + generated_answer=answer, + ), + }, + ], + response_format={"type": "json_object"}, + ) + validation_set = response.choices[0].message["content"].strip() return validation_set -def llm_table_validator(question: str, answer: str, tables: List[str]) -> Tuple[float, str]: + +def llm_table_validator( + question: str, answer: str, tables: List[str] +) -> Tuple[float, str]: prompt = """ You are evaluating an answer generated by a text-to-sql RAG-based system. Assess how well the retrived Tables relevant to the question and supports the Generated Answer (generated sql). - The tables are with the following structure: @@ -106,18 +127,24 @@ def llm_table_validator(question: str, answer: str, tables: List[str]) -> Tuple[ Output Json format: {{"relevance_score": float, "explanation": "Your assessment here."}} """ - response = completion(model=Config.VALIDTOR_MODEL, - messages=[ - {"role": "system", "content": "You are a Validator assistant."}, - {"role": "user", "content": prompt.format(question=question, tables=tables, generated_answer=answer)} - ], - response_format={"type": "json_object"}, - ) - validation_set = response.choices[0].message['content'].strip() + response = completion( + model=Config.VALIDTOR_MODEL, + messages=[ + {"role": "system", "content": "You are a Validator assistant."}, + { + "role": "user", + "content": prompt.format( + question=question, tables=tables, generated_answer=answer + ), + }, + ], + response_format={"type": "json_object"}, + ) + validation_set = response.choices[0].message["content"].strip() try: val_res = json.loads(validation_set) - score = val_res['relevance_score'] - explanation = val_res['explanation'] + score = val_res["relevance_score"] + explanation = val_res["explanation"] except Exception as e: print(f"Error: {e}") score = 0.0 @@ -138,8 +165,7 @@ def run_benchmark(): for data in benchmark_data: success, result = generate_db_description( - db_name=data['database'], - table_names=list(data['tables'].keys()) + db_name=data["database"], table_names=list(data["tables"].keys()) ) if success: @@ -147,4 +173,4 @@ def run_benchmark(): else: results.append(f"Error: {result}") - return results \ No newline at end of file + return results diff --git a/onthology.py b/onthology.py index f5f0007c..ecf98e5c 100644 --- a/onthology.py +++ b/onthology.py @@ -3,7 +3,7 @@ from graphrag_sdk.models.litellm import LiteModel model = LiteModel(model_name="gemini/gemini-2.0-flash") -db = FalkorDB(host='localhost', port=6379) +db = FalkorDB(host="localhost", port=6379) kg_name = "crm_system" ontology = Ontology.from_kg_graph(db.select_graph(kg_name), 1000000000) ontology.save_to_graph(db.select_graph(f"{{{kg_name}}}_schema")) From f6d2ccd0c871b9ff4224374eb61a79ba058734ad Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 10:16:39 +0300 Subject: [PATCH 03/58] fix long lines --- api/agents.py | 16 +++------- api/graph.py | 16 +++------- api/helpers/crm_data_generator.py | 53 ++++++++++--------------------- api/index.py | 23 ++++---------- api/loaders/csv_loader.py | 20 +++--------- api/loaders/graph_loader.py | 4 +-- api/loaders/odata_loader.py | 41 ++++++------------------ api/loaders/schema_validator.py | 4 +-- api/utils.py | 12 ++----- 9 files changed, 51 insertions(+), 138 deletions(-) diff --git a/api/agents.py b/api/agents.py index 1f5228dd..711843a6 100644 --- a/api/agents.py +++ b/api/agents.py @@ -24,9 +24,7 @@ def get_analysis( instructions: str = None, ) -> dict: formatted_schema = self._format_schema(combined_tables) - prompt = self._build_prompt( - user_query, formatted_schema, db_description, instructions - ) + prompt = self._build_prompt(user_query, formatted_schema, db_description, instructions) self.messages.append({"role": "user", "content": prompt}) completion_result = completion( model=Config.COMPLETION_MODEL, @@ -38,17 +36,13 @@ def get_analysis( response = completion_result.choices[0].message.content analysis = _parse_response(response) if isinstance(analysis["ambiguities"], list): - analysis["ambiguities"] = [ - item.replace("-", " ") for item in analysis["ambiguities"] - ] + analysis["ambiguities"] = [item.replace("-", " ") for item in analysis["ambiguities"]] analysis["ambiguities"] = "- " + "- ".join(analysis["ambiguities"]) if isinstance(analysis["missing_information"], list): analysis["missing_information"] = [ item.replace("-", " ") for item in analysis["missing_information"] ] - analysis["missing_information"] = "- " + "- ".join( - analysis["missing_information"] - ) + analysis["missing_information"] = "- " + "- ".join(analysis["missing_information"]) self.messages.append({"role": "assistant", "content": analysis["sql_query"]}) return analysis @@ -96,9 +90,7 @@ def _format_schema(self, schema_data: List) -> str: column = fk_info.get("column", "") ref_table = fk_info.get("referenced_table", "") ref_column = fk_info.get("referenced_column", "") - table_str += ( - f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" - ) + table_str += f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" formatted_schema.append(table_str) diff --git a/api/graph.py b/api/graph.py index c8938a1d..d7986787 100644 --- a/api/graph.py +++ b/api/graph.py @@ -11,9 +11,7 @@ from api.config import Config from api.extensions import db -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") class TableDescription(BaseModel): @@ -73,9 +71,7 @@ def find( response_format=Descriptions, messages=[ { - "content": Config.FIND_SYSTEM_PROMPT.format( - db_description=db_description - ), + "content": Config.FIND_SYSTEM_PROMPT.format(db_description=db_description), "role": "system", }, { @@ -99,9 +95,7 @@ def find( logging.info(f"Find tables based on: {descriptions.tables_descriptions}") tables_des = _find_tables(graph, descriptions.tables_descriptions) logging.info(f"Find tables based on columns: {descriptions.columns_descriptions}") - tables_by_columns_des = _find_tables_by_columns( - graph, descriptions.columns_descriptions - ) + tables_by_columns_des = _find_tables_by_columns(graph, descriptions.columns_descriptions) # table names for sphere and route extraction base_tables_names = [table[0] for table in tables_des] @@ -238,9 +232,7 @@ def _get_unique_tables(tables_list): return list(unique_tables.values()) -def find_connecting_tables( - graph, table_names: List[str] -) -> Tuple[List[dict], List[str]]: +def find_connecting_tables(graph, table_names: List[str]) -> Tuple[List[dict], List[str]]: """ Find all tables that form connections between any pair of tables in the input list. Handles both Table nodes and Column nodes with primary keys. diff --git a/api/helpers/crm_data_generator.py b/api/helpers/crm_data_generator.py index ce4818b2..575b72e6 100644 --- a/api/helpers/crm_data_generator.py +++ b/api/helpers/crm_data_generator.py @@ -45,9 +45,7 @@ def save_schema(schema: Dict[str, Any], output_file: str = OUTPUT_FILE) -> None: schema["metadata"]["key_registry"] = { "primary_keys": key_registry["primary_keys"], "foreign_keys": key_registry["foreign_keys"], - "table_relationships": { - k: list(v) for k, v in key_registry["table_relationships"].items() - }, + "table_relationships": {k: list(v) for k, v in key_registry["table_relationships"].items()}, } with open(output_file, "w") as file: @@ -135,9 +133,7 @@ def get_table_prompt( # Find related tables related_tables = find_related_tables(table_name, all_table_names) - related_tables_str = ( - ", ".join(related_tables) if related_tables else "None identified yet" - ) + related_tables_str = ", ".join(related_tables) if related_tables else "None identified yet" # Suggest primary key pattern table_base = table_name.split("_")[0] if "_" in table_name else table_name @@ -168,7 +164,9 @@ def get_table_prompt( if fk_suggestions: fk_suggestions_str = "Consider these foreign key relationships:\n" for i, fk in enumerate(fk_suggestions[:5]): # Limit to 5 suggestions - fk_suggestions_str += f"{i+1}. {fk['column']} -> {fk['referenced_table']}.{fk['referenced_column']}\n" + fk_suggestions_str += ( + f"{i+1}. {fk['column']} -> {fk['referenced_table']}.{fk['referenced_column']}\n" + ) # Include examples of related tables that have been processed related_examples = "" @@ -397,9 +395,7 @@ def get_table_context(table_name: str, related_tables: List[str]) -> str: context = f"The '{table_name}' table appears to be " # Check if this is a junction/linking table - if "_" in table_name and not any( - p in table_name for p in relationship_patterns.keys() - ): + if "_" in table_name and not any(p in table_name for p in relationship_patterns.keys()): parts = table_name.split("_") if len(parts) == 2 and all(len(p) > 2 for p in parts): return f"This appears to be a junction table linking '{parts[0]}' and '{parts[1]}', likely with a many-to-many relationship." @@ -420,14 +416,14 @@ def get_table_context(table_name: str, related_tables: List[str]) -> str: # Add related tables info if related_tables: - context += f"It appears to be related to the following tables: {', '.join(related_tables)}. " + context += ( + f"It appears to be related to the following tables: {', '.join(related_tables)}. " + ) # Guess if it's a child table for related in related_tables: if related in table_name and len(related) < len(table_name): - context += ( - f"It may be a child or detail table for the {related} table. " - ) + context += f"It may be a child or detail table for the {related} table. " break return context @@ -501,9 +497,7 @@ def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any for col_name, col_data in table_data["columns"].items(): required_col_attrs = ["description", "type", "null"] if not all(attr in col_data for attr in required_col_attrs): - print( - f"Warning: Column {col_name} is missing required attributes" - ) + print(f"Warning: Column {col_name} is missing required attributes") return {table_name: table_data} else: @@ -513,9 +507,7 @@ def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any else: # Try to get the first key if table_name is not found first_key = next(iter(parsed)) - print( - f"Warning: Table name mismatch. Expected {table_name}, got {first_key}" - ) + print(f"Warning: Table name mismatch. Expected {table_name}, got {first_key}") return {table_name: parsed[first_key]} except Exception as e: print(f"Error parsing LLM response for {table_name}: {e}") @@ -567,9 +559,7 @@ def process_table( def main(): # Load the initial schema with table names - initial_schema_path = ( - "examples/crm_tables.json" # Replace with your actual file path - ) + initial_schema_path = "examples/crm_tables.json" # Replace with your actual file path initial_schema = load_initial_schema(initial_schema_path) # Get the list of tables to process @@ -685,9 +675,7 @@ def validate_schema(schema: Dict[str, Any]) -> None: 1 for t in schema["tables"].values() if isinstance(t, dict) and "indexes" in t ) tables_with_foreign_keys = sum( - 1 - for t in schema["tables"].values() - if isinstance(t, dict) and "foreign_keys" in t + 1 for t in schema["tables"].values() if isinstance(t, dict) and "foreign_keys" in t ) print(f"Total tables: {table_count}") @@ -734,18 +722,11 @@ def validate_schema(schema: Dict[str, Any]) -> None: ref_column = fk_data.get("referenced_column") if ref_table and ref_table not in schema["tables"]: - invalid_fks.append( - f"{table_name}.{fk_name} -> {ref_table} (table not found)" - ) + invalid_fks.append(f"{table_name}.{fk_name} -> {ref_table} (table not found)") elif ref_table and ref_column: ref_table_data = schema["tables"].get(ref_table, {}) - if ( - not isinstance(ref_table_data, dict) - or "columns" not in ref_table_data - ): - invalid_fks.append( - f"{table_name}.{fk_name} -> {ref_table} (no columns)" - ) + if not isinstance(ref_table_data, dict) or "columns" not in ref_table_data: + invalid_fks.append(f"{table_name}.{fk_name} -> {ref_table} (no columns)") elif ref_column not in ref_table_data.get("columns", {}): invalid_fks.append( f"{table_name}.{fk_name} -> {ref_table}.{ref_column} (column not found)" diff --git a/api/index.py b/api/index.py index 2fd62936..a13f685a 100644 --- a/api/index.py +++ b/api/index.py @@ -9,8 +9,7 @@ from functools import wraps from dotenv import load_dotenv -from flask import (Blueprint, Flask, Response, jsonify, render_template, - request, stream_with_context) +from flask import Blueprint, Flask, Response, jsonify, render_template, request, stream_with_context from api.agents import AnalysisAgent, RelevancyAgent from api.constants import BENCHMARK, EXAMPLES @@ -22,9 +21,7 @@ # Load environment variables from .env file load_dotenv() -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") # Use the same delimiter as in the JavaScript MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||" @@ -126,9 +123,7 @@ def load(): success, result = JSONLoader.load(graph_id, data) # ✅ Handle XML Payload - elif content_type.startswith("application/xml") or content_type.startswith( - "text/xml" - ): + elif content_type.startswith("application/xml") or content_type.startswith("text/xml"): xml_data = request.data graph_id = "" success, result = ODataLoader.load(graph_id, xml_data) @@ -203,9 +198,7 @@ def generate(): step = {"type": "reasoning_step", "message": "Step 1: Analyzing the user query"} yield json.dumps(step) + MESSAGE_DELIMITER - db_description = get_db_description( - graph_id - ) # Ensure the database description is loaded + db_description = get_db_description(graph_id) # Ensure the database description is loaded logging.info(f"Calling to relvancy agent with query: {queries_history[-1]}") answer_rel = agent_rel.get_answer(queries_history[-1], db_description) @@ -219,9 +212,7 @@ def generate(): else: # Use a thread pool to enforce timeout with ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit( - find, graph_id, queries_history, db_description - ) + future = executor.submit(find, graph_id, queries_history, db_description) try: success, result, _ = future.result(timeout=120) except FuturesTimeoutError: @@ -281,9 +272,7 @@ def suggestions(): if graph_id in EXAMPLES: graph_examples = EXAMPLES[graph_id] # Return up to 3 examples, or all if less than 3 - suggestion_questions = random.sample( - graph_examples, min(3, len(graph_examples)) - ) + suggestion_questions = random.sample(graph_examples, min(3, len(graph_examples))) return jsonify(suggestion_questions) else: # If graph doesn't exist in EXAMPLES, return empty list diff --git a/api/loaders/csv_loader.py b/api/loaders/csv_loader.py index 3a7569e1..c0eea431 100644 --- a/api/loaders/csv_loader.py +++ b/api/loaders/csv_loader.py @@ -68,9 +68,7 @@ def load(graph_id: str, data) -> Tuple[bool, str]: rel_table = defaultdict(lambda: {"primary_key_table": "", "fk_tables": []}) relationships = {} # First pass: Organize data into tables - for idx, row in tqdm.tqdm( - df.iterrows(), total=len(df), desc="Organizing data" - ): + for idx, row in tqdm.tqdm(df.iterrows(), total=len(df), desc="Organizing data"): schema = row["Schema"] domain = row["Domain"] @@ -87,13 +85,9 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # Add column information field = row["Field"] field_type = row["Type"] if not pd.isna(row["Type"]) else "STRING" - field_desc = ( - row["Description"] if not pd.isna(row["Description"]) else field - ) + field_desc = row["Description"] if not pd.isna(row["Description"]) else field - nullable = ( - True # Default to nullable since we don't have explicit null info - ) + nullable = True # Default to nullable since we don't have explicit null info if not pd.isna(field): tables[table_name]["col_descriptions"].append(field_desc) tables[table_name]["columns"][field] = { @@ -121,9 +115,7 @@ def load(graph_id: str, data) -> Tuple[bool, str]: "source_column": source_field, "target_column": ( df.to_dict("records")[idx + 1]["Array Field"] - if not pd.isna( - df.to_dict("records")[idx + 1]["Array Field"] - ) + if not pd.isna(df.to_dict("records")[idx + 1]["Array Field"]) else "" ), "note": "", @@ -155,9 +147,7 @@ def load(graph_id: str, data) -> Tuple[bool, str]: "extra": "", } if field.endswith("_id"): - if len(tables[table_name]["columns"]) == 1 and field.endswith( - "_id" - ): + if len(tables[table_name]["columns"]) == 1 and field.endswith("_id"): suspected_primary_key = field[:-3] if suspected_primary_key in domain: rel_table[field]["primary_key_table"] = table_name diff --git a/api/loaders/graph_loader.py b/api/loaders/graph_loader.py index b02918e9..c3cdb651 100644 --- a/api/loaders/graph_loader.py +++ b/api/loaders/graph_loader.py @@ -60,9 +60,7 @@ def load_to_graph( {"db_name": db_name, "description": db_des}, ) - for table_name, table_info in tqdm.tqdm( - entities.items(), desc="Creating Graph Table Nodes" - ): + for table_name, table_info in tqdm.tqdm(entities.items(), desc="Creating Graph Table Nodes"): table_desc = table_info["description"] embedding_result = embedding_model.embed(table_desc) fk = json.dumps(table_info.get("foreign_keys", [])) diff --git a/api/loaders/odata_loader.py b/api/loaders/odata_loader.py index 31bfb787..4562dd77 100644 --- a/api/loaders/odata_loader.py +++ b/api/loaders/odata_loader.py @@ -63,25 +63,14 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: ).split(".")[-1] col_des = entity_name if len(prop.findall("edm:Annotation", namespaces)) > 0: - if ( - len( - prop.findall("edm:Annotation", namespaces)[0].get( - "String" - ) + if len(prop.findall("edm:Annotation", namespaces)[0].get("String")) > 0: + col_des = prop.findall("edm:Annotation", namespaces)[0].get( + "String" ) - > 0 - ): - col_des = prop.findall("edm:Annotation", namespaces)[ - 0 - ].get("String") entities[entity_name]["col_descriptions"].append(col_des) - entities[entity_name]["columns"][prop_name][ - "description" - ] = col_des + entities[entity_name]["columns"][prop_name]["description"] = col_des except Exception as e: - print( - f"Error parsing property {prop_name} for entity {entity_name}" - ) + print(f"Error parsing property {prop_name} for entity {entity_name}") continue # = {prop.get("Name"): prop.get("Type") for prop in entity_type.findall("edm:Property", namespaces)} @@ -95,25 +84,19 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: entities[entity_name]["description"] = ( entity_name + " with Primery key: " - + entity_type.find( - "edm:Key/edm:PropertyRef", namespaces - ).attrib["Name"] + + entity_type.find("edm:Key/edm:PropertyRef", namespaces).attrib["Name"] ) except: print(f"Error parsing description for entity {entity_name}") entities[entity_name]["description"] = entity_name - for entity_type in tqdm.tqdm( - entity_types, "Parsing OData schema - relationships" - ): + for entity_type in tqdm.tqdm(entity_types, "Parsing OData schema - relationships"): entity_name = entity_type.attrib["Name"] for rel in entity_type.findall("edm:NavigationProperty", namespaces): rel_name = rel.get("Name") - raw_type = rel.get( - "Type" - ) # e.g., 'Collection(Priority.OData.ABILITYVALUES)' + raw_type = rel.get("Type") # e.g., 'Collection(Priority.OData.ABILITYVALUES)' # Clean 'Collection(...)' wrapper if exists if raw_type.startswith("Collection(") and raw_type.endswith(")"): @@ -129,9 +112,7 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: target_fields = entities.get(target_entity, {})["columns"] # TODO This usage is for demonstration purposes only, it should be replaced with a more robust method - source_col, target_col = guess_relationship_columns( - source_fields, target_fields - ) + source_col, target_col = guess_relationship_columns(source_fields, target_fields) if source_col and target_col: # Store the relationship if rel_name not in relationships: @@ -144,9 +125,7 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: "source_column": source_col, "target_column": target_col, "note": ( - "inferred" - if source_col and target_col - else "implicit/subform" + "inferred" if source_col and target_col else "implicit/subform" ), } ) diff --git a/api/loaders/schema_validator.py b/api/loaders/schema_validator.py index 5b6ef0b5..1f0ae169 100644 --- a/api/loaders/schema_validator.py +++ b/api/loaders/schema_validator.py @@ -46,9 +46,7 @@ def validate_table_schema(schema): # Optional: validate foreign keys if "foreign_keys" in table_data: if not isinstance(table_data["foreign_keys"], dict): - errors.append( - f"Foreign keys for table '{table_name}' must be a dictionary" - ) + errors.append(f"Foreign keys for table '{table_name}' must be a dictionary") else: for fk_name, fk_data in table_data["foreign_keys"].items(): for key in ("column", "referenced_table", "referenced_column"): diff --git a/api/utils.py b/api/utils.py index 26124d60..0c8864d1 100644 --- a/api/utils.py +++ b/api/utils.py @@ -67,9 +67,7 @@ def generate_db_description( return description -def llm_answer_validator( - question: str, answer: str, expected_answer: str = None -) -> str: +def llm_answer_validator(question: str, answer: str, expected_answer: str = None) -> str: prompt = """ You are evaluating an answer generated by a text-to-sql RAG-based system. Assess how well the Generated Answer (generated sql) addresses the Question based on the Expected Answer. @@ -106,9 +104,7 @@ def llm_answer_validator( return validation_set -def llm_table_validator( - question: str, answer: str, tables: List[str] -) -> Tuple[float, str]: +def llm_table_validator(question: str, answer: str, tables: List[str]) -> Tuple[float, str]: prompt = """ You are evaluating an answer generated by a text-to-sql RAG-based system. Assess how well the retrived Tables relevant to the question and supports the Generated Answer (generated sql). - The tables are with the following structure: @@ -133,9 +129,7 @@ def llm_table_validator( {"role": "system", "content": "You are a Validator assistant."}, { "role": "user", - "content": prompt.format( - question=question, tables=tables, generated_answer=answer - ), + "content": prompt.format(question=question, tables=tables, generated_answer=answer), }, ], response_format={"type": "json_object"}, From bbc32ff0d1db98ea7a37d1368ac1151d9954278f Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 10:55:33 +0300 Subject: [PATCH 04/58] disable as long as it's a private repo --- .github/workflows/dependency-review.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/dependency-review.yml b/.github/workflows/dependency-review.yml index d19e21b7..dae0e66e 100644 --- a/.github/workflows/dependency-review.yml +++ b/.github/workflows/dependency-review.yml @@ -25,6 +25,7 @@ permissions: jobs: dependency-review: + if: github.repository_visibility == 'public' runs-on: ubuntu-latest steps: - name: 'Checkout repository' From 5a837073795575a9af989d72bc549c698fd33a89 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 10:59:51 +0300 Subject: [PATCH 05/58] fix lint issues --- api/agents.py | 24 ++++++++++++++++-------- api/config.py | 2 -- api/helpers/crm_data_generator.py | 4 +--- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/api/agents.py b/api/agents.py index 711843a6..e9f06b80 100644 --- a/api/agents.py +++ b/api/agents.py @@ -24,7 +24,9 @@ def get_analysis( instructions: str = None, ) -> dict: formatted_schema = self._format_schema(combined_tables) - prompt = self._build_prompt(user_query, formatted_schema, db_description, instructions) + prompt = self._build_prompt( + user_query, formatted_schema, db_description, instructions + ) self.messages.append({"role": "user", "content": prompt}) completion_result = completion( model=Config.COMPLETION_MODEL, @@ -36,13 +38,17 @@ def get_analysis( response = completion_result.choices[0].message.content analysis = _parse_response(response) if isinstance(analysis["ambiguities"], list): - analysis["ambiguities"] = [item.replace("-", " ") for item in analysis["ambiguities"]] + analysis["ambiguities"] = [ + item.replace("-", " ") for item in analysis["ambiguities"] + ] analysis["ambiguities"] = "- " + "- ".join(analysis["ambiguities"]) if isinstance(analysis["missing_information"], list): analysis["missing_information"] = [ item.replace("-", " ") for item in analysis["missing_information"] ] - analysis["missing_information"] = "- " + "- ".join(analysis["missing_information"]) + analysis["missing_information"] = "- " + "- ".join( + analysis["missing_information"] + ) self.messages.append({"role": "assistant", "content": analysis["sql_query"]}) return analysis @@ -76,9 +82,9 @@ def _format_schema(self, schema_data: List) -> str: nullable = column.get("nullable", False) key_info = ( - f", PRIMARY KEY" + ", PRIMARY KEY" if col_key == "PRI" - else f", FOREIGN KEY" if col_key == "FK" else "" + else ", FOREIGN KEY" if col_key == "FK" else "" ) column_str = f" - {col_name} ({col_type},{key_info},{col_key},{nullable}): {col_description}" table_str += column_str + "\n" @@ -90,7 +96,9 @@ def _format_schema(self, schema_data: List) -> str: column = fk_info.get("column", "") ref_table = fk_info.get("referenced_table", "") ref_column = fk_info.get("referenced_column", "") - table_str += f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" + table_str += ( + f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" + ) formatted_schema.append(table_str) @@ -166,8 +174,8 @@ def _build_prompt( "explanation": "Detailed explanation why the query can or cannot be translated, mentioning instructions explicitly and referencing conversation history if relevant", "sql_query": "High-level SQL query (you must to applying instructions and use previous answers if the question is a continuation)", "tables_used": ["list", "of", "tables", "used", "in", "the", "query", "with", "the", "relationships", "between", "them"], - "missing_information": ["list", "of", "missing", "information"], - "ambiguities": ["list", "of", "ambiguities"], + "missing_information": ["list", "of", "missing", "information"], + "ambiguities": ["list", "of", "ambiguities"], "confidence": integer between 0 and 100 }} diff --git a/api/config.py b/api/config.py index 6accfb4d..67a94e10 100644 --- a/api/config.py +++ b/api/config.py @@ -3,10 +3,8 @@ """ import dataclasses -import os from typing import Union -import boto3 from litellm import embedding diff --git a/api/helpers/crm_data_generator.py b/api/helpers/crm_data_generator.py index 575b72e6..e255e4fa 100644 --- a/api/helpers/crm_data_generator.py +++ b/api/helpers/crm_data_generator.py @@ -1,12 +1,10 @@ import json import os import time -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional import requests from litellm import completion -from litellm import utils as litellm_utils -from litellm import validate_environment OUTPUT_FILE = "complete_crm_schema.json" MAX_RETRIES = 3 From fa74cfa7bcbbc24ce4fcd5771967fd8f3d769208 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 11:22:02 +0300 Subject: [PATCH 06/58] clean lint --- api/config.py | 2 ++ api/extensions.py | 2 +- api/loaders/base_loader.py | 3 +++ api/loaders/csv_loader.py | 10 ++++++---- api/loaders/graph_loader.py | 2 ++ api/loaders/json_loader.py | 8 ++++---- api/utils.py | 36 ++++++++++++++++++++++++++++++------ onthology.py | 8 +++++--- 8 files changed, 53 insertions(+), 18 deletions(-) diff --git a/api/config.py b/api/config.py index 67a94e10..386155f2 100644 --- a/api/config.py +++ b/api/config.py @@ -9,6 +9,7 @@ class EmbeddingsModel: + """Embeddings model wrapper for text embedding operations.""" def __init__(self, model_name: str, config: dict = None): self.model_name = model_name @@ -51,6 +52,7 @@ class Config: SCHEMA_PATH = "api/schema_schema.json" EMBEDDING_MODEL_NAME = "azure/text-embedding-ada-002" COMPLETION_MODEL = "azure/gpt-4.1" + VALIDATOR_MODEL = "azure/gpt-4.1" TEMPERATURE = 0 # client = boto3.client('sts') # AWS_PROFILE = os.getenv("aws_profile_name") diff --git a/api/extensions.py b/api/extensions.py index 42602e8f..01dd2e1a 100644 --- a/api/extensions.py +++ b/api/extensions.py @@ -10,6 +10,6 @@ try: db = FalkorDB(host="localhost", port=6379) except Exception as e: - raise Exception(f"Failed to connect to FalkorDB: {e}") + raise ConnectionError(f"Failed to connect to FalkorDB: {e}") from e else: db = FalkorDB.from_url(os.getenv("FALKORDB_URL")) diff --git a/api/loaders/base_loader.py b/api/loaders/base_loader.py index 7b46649b..d6418382 100644 --- a/api/loaders/base_loader.py +++ b/api/loaders/base_loader.py @@ -1,8 +1,11 @@ +"""Base loader module providing abstract base class for data loaders.""" + from abc import ABC from typing import Tuple class BaseLoader(ABC): + """Abstract base class for data loaders.""" @staticmethod def load(_graph_id: str, _data) -> Tuple[bool, str]: diff --git a/api/loaders/csv_loader.py b/api/loaders/csv_loader.py index c0eea431..05d22659 100644 --- a/api/loaders/csv_loader.py +++ b/api/loaders/csv_loader.py @@ -1,17 +1,19 @@ +"""CSV loader module for processing CSV files and generating database schemas.""" + import io from collections import defaultdict -from typing import Dict, List, Tuple +from typing import Tuple -# import pandas as pd +import pandas as pd import tqdm -from litellm import embedding -from api.extensions import db from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph class CSVLoader(BaseLoader): + """CSV data loader for processing CSV files and loading them into graph database.""" + @staticmethod def load(graph_id: str, data) -> Tuple[bool, str]: """ diff --git a/api/loaders/graph_loader.py b/api/loaders/graph_loader.py index c3cdb651..a6edc96e 100644 --- a/api/loaders/graph_loader.py +++ b/api/loaders/graph_loader.py @@ -1,3 +1,5 @@ +"""Graph loader module for loading data into graph databases.""" + import json import tqdm diff --git a/api/loaders/json_loader.py b/api/loaders/json_loader.py index 12427828..0b74ec60 100644 --- a/api/loaders/json_loader.py +++ b/api/loaders/json_loader.py @@ -1,16 +1,15 @@ +"""JSON loader module for processing JSON schema files.""" + import json from typing import Tuple import tqdm -from jsonschema import ValidationError, validate -from litellm import embedding +from jsonschema import ValidationError from api.config import Config -from api.extensions import db from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph from api.loaders.schema_validator import validate_table_schema -from api.utils import generate_db_description try: with open(Config.SCHEMA_PATH, "r", encoding="utf-8") as f: @@ -22,6 +21,7 @@ class JSONLoader(BaseLoader): + """JSON schema loader for loading database schemas from JSON files.""" @staticmethod def load(graph_id: str, data) -> Tuple[bool, str]: diff --git a/api/utils.py b/api/utils.py index 0c8864d1..65932588 100644 --- a/api/utils.py +++ b/api/utils.py @@ -1,3 +1,5 @@ +"""Utility functions for the text2sql API.""" + import json from typing import List, Tuple @@ -47,9 +49,9 @@ def generate_db_description( tables_formatted = ", ".join(table_names[:-1]) + f", and {table_names[-1]}" prompt = ( - f"You are a helpful assistant. Generate a concise description of the database named '{db_name}' " - f"which contains the following tables: {tables_formatted}.\n\n" - f"Description:" + f"You are a helpful assistant. Generate a concise description of " + f"the database named '{db_name}' which contains the following tables: " + f"{tables_formatted}.\n\nDescription:" ) response = completion( @@ -68,6 +70,17 @@ def generate_db_description( def llm_answer_validator(question: str, answer: str, expected_answer: str = None) -> str: + """ + Validate an answer using LLM. + + Args: + question: The original question + answer: The generated answer + expected_answer: The expected answer for comparison + + Returns: + JSON string with validation results + """ prompt = """ You are evaluating an answer generated by a text-to-sql RAG-based system. Assess how well the Generated Answer (generated sql) addresses the Question based on the Expected Answer. @@ -86,7 +99,7 @@ def llm_answer_validator(question: str, answer: str, expected_answer: str = None {{"relevance_score": float, "explanation": "Your assessment here."}} """ response = completion( - model=Config.VALIDTOR_MODEL, + model=Config.VALIDATOR_MODEL, messages=[ {"role": "system", "content": "You are a Validator assistant."}, { @@ -105,6 +118,17 @@ def llm_answer_validator(question: str, answer: str, expected_answer: str = None def llm_table_validator(question: str, answer: str, tables: List[str]) -> Tuple[float, str]: + """ + Validate table relevance using LLM. + + Args: + question: The original question + answer: The generated answer + tables: List of available tables + + Returns: + Tuple of relevance score and explanation + """ prompt = """ You are evaluating an answer generated by a text-to-sql RAG-based system. Assess how well the retrived Tables relevant to the question and supports the Generated Answer (generated sql). - The tables are with the following structure: @@ -124,7 +148,7 @@ def llm_table_validator(question: str, answer: str, tables: List[str]) -> Tuple[ {{"relevance_score": float, "explanation": "Your assessment here."}} """ response = completion( - model=Config.VALIDTOR_MODEL, + model=Config.VALIDATOR_MODEL, messages=[ {"role": "system", "content": "You are a Validator assistant."}, { @@ -139,7 +163,7 @@ def llm_table_validator(question: str, answer: str, tables: List[str]) -> Tuple[ val_res = json.loads(validation_set) score = val_res["relevance_score"] explanation = val_res["explanation"] - except Exception as e: + except (json.JSONDecodeError, KeyError) as e: print(f"Error: {e}") score = 0.0 explanation = "Error: Unable to parse the response." diff --git a/onthology.py b/onthology.py index ecf98e5c..9ae9a783 100644 --- a/onthology.py +++ b/onthology.py @@ -1,9 +1,11 @@ +"""Ontology generation module for CRM system knowledge graph.""" + from falkordb import FalkorDB from graphrag_sdk import Ontology from graphrag_sdk.models.litellm import LiteModel model = LiteModel(model_name="gemini/gemini-2.0-flash") db = FalkorDB(host="localhost", port=6379) -kg_name = "crm_system" -ontology = Ontology.from_kg_graph(db.select_graph(kg_name), 1000000000) -ontology.save_to_graph(db.select_graph(f"{{{kg_name}}}_schema")) +KG_NAME = "crm_system" +ontology = Ontology.from_kg_graph(db.select_graph(KG_NAME), 1000000000) +ontology.save_to_graph(db.select_graph(f"{{{KG_NAME}}}_schema")) From 358f978dd8f06f030ea3bd29b408a1018122deb9 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 12:04:57 +0300 Subject: [PATCH 07/58] fix more lint errors --- api/agents.py | 50 ++++++++--- api/constants.py | 139 ++++++++++++++++++++++++------ api/graph.py | 13 +-- api/helpers/crm_data_generator.py | 90 +++++++++++-------- api/index.py | 55 ++++++------ 5 files changed, 241 insertions(+), 106 deletions(-) diff --git a/api/agents.py b/api/agents.py index e9f06b80..c7914953 100644 --- a/api/agents.py +++ b/api/agents.py @@ -1,3 +1,5 @@ +"""Module containing agent classes for handling analysis and SQL generation tasks.""" + import json from typing import Any, Dict, List @@ -7,7 +9,10 @@ class AnalysisAgent: + """Agent for analyzing user queries and generating database analysis.""" + def __init__(self, queries_history: list, result_history: list): + """Initialize the analysis agent with query and result history.""" if result_history is None: self.messages = [] else: @@ -23,6 +28,7 @@ def get_analysis( db_description: str, instructions: str = None, ) -> dict: + """Get analysis of user query against database schema.""" formatted_schema = self._format_schema(combined_tables) prompt = self._build_prompt( user_query, formatted_schema, db_description, instructions @@ -86,7 +92,8 @@ def _format_schema(self, schema_data: List) -> str: if col_key == "PRI" else ", FOREIGN KEY" if col_key == "FK" else "" ) - column_str = f" - {col_name} ({col_type},{key_info},{col_key},{nullable}): {col_description}" + column_str = (f" - {col_name} ({col_type},{key_info},{col_key}," + f"{nullable}): {col_description}") table_str += column_str + "\n" # Format foreign keys @@ -161,19 +168,27 @@ def _build_prompt( - Analyze the query's translatability into SQL according to the instructions. - Apply the instructions explicitly. - - If you CANNOT apply instructions in the SQL, explain why under "instructions_comments", "explanation" and reduce your confidence. + - If you CANNOT apply instructions in the SQL, explain why under + "instructions_comments", "explanation" and reduce your confidence. - Penalize confidence appropriately if any part of the instructions is unmet. - - When there several tables that can be used to answer the question, you can combine them in a single SQL query. + - When there several tables that can be used to answer the question, + you can combine them in a single SQL query. Provide your output ONLY in the following JSON structure: ```json {{ "is_sql_translatable": true or false, - "instructions_comments": "Comments about any part of the instructions, especially if they are unclear, impossible, or partially met", - "explanation": "Detailed explanation why the query can or cannot be translated, mentioning instructions explicitly and referencing conversation history if relevant", - "sql_query": "High-level SQL query (you must to applying instructions and use previous answers if the question is a continuation)", - "tables_used": ["list", "of", "tables", "used", "in", "the", "query", "with", "the", "relationships", "between", "them"], + "instructions_comments": ("Comments about any part of the instructions, " + "especially if they are unclear, impossible, " + "or partially met"), + "explanation": ("Detailed explanation why the query can or cannot be " + "translated, mentioning instructions explicitly and " + "referencing conversation history if relevant"), + "sql_query": ("High-level SQL query (you must to applying instructions " + "and use previous answers if the question is a continuation)"), + "tables_used": ["list", "of", "tables", "used", "in", "the", "query", + "with", "the", "relationships", "between", "them"], "missing_information": ["list", "of", "missing", "information"], "ambiguities": ["list", "of", "ambiguities"], "confidence": integer between 0 and 100 @@ -189,14 +204,18 @@ def _build_prompt( 6. Consider if complex calculations are feasible in SQL. 7. Identify multiple interpretations if they exist. 8. Strictly apply instructions; explain and penalize if not possible. - 9. If the question is a follow-up, resolve references using the conversation history and previous answers. + 9. If the question is a follow-up, resolve references using the + conversation history and previous answers. Again: OUTPUT ONLY VALID JSON. No explanations outside the JSON block. """ return prompt class RelevancyAgent: + """Agent for determining relevancy of queries to database schema.""" + def __init__(self, queries_history: list, result_history: list): + """Initialize the relevancy agent with query and result history.""" if result_history is None: self.messages = [] else: @@ -206,6 +225,7 @@ def __init__(self, queries_history: list, result_history: list): self.messages.append({"role": "assistant", "content": result}) def get_answer(self, user_question: str, database_desc: dict) -> dict: + """Get relevancy assessment for user question against database description.""" self.messages.append( { "role": "user", @@ -275,12 +295,15 @@ def get_answer(self, user_question: str, database_desc: dict) -> dict: class FollowUpAgent: + """Agent for handling follow-up questions and conversational context.""" + def __init__(self): - pass + """Initialize the follow-up agent.""" def get_answer( self, user_question: str, conversation_hist: list, database_schema: dict ) -> dict: + """Get answer for follow-up questions using conversation history.""" completion_result = completion( model=Config.COMPLETION_MODEL, messages=[ @@ -333,14 +356,19 @@ def get_answer( }} -4. Ensure your response is concise, polite, and helpful. When asking clarifying questions, be specific and guide the user toward providing the missing details so you can effectively address their query.""" +4. Ensure your response is concise, polite, and helpful. When asking clarifying + questions, be specific and guide the user toward providing the missing details + so you can effectively address their query.""" class TaxonomyAgent: + """Agent for taxonomy classification of questions and SQL queries.""" + def __init__(self): - pass + """Initialize the taxonomy agent.""" def get_answer(self, question: str, sql: str) -> str: + """Get taxonomy classification for a question and SQL pair.""" messages = [ { "content": TAXONOMY_PROMPT.format(QUESTION=question, SQL=sql), diff --git a/api/constants.py b/api/constants.py index 35357011..93532b47 100644 --- a/api/constants.py +++ b/api/constants.py @@ -1,13 +1,20 @@ +"""Constants and benchmark data for the text2sql application.""" + EXAMPLES = { "crm_usecase": [ - "Which companies have generated the most revenue through closed deals, and how much revenue did they generate?", + ("Which companies have generated the most revenue through closed deals, " + "and how much revenue did they generate?"), "How many leads converted into deals over the last month", - "Which companies have open sales opportunities and active SLA agreements in place?", - "Which high-value sales opportunities (value > $50,000) have upcoming meetings scheduled, and what companies are they associated with?", + ("Which companies have open sales opportunities and active SLA agreements " + "in place?"), + ("Which high-value sales opportunities (value > $50,000) have upcoming meetings " + "scheduled, and what companies are they associated with?"), ], "ERP_system": [ - # "What is the total value of all purchase orders created in the last quarter?", - # "Which suppliers have the highest number of active purchase orders, and what is the total value of those orders?", + # ("What is the total value of all purchase orders created in the last " + # "quarter?"), + # ("Which suppliers have the highest number of active purchase orders, " + # "and what is the total value of those orders?"), "What is the total order value for customer Almo Office?", "Show the total amount of all orders placed on 11/24", "What's the profit for order SO2400002?", @@ -43,43 +50,125 @@ BENCHMARK = [ { - "question": "List all contacts who are associated with companies that have at least one active deal in the pipeline, and include the deal stage.", - "sql": "SELECT DISTINCT c.contact_id, c.first_name, c.last_name, d.deal_id, d.deal_name, ds.stage_name FROM contacts AS c JOIN company_contacts AS cc ON c.contact_id = cc.contact_id JOIN companies AS co ON cc.company_id = co.company_id JOIN deals AS d ON co.company_id = d.company_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.is_active = 1;", + "question": ("List all contacts who are associated with companies that have at " + "least one active deal in the pipeline, and include the deal stage."), + "sql": ("SELECT DISTINCT c.contact_id, c.first_name, c.last_name, d.deal_id, " + "d.deal_name, ds.stage_name FROM contacts AS c " + "JOIN company_contacts AS cc ON c.contact_id = cc.contact_id " + "JOIN companies AS co ON cc.company_id = co.company_id " + "JOIN deals AS d ON co.company_id = d.company_id " + "JOIN deal_stages AS ds ON d.stage_id = ds.stage_id " + "WHERE ds.is_active = 1;"), }, { - "question": "Which sales representatives (users) have closed deals worth more than $100,000 in the past year, and what was the total value of deals they closed?", - "sql": "SELECT u.user_id, u.first_name, u.last_name, SUM(d.amount) AS total_closed_value FROM users AS u JOIN deals AS d ON u.user_id = d.owner_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' AND d.close_date >= DATE_SUB(CURDATE(), INTERVAL 1 YEAR) GROUP BY u.user_id HAVING total_closed_value > 100000;", + "question": ("Which sales representatives (users) have closed deals worth more " + "than $100,000 in the past year, and what was the total value of " + "deals they closed?"), + "sql": ("SELECT u.user_id, u.first_name, u.last_name, SUM(d.amount) AS " + "total_closed_value FROM users AS u " + "JOIN deals AS d ON u.user_id = d.owner_id " + "JOIN deal_stages AS ds ON d.stage_id = ds.stage_id " + "WHERE ds.stage_name = 'Closed Won' AND d.close_date >= " + "DATE_SUB(CURDATE(), INTERVAL 1 YEAR) GROUP BY u.user_id " + "HAVING total_closed_value > 100000;"), }, { - "question": "Find all contacts who attended at least one event and were later converted into leads that became opportunities within three months of the event.", - "sql": "SELECT DISTINCT c.contact_id, c.first_name, c.last_name FROM contacts AS c JOIN event_attendees AS ea ON c.contact_id = ea.contact_id JOIN events AS e ON ea.event_id = e.event_id JOIN leads AS l ON c.contact_id = l.contact_id JOIN opportunities AS o ON l.lead_id = o.lead_id WHERE o.created_date BETWEEN e.event_date AND DATE_ADD(e.event_date, INTERVAL 3 MONTH);", + "question": ("Find all contacts who attended at least one event and were later " + "converted into leads that became opportunities within three months " + "of the event."), + "sql": ("SELECT DISTINCT c.contact_id, c.first_name, c.last_name " + "FROM contacts AS c " + "JOIN event_attendees AS ea ON c.contact_id = ea.contact_id " + "JOIN events AS e ON ea.event_id = e.event_id " + "JOIN leads AS l ON c.contact_id = l.contact_id " + "JOIN opportunities AS o ON l.lead_id = o.lead_id " + "WHERE o.created_date BETWEEN e.event_date AND " + "DATE_ADD(e.event_date, INTERVAL 3 MONTH);"), }, { - "question": "Which customers have the highest lifetime value based on their total invoice payments, including refunds and discounts?", - "sql": "SELECT c.contact_id, c.first_name, c.last_name, SUM(i.total_amount - COALESCE(r.refund_amount, 0) - COALESCE(d.discount_amount, 0)) AS lifetime_value FROM contacts AS c JOIN orders AS o ON c.contact_id = o.contact_id JOIN invoices AS i ON o.order_id = i.order_id LEFT JOIN refunds AS r ON i.invoice_id = r.invoice_id LEFT JOIN discounts AS d ON i.invoice_id = d.invoice_id GROUP BY c.contact_id ORDER BY lifetime_value DESC LIMIT 10;", + "question": ("Which customers have the highest lifetime value based on their " + "total invoice payments, including refunds and discounts?"), + "sql": ("SELECT c.contact_id, c.first_name, c.last_name, " + "SUM(i.total_amount - COALESCE(r.refund_amount, 0) - " + "COALESCE(d.discount_amount, 0)) AS lifetime_value " + "FROM contacts AS c " + "JOIN orders AS o ON c.contact_id = o.contact_id " + "JOIN invoices AS i ON o.order_id = i.order_id " + "LEFT JOIN refunds AS r ON i.invoice_id = r.invoice_id " + "LEFT JOIN discounts AS d ON i.invoice_id = d.invoice_id " + "GROUP BY c.contact_id ORDER BY lifetime_value DESC LIMIT 10;"), }, { - "question": "Show all deals that have involved at least one email exchange, one meeting, and one phone call with a contact in the past six months.", - "sql": "SELECT DISTINCT d.deal_id, d.deal_name FROM deals AS d JOIN contacts AS c ON d.contact_id = c.contact_id JOIN emails AS e ON c.contact_id = e.contact_id JOIN meetings AS m ON c.contact_id = m.contact_id JOIN phone_calls AS p ON c.contact_id = p.contact_id WHERE e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH);", + "question": ("Show all deals that have involved at least one email exchange, " + "one meeting, and one phone call with a contact in the past six months."), + "sql": ("SELECT DISTINCT d.deal_id, d.deal_name FROM deals AS d " + "JOIN contacts AS c ON d.contact_id = c.contact_id " + "JOIN emails AS e ON c.contact_id = e.contact_id " + "JOIN meetings AS m ON c.contact_id = m.contact_id " + "JOIN phone_calls AS p ON c.contact_id = p.contact_id " + "WHERE e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) " + "AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH) " + "AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 6 MONTH);"), }, { - "question": "Which companies have the highest number of active support tickets, and how does their number of tickets correlate with their total deal value?", - "sql": "SELECT co.company_id, co.company_name, COUNT(st.ticket_id) AS active_tickets, SUM(d.amount) AS total_deal_value FROM companies AS co LEFT JOIN support_tickets AS st ON co.company_id = st.company_id AND st.status = 'Open' LEFT JOIN deals AS d ON co.company_id = d.company_id GROUP BY co.company_id ORDER BY active_tickets DESC;", + "question": ("Which companies have the highest number of active support tickets, " + "and how does their number of tickets correlate with their total deal value?"), + "sql": ("SELECT co.company_id, co.company_name, COUNT(st.ticket_id) AS active_tickets, " + "SUM(d.amount) AS total_deal_value FROM companies AS co " + "LEFT JOIN support_tickets AS st ON co.company_id = st.company_id " + "AND st.status = 'Open' " + "LEFT JOIN deals AS d ON co.company_id = d.company_id " + "GROUP BY co.company_id ORDER BY active_tickets DESC;"), }, { - "question": "Retrieve all contacts who are assigned to a sales rep but have not been contacted via email, phone, or meeting in the past three months.", - "sql": "SELECT c.contact_id, c.first_name, c.last_name FROM contacts AS c JOIN users AS u ON c.owner_id = u.user_id LEFT JOIN emails AS e ON c.contact_id = e.contact_id AND e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) LEFT JOIN phone_calls AS p ON c.contact_id = p.contact_id AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) LEFT JOIN meetings AS m ON c.contact_id = m.contact_id AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) WHERE e.contact_id IS NULL AND p.contact_id IS NULL AND m.contact_id IS NULL;", + "question": ("Retrieve all contacts who are assigned to a sales rep but have not " + "been contacted via email, phone, or meeting in the past three months."), + "sql": ("SELECT c.contact_id, c.first_name, c.last_name FROM contacts AS c " + "JOIN users AS u ON c.owner_id = u.user_id " + "LEFT JOIN emails AS e ON c.contact_id = e.contact_id " + "AND e.sent_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) " + "LEFT JOIN phone_calls AS p ON c.contact_id = p.contact_id " + "AND p.call_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) " + "LEFT JOIN meetings AS m ON c.contact_id = m.contact_id " + "AND m.meeting_date >= DATE_SUB(CURDATE(), INTERVAL 3 MONTH) " + "WHERE e.contact_id IS NULL AND p.contact_id IS NULL " + "AND m.contact_id IS NULL;"), }, { - "question": "Which email campaigns resulted in the highest number of closed deals, and what was the average deal size for those campaigns?", - "sql": "SELECT ec.campaign_id, ec.campaign_name, COUNT(d.deal_id) AS closed_deals, AVG(d.amount) AS avg_deal_value FROM email_campaigns AS ec JOIN contacts AS c ON ec.campaign_id = c.campaign_id JOIN deals AS d ON c.contact_id = d.contact_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' GROUP BY ec.campaign_id ORDER BY closed_deals DESC;", + "question": ("Which email campaigns resulted in the highest number of closed deals, " + "and what was the average deal size for those campaigns?"), + "sql": ("SELECT ec.campaign_id, ec.campaign_name, COUNT(d.deal_id) AS closed_deals, " + "AVG(d.amount) AS avg_deal_value FROM email_campaigns AS ec " + "JOIN contacts AS c ON ec.campaign_id = c.campaign_id " + "JOIN deals AS d ON c.contact_id = d.contact_id " + "JOIN deal_stages AS ds ON d.stage_id = ds.stage_id " + "WHERE ds.stage_name = 'Closed Won' GROUP BY ec.campaign_id " + "ORDER BY closed_deals DESC;"), }, { - "question": "Find the average time it takes for a lead to go from creation to conversion into a deal, broken down by industry.", - "sql": "SELECT ind.industry_name, AVG(DATEDIFF(d.close_date, l.created_date)) AS avg_conversion_time FROM leads AS l JOIN companies AS co ON l.company_id = co.company_id JOIN industries AS ind ON co.industry_id = ind.industry_id JOIN opportunities AS o ON l.lead_id = o.lead_id JOIN deals AS d ON o.opportunity_id = d.opportunity_id WHERE d.stage_id IN (SELECT stage_id FROM deal_stages WHERE stage_name = 'Closed Won') GROUP BY ind.industry_name ORDER BY avg_conversion_time ASC;", + "question": ("Find the average time it takes for a lead to go from creation to " + "conversion into a deal, broken down by industry."), + "sql": ("SELECT ind.industry_name, AVG(DATEDIFF(d.close_date, l.created_date)) " + "AS avg_conversion_time FROM leads AS l " + "JOIN companies AS co ON l.company_id = co.company_id " + "JOIN industries AS ind ON co.industry_id = ind.industry_id " + "JOIN opportunities AS o ON l.lead_id = o.lead_id " + "JOIN deals AS d ON o.opportunity_id = d.opportunity_id " + "WHERE d.stage_id IN (SELECT stage_id FROM deal_stages " + "WHERE stage_name = 'Closed Won') GROUP BY ind.industry_name " + "ORDER BY avg_conversion_time ASC;"), }, { - "question": "Which sales reps (users) have the highest win rate, calculated as the percentage of their assigned leads that convert into closed deals?", - "sql": "SELECT u.user_id, u.first_name, u.last_name, COUNT(DISTINCT d.deal_id) / COUNT(DISTINCT l.lead_id) * 100 AS win_rate FROM users AS u JOIN leads AS l ON u.user_id = l.owner_id LEFT JOIN opportunities AS o ON l.lead_id = o.lead_id LEFT JOIN deals AS d ON o.opportunity_id = d.opportunity_id JOIN deal_stages AS ds ON d.stage_id = ds.stage_id WHERE ds.stage_name = 'Closed Won' GROUP BY u.user_id ORDER BY win_rate DESC;", + "question": ("Which sales reps (users) have the highest win rate, calculated as " + "the percentage of their assigned leads that convert into closed deals?"), + "sql": ("SELECT u.user_id, u.first_name, u.last_name, " + "COUNT(DISTINCT d.deal_id) / COUNT(DISTINCT l.lead_id) * 100 AS win_rate " + "FROM users AS u " + "JOIN leads AS l ON u.user_id = l.owner_id " + "LEFT JOIN opportunities AS o ON l.lead_id = o.lead_id " + "LEFT JOIN deals AS d ON o.opportunity_id = d.opportunity_id " + "JOIN deal_stages AS ds ON d.stage_id = ds.stage_id " + "WHERE ds.stage_name = 'Closed Won' GROUP BY u.user_id " + "ORDER BY win_rate DESC;"), }, ] diff --git a/api/graph.py b/api/graph.py index d7986787..1fa783aa 100644 --- a/api/graph.py +++ b/api/graph.py @@ -61,7 +61,8 @@ def find( previous_queries = queries_history[:-1] logging.info( - f"Calling to an LLM to find relevant tables and columns for the query: {user_query}" + "Calling to an LLM to find relevant tables and columns for the query: %s", + user_query ) # Call the completion model to get the relevant Cypher queries to retrieve # from the Graph that represent the Database schema. @@ -92,16 +93,16 @@ def find( # Parse JSON string and convert to Pydantic model json_data = json.loads(json_str) descriptions = Descriptions(**json_data) - logging.info(f"Find tables based on: {descriptions.tables_descriptions}") + logging.info("Find tables based on: %s", descriptions.tables_descriptions) tables_des = _find_tables(graph, descriptions.tables_descriptions) - logging.info(f"Find tables based on columns: {descriptions.columns_descriptions}") + logging.info("Find tables based on columns: %s", descriptions.columns_descriptions) tables_by_columns_des = _find_tables_by_columns(graph, descriptions.columns_descriptions) # table names for sphere and route extraction base_tables_names = [table[0] for table in tables_des] logging.info("Extracting tables by sphere") tables_by_sphere = _find_tables_sphere(graph, base_tables_names) - logging.info(f"Extracting tables by connecting routes {base_tables_names}") + logging.info("Extracting tables by connecting routes %s", base_tables_names) tables_by_route, _ = find_connecting_tables(graph, base_tables_names) combined_tables = _get_unique_tables( tables_des + tables_by_columns_des + tables_by_route + tables_by_sphere @@ -225,8 +226,8 @@ def _get_unique_tables(tables_list): table_info[3] = [dict(od) for od in table_info[3]] table_info[2] = "Foreign keys: " + table_info[2] unique_tables[table_name] = table_info - except: - print(f"Error: {table_info}") + except Exception as e: + print(f"Error: {table_info}, Exception: {e}") # Return the values (the unique table info lists) return list(unique_tables.values()) diff --git a/api/helpers/crm_data_generator.py b/api/helpers/crm_data_generator.py index e255e4fa..9330c2a9 100644 --- a/api/helpers/crm_data_generator.py +++ b/api/helpers/crm_data_generator.py @@ -1,3 +1,10 @@ +""" +CRM data generator module for creating complete database schemas with relationships. + +This module provides functionality to generate comprehensive CRM database schemas +with proper primary/foreign key relationships and table structures. +""" + import json import os import time @@ -22,7 +29,7 @@ def load_initial_schema(file_path: str) -> Dict[str, Any]: """Load the initial schema file with table names""" try: - with open(file_path, "r") as file: + with open(file_path, "r", encoding="utf-8") as file: schema = json.load(file) print(f"Loaded initial schema with {len(schema.get('tables', {}))} tables") return schema @@ -46,7 +53,7 @@ def save_schema(schema: Dict[str, Any], output_file: str = OUTPUT_FILE) -> None: "table_relationships": {k: list(v) for k, v in key_registry["table_relationships"].items()}, } - with open(output_file, "w") as file: + with open(output_file, "w", encoding="utf-8") as file: json.dump(schema, file, indent=2) print(f"Schema saved to {output_file}") @@ -71,7 +78,7 @@ def update_key_registry(table_name: str, table_data: Dict[str, Any]) -> None: if table_name not in key_registry["table_relationships"]: key_registry["table_relationships"][table_name] = set() - for fk_name, fk_data in table_data["foreign_keys"].items(): + for fk_data in table_data["foreign_keys"].values(): column = fk_data.get("column") ref_table = fk_data.get("referenced_table") ref_column = fk_data.get("referenced_column") @@ -133,18 +140,18 @@ def get_table_prompt( related_tables = find_related_tables(table_name, all_table_names) related_tables_str = ", ".join(related_tables) if related_tables else "None identified yet" - # Suggest primary key pattern - table_base = table_name.split("_")[0] if "_" in table_name else table_name - suggested_pk = f"{table_name}_id" # Default pattern + # # Suggest primary key pattern + # table_base = table_name.split("_")[0] if "_" in table_name else table_name + # suggested_pk = f"{table_name}_id" # Default pattern - # Check if related tables have primary keys to follow same pattern - for related in related_tables: - if related in key_registry["primary_keys"]: - related_pk = key_registry["primary_keys"][related] - if related_pk.endswith("_id") and related in related_pk: - # Follow the same pattern - suggested_pk = f"{table_name}_id" - break + # # Check if related tables have primary keys to follow same pattern + # for related in related_tables: + # if related in key_registry["primary_keys"]: + # related_pk = key_registry["primary_keys"][related] + # if related_pk.endswith("_id") and related in related_pk: + # # Follow the same pattern + # suggested_pk = f"{table_name}_id" + # break # Prepare foreign key suggestions fk_suggestions = [] @@ -183,7 +190,8 @@ def get_table_prompt( contacts_example = """ { "contacts": { - "description": "Stores information about individual contacts within the CRM system, including personal details and relationship to companies.", + "description": ("Stores information about individual contacts within the CRM " + "system, including personal details and relationship to companies."), "columns": { "contact_id": { "description": "Unique identifier for each contact", @@ -283,7 +291,8 @@ def get_table_prompt( table_context = get_table_context(table_name, related_tables) keys = json.dumps(topology["tables"][table_name]) prompt = f""" -You are an expert database architect specializing in CRM systems. Create a detailed JSON schema for the '{table_name}' table in our CRM database. +You are an expert database architect specializing in CRM systems. Create a detailed +JSON schema for the '{table_name}' table in our CRM database. CONTEXT ABOUT THIS TABLE: {table_context} @@ -331,7 +340,8 @@ def get_table_prompt( - For many-to-many relationships, create appropriate junction tables - Ensure referential integrity with foreign key constraints -Return ONLY valid JSON for the '{table_name}' table structure without any explanation or additional text: +Return ONLY valid JSON for the '{table_name}' table structure without any +explanation or additional text: {{ "{table_name}": {{ "description": "...", @@ -393,10 +403,11 @@ def get_table_context(table_name: str, related_tables: List[str]) -> str: context = f"The '{table_name}' table appears to be " # Check if this is a junction/linking table - if "_" in table_name and not any(p in table_name for p in relationship_patterns.keys()): + if "_" in table_name and not any(p in table_name for p in relationship_patterns): parts = table_name.split("_") if len(parts) == 2 and all(len(p) > 2 for p in parts): - return f"This appears to be a junction table linking '{parts[0]}' and '{parts[1]}', likely with a many-to-many relationship." + return (f"This appears to be a junction table linking '{parts[0]}' and " + f"'{parts[1]}', likely with a many-to-many relationship.") # Check for main entities for entity, description in entities.items(): @@ -449,8 +460,8 @@ def call_llm_api(prompt: str, retries: int = MAX_RETRIES) -> Optional[str]: ) if result: return result - else: - print(f"Empty response from API (attempt {attempt}/{retries})") + + print(f"Empty response from API (attempt {attempt}/{retries})") except requests.exceptions.RequestException as e: print(f"API request error (attempt {attempt}/{retries}): {e}") @@ -479,7 +490,7 @@ def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any # Cleanup any trailing/leading text start_idx = response.find("{") end_idx = response.rfind("}") + 1 - if start_idx >= 0 and end_idx > start_idx: + if 0 <= start_idx < end_idx: response = response[start_idx:end_idx] parsed = json.loads(response) @@ -498,15 +509,15 @@ def parse_llm_response(response: str, table_name: str) -> Optional[Dict[str, Any print(f"Warning: Column {col_name} is missing required attributes") return {table_name: table_data} - else: - missing = [key for key in required_keys if key not in table_data] - print(f"Warning: Table schema missing required sections: {missing}") - return {table_name: table_data} # Return anyway, but with warning - else: - # Try to get the first key if table_name is not found - first_key = next(iter(parsed)) - print(f"Warning: Table name mismatch. Expected {table_name}, got {first_key}") - return {table_name: parsed[first_key]} + + missing = [key for key in required_keys if key not in table_data] + print(f"Warning: Table schema missing required sections: {missing}") + return {table_name: table_data} # Return anyway, but with warning + + # Try to get the first key if table_name is not found + first_key = next(iter(parsed)) + print(f"Warning: Table name mismatch. Expected {table_name}, got {first_key}") + return {table_name: parsed[first_key]} except Exception as e: print(f"Error parsing LLM response for {table_name}: {e}") print(f"Raw response: {response[:500]}...") # Show first 500 chars @@ -556,6 +567,7 @@ def process_table( def main(): + """Main function to generate complete CRM schema with relationships.""" # Load the initial schema with table names initial_schema_path = "examples/crm_tables.json" # Replace with your actual file path initial_schema = load_initial_schema(initial_schema_path) @@ -572,7 +584,7 @@ def main(): # If we have existing work, load it if os.path.exists(OUTPUT_FILE): try: - with open(OUTPUT_FILE, "r") as file: + with open(OUTPUT_FILE, "r", encoding="utf-8") as file: schema = json.load(file) print(f"Loaded existing schema from {OUTPUT_FILE}") except Exception as e: @@ -595,7 +607,8 @@ def table_priority(table_name): # Process tables for i, table_name in enumerate(tables): print( - f"\nProcessing table {i+1}/{len(tables)}: {table_name} (Priority: {table_priority(table_name)})" + f"\nProcessing table {i+1}/{len(tables)}: {table_name} " + f"(Priority: {table_priority(table_name)})" ) schema = process_table(table_name, schema, all_table_names, topology) @@ -616,13 +629,18 @@ def table_priority(table_name): def generate_keys(tables) -> Dict[str, Any]: + """Generate primary and foreign keys for CRM tables.""" path = "examples/crm_topology.json" + last_key = 0 # Initialize default value + schema = {"tables": {}} # Initialize default schema + # If we have existing work, load it if os.path.exists(path): try: - with open(path, "r") as file: + with open(path, "r", encoding="utf-8") as file: schema = json.load(file) - last_key = tables.index(list(schema["tables"].keys())[-1]) + if schema.get("tables"): + last_key = tables.index(list(schema["tables"].keys())[-1]) print(f"Loaded existing schema from {path}") except Exception as e: print(f"Error loading existing schema: {e}") @@ -652,7 +670,7 @@ def generate_keys(tables) -> Dict[str, Any]: new_table = json.loads(response) schema["tables"].update(new_table) - with open(path, "w") as file: + with open(path, "w", encoding="utf-8") as file: json.dump(schema, file, indent=2) print(f"Schema saved to {path}") print(f"Final schema saved to {path}") diff --git a/api/index.py b/api/index.py index a13f685a..d8bbbffe 100644 --- a/api/index.py +++ b/api/index.py @@ -12,7 +12,7 @@ from flask import Blueprint, Flask, Response, jsonify, render_template, request, stream_with_context from api.agents import AnalysisAgent, RelevancyAgent -from api.constants import BENCHMARK, EXAMPLES +from api.constants import EXAMPLES from api.extensions import db from api.graph import find, get_db_description from api.loaders.csv_loader import CSVLoader @@ -34,7 +34,7 @@ def verify_token(token): """Verify the token provided in the request""" - return token == SECRET_TOKEN or token == SECRET_TOKEN_ERP or token == "null" + return token in (SECRET_TOKEN, SECRET_TOKEN_ERP, 'null') def token_required(f): @@ -77,26 +77,24 @@ def graphs(): """ This route is used to list all the graphs that are available in the database. """ - graphs = db.list_graphs() + user_graphs = db.list_graphs() if os.getenv("USER_TOKEN") == SECRET_TOKEN: - if "hospital" in graphs: + if "hospital" in user_graphs: return ["hospital"] - else: - return [] + return [] if os.getenv("USER_TOKEN") == SECRET_TOKEN_ERP: - if "ERP_system" in graphs: + if "ERP_system" in user_graphs: return ["ERP_system"] - else: - return ["crm_usecase"] - elif os.getenv("USER_TOKEN") == "null": - if "crm_usecase" in graphs: + return ["crm_usecase"] + + if os.getenv("USER_TOKEN") == "null": + if "crm_usecase" in user_graphs: return ["crm_usecase"] - else: - return [] - else: - graphs.remove("hospital") - return graphs + return [] + + user_graphs.remove("hospital") + return user_graphs @app.route("/graphs", methods=["POST"]) @@ -189,7 +187,7 @@ def query(graph_id: str): if not queries_history: return jsonify({"error": "Invalid or missing JSON data"}), 400 - logging.info(f"User Query: {queries_history[-1]}") + logging.info("User Query: %s", queries_history[-1]) # Create a generator function for streaming def generate(): @@ -200,31 +198,32 @@ def generate(): yield json.dumps(step) + MESSAGE_DELIMITER db_description = get_db_description(graph_id) # Ensure the database description is loaded - logging.info(f"Calling to relvancy agent with query: {queries_history[-1]}") + logging.info("Calling to relvancy agent with query: %s", queries_history[-1]) answer_rel = agent_rel.get_answer(queries_history[-1], db_description) if answer_rel["status"] != "On-topic": step = { "type": "followup_questions", "message": "Off topic question: " + answer_rel["reason"], } - logging.info(f"SQL Fail reason: {answer_rel["reason"]}") + logging.info("SQL Fail reason: %s", answer_rel["reason"]) yield json.dumps(step) + MESSAGE_DELIMITER else: # Use a thread pool to enforce timeout with ThreadPoolExecutor(max_workers=1) as executor: future = executor.submit(find, graph_id, queries_history, db_description) try: - success, result, _ = future.result(timeout=120) + _, result, _ = future.result(timeout=120) except FuturesTimeoutError: yield json.dumps( { "type": "error", - "message": "Timeout error while finding tables relevant to your request.", + "message": ("Timeout error while finding tables relevant to " + "your request."), } ) + MESSAGE_DELIMITER return except Exception as e: - logging.info(f"Error in find function: {e}") + logging.info("Error in find function: %s", e) yield json.dumps( {"type": "error", "message": "Error in find function"} ) + MESSAGE_DELIMITER @@ -232,12 +231,12 @@ def generate(): step = {"type": "reasoning_step", "message": "Step 2: Generating SQL query"} yield json.dumps(step) + MESSAGE_DELIMITER - logging.info(f"Calling to analysis agent with query: {queries_history[-1]}") + logging.info("Calling to analysis agent with query: %s", queries_history[-1]) answer_an = agent_an.get_analysis( queries_history[-1], result, db_description, instructions ) - logging.info(f"SQL Result: {answer_an['sql_query']}") + logging.info("SQL Result: %s", answer_an['sql_query']) yield json.dumps( { "type": "final_result", @@ -274,12 +273,12 @@ def suggestions(): # Return up to 3 examples, or all if less than 3 suggestion_questions = random.sample(graph_examples, min(3, len(graph_examples))) return jsonify(suggestion_questions) - else: - # If graph doesn't exist in EXAMPLES, return empty list - return jsonify([]) + + # If graph doesn't exist in EXAMPLES, return empty list + return jsonify([]) except Exception as e: - logging.error(f"Error fetching suggestions: {e}") + logging.error("Error fetching suggestions: %s", e) return jsonify([]), 500 From 83dfa60ad4c2fc7bf029d4fb0742cf72102e729a Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 12:59:52 +0300 Subject: [PATCH 08/58] clean more lints --- api/agents.py | 9 ++- api/helpers/crm_data_generator.py | 5 +- api/loaders/csv_loader.py | 39 ++++++++--- api/loaders/graph_loader.py | 3 +- api/loaders/odata_loader.py | 12 ++-- api/loaders/schema_validator.py | 109 +++++++++++++++++++++--------- 6 files changed, 125 insertions(+), 52 deletions(-) diff --git a/api/agents.py b/api/agents.py index c7914953..8e268873 100644 --- a/api/agents.py +++ b/api/agents.py @@ -385,14 +385,17 @@ def get_answer(self, question: str, sql: str) -> str: return answer -TAXONOMY_PROMPT = """You are an advanced taxonomy generator. For a pair of question and SQL query provde a single clarification question to the user. -* For any SQL query that contain WHERE clause, provide a clarification question to the user about the generated value. +TAXONOMY_PROMPT = """You are an advanced taxonomy generator. For a pair of question and SQL query \ +provde a single clarification question to the user. +* For any SQL query that contain WHERE clause, provide a clarification question to the user about the \ +generated value. * Your question can contain more than one clarification related to WHERE clause. * Please asked only about the clarifications that you need and not extand the answer. * Please ask in a polite, humen, and concise manner. * Do not meantion any tables or columns in your ouput!. * If you dont need any clarification, please answer with "I don't need any clarification." -* The user didnt saw the SQL queryor the tables, so please understand this position and ask the clarification in that way he have the relevent information to answer. +* The user didnt saw the SQL queryor the tables, so please understand this position and ask the \ +clarification in that way he have the relevent information to answer. * When you ask the user to confirm a value, please provide the value in your answer. * Mention only question about values and dont mention the SQL query or the tables in your answer. diff --git a/api/helpers/crm_data_generator.py b/api/helpers/crm_data_generator.py index 9330c2a9..01aa1262 100644 --- a/api/helpers/crm_data_generator.py +++ b/api/helpers/crm_data_generator.py @@ -183,7 +183,10 @@ def get_table_prompt( and "columns" in existing_tables[related] and example_count < 2 ): - related_examples += f"\nRelated table example:\n```json\n{json.dumps({related: existing_tables[related]}, indent=2)}\n```\n" + related_examples += ( + f"\nRelated table example:\n```json\n" + f"{json.dumps({related: existing_tables[related]}, indent=2)}\n```\n" + ) example_count += 1 # Use contacts table as primary example if no related examples found diff --git a/api/loaders/csv_loader.py b/api/loaders/csv_loader.py index 05d22659..6ad0867d 100644 --- a/api/loaders/csv_loader.py +++ b/api/loaders/csv_loader.py @@ -128,7 +128,9 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # 'source_field': source_field, # 'target_table': target_table, # 'cardinality': cardinality, - # 'target_field': df.to_dict("records")[idx+1]['Array Field'] if not pd.isna(df.to_dict("records")[idx+1]['Array Field']) else '' + # 'target_field': df.to_dict("records")[idx+1]['Array Field'] \ + # if not pd.isna(df.to_dict("records")[idx+1] \ + # ['Array Field']) else '' # }) tables[target_table]["description"] = field_desc @@ -198,7 +200,8 @@ def load(graph_id: str, data) -> Tuple[bool, str]: except Exception as e: return False, f"Error loading CSV: {str(e)}" # else: - # # For case 2: when no primary key table exists, connect all FK tables to each other + # # For case 2: when no primary key table exists, \ + # # connect all FK tables to each other # graph.query( # """ # CREATE (src: Column {name: $col, cardinality: $cardinality}) @@ -262,16 +265,24 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # batch_size = 50 # col_descriptions = table_info['col_descriptions'] # for batch in tqdm.tqdm( -# [col_descriptions[i:i + batch_size] for i in range(0, len(col_descriptions), batch_size)], +# [col_descriptions[i:i + batch_size] \ +# for i in range(0, len(col_descriptions), batch_size)], # desc=f"Creating embeddings for {table_name}"): -# embedding_result = embedding(model='bedrock/cohere.embed-english-v3', input=batch[:95], aws_profile_name=Config.AWS_PROFILE, aws_region_name=Config.AWS_REGION) +# embedding_result = embedding( +# model='bedrock/cohere.embed-english-v3', +# input=batch[:95], +# aws_profile_name=Config.AWS_PROFILE, +# aws_region_name=Config.AWS_REGION) # embed_columns.extend([emb.values for emb in embedding_result.embeddings]) # except Exception as e: # print(f"Error creating embeddings: {str(e)}") # # Create column nodes -# for idx, (col_name, col_info) in tqdm.tqdm(enumerate(table_info['columns'].items()), desc=f"Creating columns for {table_name}", total=len(table_info['columns'])): +# for idx, (col_name, col_info) in tqdm.tqdm( +# enumerate(table_info['columns'].items()), +# desc=f"Creating columns for {table_name}", +# total=len(table_info['columns'])): # # embedding_result = embedding( # # model=Config.EMBEDDING_MODEL, # # input=[col_info['description'] if col_info['description'] else col_name] @@ -309,14 +320,20 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # ) # # Third pass: Create relationships -# for table_name, table_info in tqdm.tqdm(tables.items(), desc="Creating relationships"): +# for table_name, table_info in tqdm.tqdm(tables.items(), \ +# desc="Creating relationships"): # for rel in table_info['relationships']: # source_field = rel['source_field'] # target_table = rel['target_table'] # cardinality = rel['cardinality'] -# target_field = rel['target_field']#list(tables[tables[table_name]['relationships'][-1]['target_table']]['columns'].keys())[0] +# target_field = rel['target_field'] # \ +# # list(tables[tables[table_name]['relationships'][-1] \ +# # ['target_table']]['columns'].keys())[0] # # Create constraint name -# constraint_name = f"fk_{table_name.replace('.', '_')}_{source_field}_to_{target_table.replace('.', '_')}" +# constraint_name = ( +# f"fk_{table_name.replace('.', '_')}_{source_field}_to_" +# f"{target_table.replace('.', '_')}" +# ) # # Create relationship if both tables and columns exist # try: @@ -343,7 +360,8 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # except Exception as e: # print(f"Warning: Could not create relationship: {str(e)}") # continue -# for key, tables_info in tqdm.tqdm(rel_table.items(), desc="Creating relationships from names"): +# for key, tables_info in tqdm.tqdm(rel_table.items(), \ +# desc="Creating relationships from names"): # if len(tables_info['fk_tables']) > 0: # fk_tables = list(set(tables_info['fk_tables'])) # if len(tables_info['primary_key_table']) > 0: @@ -369,7 +387,8 @@ def load(graph_id: str, data) -> Tuple[bool, str]: # } # ) # else: -# # For case 2: when no primary key table exists, connect all FK tables to each other +# # For case 2: when no primary key table exists, \ +# # connect all FK tables to each other # graph.query( # """ # CREATE (src: Column {name: $col, cardinality: $cardinality}) diff --git a/api/loaders/graph_loader.py b/api/loaders/graph_loader.py index a6edc96e..82fe5342 100644 --- a/api/loaders/graph_loader.py +++ b/api/loaders/graph_loader.py @@ -86,7 +86,8 @@ def load_to_graph( ) # Batch embeddings for table columns - # TODO: Check if the embedding model and description are correct (without 2 sources of truth) + # TODO: Check if the embedding model and description are correct \ + # (without 2 sources of truth) batch_flag = True col_descriptions = table_info.get("col_descriptions") if col_descriptions is None: diff --git a/api/loaders/odata_loader.py b/api/loaders/odata_loader.py index 4562dd77..d498e4dc 100644 --- a/api/loaders/odata_loader.py +++ b/api/loaders/odata_loader.py @@ -73,7 +73,8 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: print(f"Error parsing property {prop_name} for entity {entity_name}") continue - # = {prop.get("Name"): prop.get("Type") for prop in entity_type.findall("edm:Property", namespaces)} + # = {prop.get("Name"): prop.get("Type") \ + # for prop in entity_type.findall("edm:Property", namespaces)} description = entity_type.findall("edm:Annotation", namespaces) if len(description) > 0: entities[entity_name]["description"] = ( @@ -111,13 +112,15 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: source_fields = entities.get(entity_name, {})["columns"] target_fields = entities.get(target_entity, {})["columns"] - # TODO This usage is for demonstration purposes only, it should be replaced with a more robust method + # TODO This usage is for demonstration purposes only, it should be \ + # replaced with a more robust method source_col, target_col = guess_relationship_columns(source_fields, target_fields) if source_col and target_col: # Store the relationship if rel_name not in relationships: relationships[rel_name] = [] - # src_col, tgt_col = guess_relationship_columns(source_entity, target_entity, entities[source_entity], entities[target_entity]) + # src_col, tgt_col = guess_relationship_columns(source_entity, \ + # target_entity, entities[source_entity], entities[target_entity]) relationships[rel_name].append( { "from": source_entity, @@ -133,7 +136,8 @@ def _parse_odata_schema(data) -> Tuple[dict, dict]: return entities, relationships -# TODO: this funtion is for demonstration purposes only, it should be replaced with a more robust method +# TODO: this funtion is for demonstration purposes only, it should be \ +# replaced with a more robust method def guess_relationship_columns(source_fields, target_fields): for src_key, src_meta in source_fields.items(): if src_key == "description": diff --git a/api/loaders/schema_validator.py b/api/loaders/schema_validator.py index 1f0ae169..32c54cb1 100644 --- a/api/loaders/schema_validator.py +++ b/api/loaders/schema_validator.py @@ -1,8 +1,19 @@ +"""Schema validation module for table schemas.""" + REQUIRED_COLUMN_KEYS = {"description", "type", "null", "key", "default"} VALID_NULL_VALUES = {"YES", "NO"} def validate_table_schema(schema): + """ + Validate a table schema structure. + + Args: + schema (dict): The schema dictionary to validate + + Returns: + list: List of validation errors found + """ errors = [] # Validate top-level database key @@ -15,44 +26,76 @@ def validate_table_schema(schema): return errors for table_name, table_data in schema["tables"].items(): - if not table_data.get("description"): - errors.append(f"Table '{table_name}' is missing a description") + errors.extend(_validate_table(table_name, table_data)) - if "columns" not in table_data or not isinstance(table_data["columns"], dict): - errors.append(f"Table '{table_name}' has no valid 'columns' definition") - continue + return errors - for column_name, column_data in table_data["columns"].items(): - # Check for missing required keys - missing_keys = REQUIRED_COLUMN_KEYS - column_data.keys() - if missing_keys: - errors.append( - f"Column '{column_name}' in table '{table_name}' is missing keys: {missing_keys}" - ) - continue - # Validate non-empty description - if not column_data.get("description"): - errors.append( - f"Column '{column_name}' in table '{table_name}' has an empty description" - ) +def _validate_table(table_name, table_data): + """Validate a single table's structure.""" + errors = [] + + if not table_data.get("description"): + errors.append(f"Table '{table_name}' is missing a description") + + if "columns" not in table_data or not isinstance(table_data["columns"], dict): + errors.append(f"Table '{table_name}' has no valid 'columns' definition") + return errors + + for column_name, column_data in table_data["columns"].items(): + errors.extend(_validate_column(table_name, column_name, column_data)) + + # Optional: validate foreign keys + if "foreign_keys" in table_data: + errors.extend(_validate_foreign_keys(table_name, table_data["foreign_keys"])) + + return errors + - # Validate 'null' field - if column_data["null"] not in VALID_NULL_VALUES: +def _validate_column(table_name, column_name, column_data): + """Validate a single column's structure.""" + errors = [] + + # Check for missing required keys + missing_keys = REQUIRED_COLUMN_KEYS - column_data.keys() + if missing_keys: + errors.append( + f"Column '{column_name}' in table '{table_name}' " + f"is missing keys: {missing_keys}" + ) + return errors + + # Validate non-empty description + if not column_data.get("description"): + errors.append( + f"Column '{column_name}' in table '{table_name}' has an empty description" + ) + + # Validate 'null' field + if column_data["null"] not in VALID_NULL_VALUES: + errors.append( + f"Column '{column_name}' in table '{table_name}' " + f"has invalid 'null' value: {column_data['null']}" + ) + + return errors + + +def _validate_foreign_keys(table_name, foreign_keys): + """Validate foreign keys structure.""" + errors = [] + + if not isinstance(foreign_keys, dict): + errors.append( + f"Foreign keys for table '{table_name}' must be a dictionary" + ) + return errors + + for fk_name, fk_data in foreign_keys.items(): + for key in ("column", "referenced_table", "referenced_column"): + if key not in fk_data or not fk_data[key]: errors.append( - f"Column '{column_name}' in table '{table_name}' has invalid 'null' value: {column_data['null']}" + f"Foreign key '{fk_name}' in table '{table_name}' is missing '{key}'" ) - # Optional: validate foreign keys - if "foreign_keys" in table_data: - if not isinstance(table_data["foreign_keys"], dict): - errors.append(f"Foreign keys for table '{table_name}' must be a dictionary") - else: - for fk_name, fk_data in table_data["foreign_keys"].items(): - for key in ("column", "referenced_table", "referenced_column"): - if key not in fk_data or not fk_data[key]: - errors.append( - f"Foreign key '{fk_name}' in table '{table_name}' is missing '{key}'" - ) - return errors From c2c96dff30477452794c4aca3a95394b58a90a7c Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 14:32:57 +0300 Subject: [PATCH 09/58] fix build --- api/loaders/csv_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/api/loaders/csv_loader.py b/api/loaders/csv_loader.py index 6ad0867d..8beda5a5 100644 --- a/api/loaders/csv_loader.py +++ b/api/loaders/csv_loader.py @@ -4,7 +4,6 @@ from collections import defaultdict from typing import Tuple -import pandas as pd import tqdm from api.loaders.base_loader import BaseLoader @@ -27,6 +26,8 @@ def load(graph_id: str, data) -> Tuple[bool, str]: Tuple of (success, message) """ raise NotImplementedError("CSVLoader is not implemented yet") + import pandas as pd + try: # Parse CSV data using pandas for better handling of large files df = pd.read_csv(io.StringIO(data), encoding="utf-8") From cfbd89319e7be4e6f671cede76bbd5554a75253e Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 14:50:21 +0300 Subject: [PATCH 10/58] remove unused import --- api/loaders/odata_loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/loaders/odata_loader.py b/api/loaders/odata_loader.py index d498e4dc..77125d4d 100644 --- a/api/loaders/odata_loader.py +++ b/api/loaders/odata_loader.py @@ -4,7 +4,6 @@ import tqdm -from api.extensions import db from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph From 4cc4799e091c4d1678dc786f3abd73cf8e9245f8 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 22:51:28 +0300 Subject: [PATCH 11/58] add login --- api/index.py | 48 ++++++++++++++---- api/templates/{chat.html => chat.j2} | 24 +++++++++ poetry.lock | 75 +++++++++++++++++++++++++++- pyproject.toml | 1 + 4 files changed, 136 insertions(+), 12 deletions(-) rename api/templates/{chat.html => chat.j2} (74%) diff --git a/api/index.py b/api/index.py index d8bbbffe..81e449d1 100644 --- a/api/index.py +++ b/api/index.py @@ -10,6 +10,8 @@ from dotenv import load_dotenv from flask import Blueprint, Flask, Response, jsonify, render_template, request, stream_with_context +from flask import session, redirect, url_for +from flask_dance.contrib.google import make_google_blueprint, google from api.agents import AnalysisAgent, RelevancyAgent from api.constants import EXAMPLES @@ -52,23 +54,29 @@ def decorated_function(*args, **kwargs): app = Flask(__name__) - -# @app.before_request -# def before_request_func(): -# oidc_token = request.headers.get('x-vercel-oidc-token') -# if oidc_token: -# set_oidc_token(oidc_token) -# credentials = assume_role() -# else: -# # Optional: require it for protected routes -# pass +app.secret_key = os.getenv("FLASK_SECRET_KEY", "supersekrit") + +# Google OAuth setup +GOOGLE_CLIENT_ID = os.getenv("GOOGLE_CLIENT_ID") +GOOGLE_CLIENT_SECRET = os.getenv("GOOGLE_CLIENT_SECRET") +google_bp = make_google_blueprint( + client_id=GOOGLE_CLIENT_ID, + client_secret=GOOGLE_CLIENT_SECRET, + scope=[ + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "openid" + ] +) +app.register_blueprint(google_bp, url_prefix="/login") @app.route("/") @token_required # Apply token authentication decorator def home(): """Home route""" - return render_template("chat.html") + is_authenticated = "google_oauth_token" in session + return render_template("chat.j2", is_authenticated=is_authenticated) @app.route("/graphs") @@ -281,6 +289,24 @@ def suggestions(): logging.error("Error fetching suggestions: %s", e) return jsonify([]), 500 +@app.route("/login") +def login_google(): + if not google.authorized: + return redirect(url_for("google.login")) + resp = google.get("/oauth2/v2/userinfo") + if resp.ok: + user_info = resp.json() + session["google_user"] = user_info + # You can set your own token/session logic here + return redirect(url_for("home")) + return "Could not fetch your information from Google.", 400 + + +@app.route("/logout") +def logout(): + session.clear() + return redirect(url_for("home")) + if __name__ == "__main__": app.register_blueprint(main) diff --git a/api/templates/chat.html b/api/templates/chat.j2 similarity index 74% rename from api/templates/chat.html rename to api/templates/chat.j2 index 6f778f88..d52fa7cd 100644 --- a/api/templates/chat.html +++ b/api/templates/chat.j2 @@ -10,6 +10,9 @@ + {% if is_authenticated %} + Logout + {% endif %}
+ + \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index 486ad416..2ca47e8e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -478,6 +478,32 @@ werkzeug = ">=3.1.0" async = ["asgiref (>=3.2)"] dotenv = ["python-dotenv"] +[[package]] +name = "flask-dance" +version = "7.1.0" +description = "Doing the OAuth dance with style using Flask, requests, and oauthlib" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "flask_dance-7.1.0-py3-none-any.whl", hash = "sha256:81599328a2b3604fd4332b3d41a901cf36980c2067e5e38c44ce3b85c4e1ae9c"}, + {file = "flask_dance-7.1.0.tar.gz", hash = "sha256:6d0510e284f3d6ff05af918849791b17ef93a008628ec33f3a80578a44b51674"}, +] + +[package.dependencies] +Flask = ">=2.0.3" +oauthlib = ">=3.2" +requests = ">=2.0" +requests-oauthlib = ">=1.0.0" +urlobject = "*" +Werkzeug = "*" + +[package.extras] +docs = ["Flask-Sphinx-Themes", "betamax", "pillow (<=9.5)", "pytest", "sphinx (>=1.3)", "sphinxcontrib-seqdiag", "sphinxcontrib-spelling", "sqlalchemy (>=1.3.11)"] +signals = ["blinker"] +sqla = ["sqlalchemy (>=1.3.11)"] +test = ["betamax", "coverage", "flask-caching", "flask-login", "flask-sqlalchemy", "freezegun", "oauthlib[signedtoken]", "pytest", "pytest-mock", "responses", "sqlalchemy (>=1.3.11)"] + [[package]] name = "frozenlist" version = "1.7.0" @@ -1219,6 +1245,23 @@ files = [ {file = "multidict-6.6.2.tar.gz", hash = "sha256:c1e8b8b0523c0361a78ce9b99d9850c51cf25e1fa3c5686030ce75df6fdf2918"}, ] +[[package]] +name = "oauthlib" +version = "3.3.1" +description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1"}, + {file = "oauthlib-3.3.1.tar.gz", hash = "sha256:0f0f8aa759826a193cf66c12ea1af1637f87b9b4622d46e866952bb022e538c9"}, +] + +[package.extras] +rsa = ["cryptography (>=3.0.0)"] +signals = ["blinker (>=1.4.0)"] +signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] + [[package]] name = "openai" version = "1.93.0" @@ -1947,6 +1990,25 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "requests-oauthlib" +version = "2.0.0" +description = "OAuthlib authentication support for Requests." +optional = false +python-versions = ">=3.4" +groups = ["main"] +files = [ + {file = "requests-oauthlib-2.0.0.tar.gz", hash = "sha256:b3dffaebd884d8cd778494369603a9e7b58d29111bf6b41bdc2dcd87203af4e9"}, + {file = "requests_oauthlib-2.0.0-py2.py3-none-any.whl", hash = "sha256:7dd8a5c40426b779b0868c404bdef9768deccf22749cde15852df527e6269b36"}, +] + +[package.dependencies] +oauthlib = ">=3.0.0" +requests = ">=2.0.0" + +[package.extras] +rsa = ["oauthlib[signedtoken] (>=3.0.0)"] + [[package]] name = "rpds-py" version = "0.25.1" @@ -2276,6 +2338,17 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "urlobject" +version = "2.4.3" +description = "A utility class for manipulating URLs." +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "URLObject-2.4.3.tar.gz", hash = "sha256:47b2e20e6ab9c8366b2f4a3566b6ff4053025dad311c4bb71279bbcfa2430caa"}, +] + [[package]] name = "werkzeug" version = "3.1.3" @@ -2436,4 +2509,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.1" python-versions = ">=3.12,<3.13" -content-hash = "ee04db8da4b24fe0934a010563416a179dc15ee351d46a380f778fb1086a0898" +content-hash = "f883bca3ecc7074ea013650a53f67a0a78dca576a65bd5732a4be3cfc6e4a66a" diff --git a/pyproject.toml b/pyproject.toml index 5c4298ab..7c896dc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ jsonschema = "^4.23.0" tqdm = "^4.67.1" boto3 = "^1.37.29" psycopg2-binary = "^2.9.9" +flask-dance = "^7.1.0" [tool.poetry.group.test.dependencies] pytest = "^8.2.0" From 5559339e3662feb94a3ef98c784e688c33e58ba5 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 23:37:55 +0300 Subject: [PATCH 12/58] add postgres button --- api/static/css/chat.css | 138 ++++++++++++++++++++++++++++++++++++++++ api/templates/chat.j2 | 36 ++++++++--- 2 files changed, 166 insertions(+), 8 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index fea8e19b..b522eedb 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -588,6 +588,108 @@ body { max-height: none; } +.pg-modal { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + background: rgba(0,0,0,0.5); + z-index: 3000; + align-items: center; + justify-content: center; +} +.pg-modal-content { + background: #fff; + padding: 2em 2.5em; + border-radius: 10px; + box-shadow: 0 2px 16px rgba(0,0,0,0.2); + text-align: center; + min-width: 340px; +} +.pg-modal-title { + margin-bottom: 1em; + color: #222; +} +.pg-modal-input { + width: 100%; + padding: 0.6em; + font-size: 1em; + border: 1px solid #ccc; + border-radius: 4px; + margin-bottom: 1.5em; + color: #222; + background: #fafafa; +} +.pg-modal-actions { + display: flex; + justify-content: space-between; + gap: 1em; +} +.pg-modal-btn { + flex: 1; + padding: 0.5em 0; + border: none; + border-radius: 4px; + font-size: 1em; + font-weight: bold; + cursor: pointer; + transition: background 0.2s; +} +.pg-modal-connect { + background: #4285F4; + color: #fff; +} +.pg-modal-connect:hover { + background: #3367d6; +} +.pg-modal-cancel { + background: #e0e0e0; + color: #333; +} +.pg-modal-cancel:hover { + background: #cacaca; +} +.google-login-modal { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + background: rgba(0,0,0,0.6); + z-index: 1000; + align-items: center; + justify-content: center; +} +.google-login-modal-content { + background: #fff; + padding: 2em 3em; + border-radius: 10px; + box-shadow: 0 2px 16px rgba(0,0,0,0.2); + text-align: center; +} +.google-login-btn { + display: inline-block; + margin-top: 1em; + padding: 0.7em 2em; + background: #4285F4; + color: #fff; + border-radius: 4px; + font-size: 1.1em; + text-decoration: none; + font-weight: 500; + transition: background 0.2s; +} +.google-login-btn:hover { + background: #3367d6; +} +.google-login-logo { + height: 20px; + vertical-align: middle; + margin-right: 8px; +} @keyframes shadow-fade { 0% { box-shadow: 0 0 20px 3px var(--falkor-tertiary) @@ -668,4 +770,40 @@ body { ::-webkit-scrollbar-thumb:hover { background-color: #a8bbbf; +} + +.pg-connect-btn { + margin-left: 12px; + padding: 0.4em 0.8em; + background: #f5f5f5; + border: 1px solid #ccc; + border-radius: 4px; + font-size: 1em; + color: #333; + cursor: pointer; + transition: background 0.2s, border 0.2s; +} + +.pg-connect-btn:hover { + background: #eaeaea; + border-color: #888; +} + +.logout-btn { + position: fixed; + top: 20px; + right: 30px; + z-index: 2000; + padding: 0.5em 1.2em; + background: #e74c3c; + color: #fff; + border-radius: 5px; + text-decoration: none; + font-weight: bold; + box-shadow: 0 2px 8px rgba(0,0,0,0.08); + transition: background 0.2s; +} + +.logout-btn:hover { + background: #c0392b; } \ No newline at end of file diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index d52fa7cd..04e91fea 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -11,7 +11,7 @@ {% if is_authenticated %} - Logout + Logout {% endif %}
@@ -102,15 +102,25 @@
-
From 7ebd4da78e853efae084fd1f256f847bdf4c3236 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Mon, 30 Jun 2025 23:51:47 +0300 Subject: [PATCH 14/58] refacor code move to js file --- api/static/css/chat.css | 4 ++-- api/static/js/chat.js | 43 +++++++++++++++++++++++++++++++++++++++++ api/templates/chat.j2 | 24 +++-------------------- 3 files changed, 48 insertions(+), 23 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index f0ec2b5e..bd6bb0bc 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -623,11 +623,11 @@ body { } .pg-modal-content { background: #fff; - padding: 2em 2.5em; + padding: 2em 5em; border-radius: 10px; box-shadow: 0 2px 16px rgba(0,0,0,0.2); text-align: center; - min-width: 340px; + min-width: 680px; } .pg-modal-title { margin-bottom: 1em; diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 2d8f19ba..179adf49 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -551,4 +551,47 @@ fileUpload.addEventListener('change', function (e) { console.error('Error uploading file:', error); addMessage('Sorry, there was an error uploading your file: ' + error.message, false); }); +}); + +document.addEventListener('DOMContentLoaded', function() { + // Authentication modal logic + var isAuthenticated = window.isAuthenticated !== undefined ? window.isAuthenticated : false; + var googleLoginModal = document.getElementById('google-login-modal'); + var container = document.getElementById('container'); + if (googleLoginModal && container) { + if (!isAuthenticated) { + googleLoginModal.style.display = 'flex'; + container.style.filter = 'blur(2px)'; + } else { + googleLoginModal.style.display = 'none'; + container.style.filter = ''; + } + } + // Postgres modal logic + var pgModal = document.getElementById('pg-modal'); + var openPgModalBtn = document.getElementById('open-pg-modal'); + var cancelPgModalBtn = document.getElementById('pg-modal-cancel'); + if (openPgModalBtn && pgModal) { + openPgModalBtn.addEventListener('click', function() { + pgModal.style.display = 'flex'; + }); + } + if (cancelPgModalBtn && pgModal) { + cancelPgModalBtn.addEventListener('click', function() { + pgModal.style.display = 'none'; + }); + } + // Allow closing Postgres modal with Escape key + document.addEventListener('keydown', function(e) { + if (pgModal && pgModal.style.display === 'flex' && e.key === 'Escape') { + pgModal.style.display = 'none'; + } + }); + // Optional: Close Google login modal with Escape (if ever needed) + document.addEventListener('keydown', function(e) { + if (googleLoginModal && googleLoginModal.style.display === 'flex' && e.key === 'Escape') { + googleLoginModal.style.display = 'none'; + container.style.filter = ''; + } + }); }); \ No newline at end of file diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index 4bc05e63..4dbdde80 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -123,29 +123,11 @@
- + {# Set authentication state for JS before loading chat.js #} + \ No newline at end of file From 9dfa6aa829f11f319182246b69b69fb263bc02a9 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Tue, 1 Jul 2025 00:17:55 +0300 Subject: [PATCH 15/58] connect db --- api/index.py | 29 +++++++++++++++++++++++++++++ api/loaders/postgres_loader.py | 18 +++++++----------- api/static/js/chat.js | 32 ++++++++++++++++++++++++++++++++ api/templates/chat.j2 | 2 +- 4 files changed, 69 insertions(+), 12 deletions(-) diff --git a/api/index.py b/api/index.py index 81e449d1..ffddcde7 100644 --- a/api/index.py +++ b/api/index.py @@ -19,6 +19,7 @@ from api.graph import find, get_db_description from api.loaders.csv_loader import CSVLoader from api.loaders.json_loader import JSONLoader +from api.loaders.postgres_loader import PostgresLoader from api.loaders.odata_loader import ODataLoader # Load environment variables from .env file @@ -307,6 +308,34 @@ def logout(): session.clear() return redirect(url_for("home")) +@app.route("/database", methods=["POST"]) +@token_required # Apply token authentication decorator +def connect_database(): + """ + Accepts a JSON payload with a Postgres URL and attempts to connect. + Returns success or error message. + """ + data = request.get_json() + url = data.get("url") if data else None + if not url: + return jsonify({"success": False, "error": "No URL provided"}), 400 + try: + # Check for Postgres URL + if url.startswith("postgres://") or url.startswith("postgresql://"): + try: + # Attempt to connect/load using the loader + success, result = PostgresLoader.load(url) + if success: + return jsonify({"success": True, "message": result}), 200 + else: + return jsonify({"success": False, "error": result}), 400 + except Exception as e: + return jsonify({"success": False, "error": str(e)}), 500 + else: + return jsonify({"success": False, "error": "Invalid Postgres URL"}), 400 + except Exception as e: + return jsonify({"success": False, "error": str(e)}), 500 + if __name__ == "__main__": app.register_blueprint(main) diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index 0bb67160..ea8cc5fd 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -1,25 +1,21 @@ from typing import Tuple, Dict, Any, List import psycopg2 import tqdm -from api.config import Config from api.loaders.base_loader import BaseLoader -from api.extensions import db -from api.utils import generate_db_description from api.loaders.graph_loader import load_to_graph -class PostgreSQLLoader(BaseLoader): +class PostgresLoader(BaseLoader): """ Loader for PostgreSQL databases that connects and extracts schema information. """ @staticmethod - def load(graph_id: str, connection_url: str) -> Tuple[bool, str]: + def load(connection_url: str) -> Tuple[bool, str]: """ Load the graph data from a PostgreSQL database into the graph database. Args: - graph_id: The ID of the graph to load data into connection_url: PostgreSQL connection URL in format: postgresql://username:password@host:port/database @@ -37,17 +33,17 @@ def load(graph_id: str, connection_url: str) -> Tuple[bool, str]: db_name = db_name.split('?')[0] # Get all table information - entities = PostgreSQLLoader.extract_tables_info(cursor) + entities = PostgresLoader.extract_tables_info(cursor) # Get all relationship information - relationships = PostgreSQLLoader.extract_relationships(cursor) + relationships = PostgresLoader.extract_relationships(cursor) # Close database connection cursor.close() conn.close() # Load data into graph - load_to_graph(graph_id, entities, relationships, db_name=db_name) + load_to_graph(db_name, entities, relationships, db_name=db_name) return True, f"PostgreSQL schema loaded successfully. Found {len(entities)} tables." @@ -91,10 +87,10 @@ def extract_tables_info(cursor) -> Dict[str, Any]: table_name = table_name.strip() # Get column information for this table - columns_info = PostgreSQLLoader.extract_columns_info(cursor, table_name) + columns_info = PostgresLoader.extract_columns_info(cursor, table_name) # Get foreign keys for this table - foreign_keys = PostgreSQLLoader.extract_foreign_keys(cursor, table_name) + foreign_keys = PostgresLoader.extract_foreign_keys(cursor, table_name) # Generate table description table_description = table_comment if table_comment else f"Table: {table_name}" diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 179adf49..21ef7df2 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -571,6 +571,8 @@ document.addEventListener('DOMContentLoaded', function() { var pgModal = document.getElementById('pg-modal'); var openPgModalBtn = document.getElementById('open-pg-modal'); var cancelPgModalBtn = document.getElementById('pg-modal-cancel'); + var connectPgModalBtn = document.getElementById('pg-modal-connect'); + var pgUrlInput = document.getElementById('pg-url-input'); if (openPgModalBtn && pgModal) { openPgModalBtn.addEventListener('click', function() { pgModal.style.display = 'flex'; @@ -594,4 +596,34 @@ document.addEventListener('DOMContentLoaded', function() { container.style.filter = ''; } }); + + // Handle Connect button for Postgres modal + if (connectPgModalBtn && pgUrlInput && pgModal) { + connectPgModalBtn.addEventListener('click', function() { + const pgUrl = pgUrlInput.value.trim(); + if (!pgUrl) { + alert('Please enter a Postgres URL.'); + return; + } + fetch('/database', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ url: pgUrl }) + }) + .then(response => response.json()) + .then(data => { + if (data.success) { + alert('Database connected successfully!'); + pgModal.style.display = 'none'; + } else { + alert('Failed to connect: ' + (data.error || 'Unknown error')); + } + }) + .catch(error => { + alert('Error connecting to database: ' + error.message); + }); + }); + } }); \ No newline at end of file diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index 4dbdde80..3120edeb 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -116,7 +116,7 @@

Connect to Postgres

- +
From 24d14512aad90c9df37ef8089fd3d44a463d37f3 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Tue, 1 Jul 2025 00:20:45 +0300 Subject: [PATCH 16/58] close modal --- api/static/js/chat.js | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 21ef7df2..e2bb65de 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -615,8 +615,7 @@ document.addEventListener('DOMContentLoaded', function() { .then(response => response.json()) .then(data => { if (data.success) { - alert('Database connected successfully!'); - pgModal.style.display = 'none'; + pgModal.style.display = 'none'; // Close modal on success, no alert } else { alert('Failed to connect: ' + (data.error || 'Unknown error')); } From ff262f9f9e89d64f05d079a16bcab572e98e0b62 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Tue, 1 Jul 2025 01:13:23 +0300 Subject: [PATCH 17/58] filter users graphs --- api/index.py | 53 ++++++++++++++-------------------- api/loaders/postgres_loader.py | 4 +-- 2 files changed, 23 insertions(+), 34 deletions(-) diff --git a/api/index.py b/api/index.py index ffddcde7..688b9dfb 100644 --- a/api/index.py +++ b/api/index.py @@ -9,7 +9,7 @@ from functools import wraps from dotenv import load_dotenv -from flask import Blueprint, Flask, Response, jsonify, render_template, request, stream_with_context +from flask import Blueprint, Flask, Response, jsonify, render_template, request, stream_with_context, g from flask import session, redirect, url_for from flask_dance.contrib.google import make_google_blueprint, google @@ -35,19 +35,15 @@ SECRET_TOKEN_ERP = os.getenv("SECRET_TOKEN_ERP") -def verify_token(token): - """Verify the token provided in the request""" - return token in (SECRET_TOKEN, SECRET_TOKEN_ERP, 'null') - - def token_required(f): """Decorator to protect routes with token authentication""" @wraps(f) def decorated_function(*args, **kwargs): - token = request.args.get("token", "null") # Get token from header - os.environ["USER_TOKEN"] = token - if not verify_token(token): + user_info = session.get("google_user") + if user_info: + g.user_id = user_info.get("id") + else: return jsonify(message="Unauthorized"), 401 return f(*args, **kwargs) @@ -73,10 +69,14 @@ def decorated_function(*args, **kwargs): @app.route("/") -@token_required # Apply token authentication decorator def home(): """Home route""" is_authenticated = "google_oauth_token" in session + if is_authenticated: + resp = google.get("/oauth2/v2/userinfo") + if resp.ok: + user_info = resp.json() + session["google_user"] = user_info return render_template("chat.j2", is_authenticated=is_authenticated) @@ -86,24 +86,12 @@ def graphs(): """ This route is used to list all the graphs that are available in the database. """ + user_id = g.user_id user_graphs = db.list_graphs() - if os.getenv("USER_TOKEN") == SECRET_TOKEN: - if "hospital" in user_graphs: - return ["hospital"] - return [] - - if os.getenv("USER_TOKEN") == SECRET_TOKEN_ERP: - if "ERP_system" in user_graphs: - return ["ERP_system"] - return ["crm_usecase"] - - if os.getenv("USER_TOKEN") == "null": - if "crm_usecase" in user_graphs: - return ["crm_usecase"] - return [] - - user_graphs.remove("hospital") - return user_graphs + # Only include graphs that start with user_id + '_', and strip the prefix + filtered_graphs = [graph[len(f"{user_id}_"):] + for graph in user_graphs if graph.startswith(f"{user_id}_")] + return jsonify(filtered_graphs) @app.route("/graphs", methods=["POST"]) @@ -126,7 +114,7 @@ def load(): if not data or "database" not in data: return jsonify({"error": "Invalid JSON data"}), 400 - graph_id = data["database"] + graph_id = g.user_id + "_" + data["database"] success, result = JSONLoader.load(graph_id, data) # ✅ Handle XML Payload @@ -154,7 +142,7 @@ def load(): if file.filename.endswith(".json"): try: data = json.load(file) - graph_id = data.get("database", "") + graph_id = g.user_id + "_" + data.get("database", "") success, result = JSONLoader.load(graph_id, data) except json.JSONDecodeError: return jsonify({"error": "Invalid JSON file"}), 400 @@ -162,13 +150,13 @@ def load(): # ✅ Check if file is XML elif file.filename.endswith(".xml"): xml_data = file.read().decode("utf-8") # Convert bytes to string - graph_id = file.filename.replace(".xml", "") + graph_id = g.user_id + "_" + file.filename.replace(".xml", "") success, result = ODataLoader.load(graph_id, xml_data) # ✅ Check if file is csv elif file.filename.endswith(".csv"): csv_data = file.read().decode("utf-8") # Convert bytes to string - graph_id = file.filename.replace(".csv", "") + graph_id = g.user_id + "_" + file.filename.replace(".csv", "") success, result = CSVLoader.load(graph_id, csv_data) else: @@ -189,6 +177,7 @@ def query(graph_id: str): """ text2sql """ + graph_id = g.user_id + "_" + graph_id.strip() request_data = request.get_json() queries_history = request_data.get("chat") result_history = request_data.get("result") @@ -324,7 +313,7 @@ def connect_database(): if url.startswith("postgres://") or url.startswith("postgresql://"): try: # Attempt to connect/load using the loader - success, result = PostgresLoader.load(url) + success, result = PostgresLoader.load(g.user_id, url) if success: return jsonify({"success": True, "message": result}), 200 else: diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index ea8cc5fd..115e67d2 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -11,7 +11,7 @@ class PostgresLoader(BaseLoader): """ @staticmethod - def load(connection_url: str) -> Tuple[bool, str]: + def load(prefix: str, connection_url: str) -> Tuple[bool, str]: """ Load the graph data from a PostgreSQL database into the graph database. @@ -43,7 +43,7 @@ def load(connection_url: str) -> Tuple[bool, str]: conn.close() # Load data into graph - load_to_graph(db_name, entities, relationships, db_name=db_name) + load_to_graph(prefix + "_" + db_name, entities, relationships, db_name=db_name) return True, f"PostgreSQL schema loaded successfully. Found {len(entities)} tables." From cd74b0649ae77506fdef0ccb2100a699a74fb0c9 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Tue, 1 Jul 2025 01:20:24 +0300 Subject: [PATCH 18/58] clean style --- api/static/css/chat.css | 4 +++- api/templates/chat.j2 | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index bd6bb0bc..6c7c691f 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -410,6 +410,7 @@ body { background-position: calc(100% - 20px) center, calc(100% - 15px) center; background-size: 5px 5px, 5px 5px; background-repeat: no-repeat; + cursor:pointer; } #graph-select:focus { @@ -433,9 +434,10 @@ body { background-position: calc(100% - 20px) center, calc(100% - 15px) center; background-size: 5px 5px, 5px 5px; background-repeat: no-repeat; + cursor:pointer; } -.custom-file-upload { +#open-pg-modal { height: 100%; padding: 8px 12px; border-radius: 6px; diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index 3120edeb..d99340a0 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -60,10 +60,10 @@
-
From 885a6650e27fa7372f112ff5ac31322c3dd2bc24 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Tue, 1 Jul 2025 10:16:08 +0300 Subject: [PATCH 19/58] update login design --- api/static/css/chat.css | 9 +++++++-- api/static/public/icons/google.svg | 1 + api/templates/chat.j2 | 3 ++- 3 files changed, 10 insertions(+), 3 deletions(-) create mode 100644 api/static/public/icons/google.svg diff --git a/api/static/css/chat.css b/api/static/css/chat.css index 6c7c691f..9410ef27 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -694,7 +694,10 @@ body { text-align: center; } .google-login-btn { - display: inline-block; + display: flex; + align-items: center; + justify-content: flex-start; + gap: 10px; margin-top: 1em; padding: 0.7em 2em; background: #4285F4; @@ -710,8 +713,10 @@ body { } .google-login-logo { height: 20px; + margin-right: 10px; + margin-left: 0; vertical-align: middle; - margin-right: 8px; + display: inline-block; } @keyframes shadow-fade { 0% { diff --git a/api/static/public/icons/google.svg b/api/static/public/icons/google.svg new file mode 100644 index 00000000..fb354fdb --- /dev/null +++ b/api/static/public/icons/google.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index d99340a0..8d03aa42 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -63,6 +63,7 @@ +
@@ -109,7 +110,7 @@

Welcome to Text-to-SQL

Please login to continue

From 60a5affe6291c5ed52cdab1e2dc4bb6a0dfe2b9d Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Tue, 1 Jul 2025 19:40:19 +0300 Subject: [PATCH 20/58] update req --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index f5cdb6c5..1a2eb4b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ distro==1.9.0 ; python_version == "3.12" falkordb==1.1.2 ; python_version == "3.12" filelock==3.18.0 ; python_version == "3.12" flask==3.1.1 ; python_version == "3.12" +flask-dance==7.1.0 ; python_version == "3.12" frozenlist==1.7.0 ; python_version == "3.12" fsspec==2025.5.1 ; python_version == "3.12" h11==0.16.0 ; python_version == "3.12" From 3267418067780c257a467446baaf2e3090b9206f Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Tue, 1 Jul 2025 19:49:52 +0300 Subject: [PATCH 21/58] fix req --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 1a2eb4b6..a93da3ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -60,3 +60,4 @@ urllib3==2.5.0 ; python_version == "3.12" werkzeug==3.1.3 ; python_version == "3.12" yarl==1.20.1 ; python_version == "3.12" zipp==3.23.0 ; python_version == "3.12" +psycopg2-binary==2.9.9 ; python_version == "3.12" From d81b12bd4325f187c865201c8b88f636e7adc40a Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Tue, 1 Jul 2025 22:16:04 +0300 Subject: [PATCH 22/58] Add query result --- api/graph.py | 8 ++--- api/index.py | 18 +++++++++++- api/loaders/graph_loader.py | 6 ++-- api/loaders/postgres_loader.py | 53 +++++++++++++++++++++++++++++----- api/static/js/chat.js | 7 +++++ 5 files changed, 77 insertions(+), 15 deletions(-) diff --git a/api/graph.py b/api/graph.py index 1fa783aa..9eed9bca 100644 --- a/api/graph.py +++ b/api/graph.py @@ -35,20 +35,20 @@ class Descriptions(BaseModel): columns_descriptions: list[ColumnDescription] -def get_db_description(graph_id: str) -> str: +def get_db_description(graph_id: str) -> (str, str): """Get the database description from the graph.""" graph = db.select_graph(graph_id) query_result = graph.query( """ MATCH (d:Database) - RETURN d.description + RETURN d.description, d.url """ ) if not query_result.result_set: - return "No description available for this database." + return ("No description available for this database.", "No URL available for this database.") - return query_result.result_set[0][0] # Return the first result's description + return (query_result.result_set[0][0], query_result.result_set[0][1]) # Return the first result's description def find( diff --git a/api/index.py b/api/index.py index 688b9dfb..52fe7c82 100644 --- a/api/index.py +++ b/api/index.py @@ -194,7 +194,7 @@ def generate(): step = {"type": "reasoning_step", "message": "Step 1: Analyzing the user query"} yield json.dumps(step) + MESSAGE_DELIMITER - db_description = get_db_description(graph_id) # Ensure the database description is loaded + db_description, db_url = get_db_description(graph_id) # Ensure the database description is loaded logging.info("Calling to relvancy agent with query: %s", queries_history[-1]) answer_rel = agent_rel.get_answer(queries_history[-1], db_description) @@ -247,6 +247,22 @@ def generate(): } ) + MESSAGE_DELIMITER + # If the SQL query is valid, execute it using the postgress database db_url + if answer_an["is_sql_translatable"]: + try: + result = PostgresLoader.execute_sql_query(answer_an["sql_query"], db_url) + yield json.dumps( + { + "type": "query_result", + "data": result, + } + ) + MESSAGE_DELIMITER + except Exception as e: + logging.error("Error executing SQL query: %s", e) + yield json.dumps( + {"type": "error", "message": str(e)} + ) + MESSAGE_DELIMITER + return Response(stream_with_context(generate()), content_type="application/json") diff --git a/api/loaders/graph_loader.py b/api/loaders/graph_loader.py index 82fe5342..48f1f676 100644 --- a/api/loaders/graph_loader.py +++ b/api/loaders/graph_loader.py @@ -15,6 +15,7 @@ def load_to_graph( relationships: dict, batch_size: int = 100, db_name: str = "TBD", + db_url: str = "", ) -> None: """ Load the graph data into the database. @@ -56,10 +57,11 @@ def load_to_graph( """ CREATE (d:Database { name: $db_name, - description: $description + description: $description, + url: $url }) """, - {"db_name": db_name, "description": db_des}, + {"db_name": db_name, "description": db_des, "url": db_url}, ) for table_name, table_info in tqdm.tqdm(entities.items(), desc="Creating Graph Table Nodes"): diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index 115e67d2..97183839 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -26,27 +26,27 @@ def load(prefix: str, connection_url: str) -> Tuple[bool, str]: # Connect to PostgreSQL database conn = psycopg2.connect(connection_url) cursor = conn.cursor() - + # Extract database name from connection URL db_name = connection_url.split('/')[-1] if '?' in db_name: db_name = db_name.split('?')[0] - + # Get all table information entities = PostgresLoader.extract_tables_info(cursor) - + # Get all relationship information relationships = PostgresLoader.extract_relationships(cursor) - + # Close database connection cursor.close() conn.close() - + # Load data into graph - load_to_graph(prefix + "_" + db_name, entities, relationships, db_name=db_name) - + load_to_graph(prefix + "_" + db_name, entities, relationships, db_name=db_name, db_url=connection_url) + return True, f"PostgreSQL schema loaded successfully. Found {len(entities)} tables." - + except psycopg2.Error as e: return False, f"PostgreSQL connection error: {str(e)}" except Exception as e: @@ -276,3 +276,40 @@ def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]: }) return relationships + + @staticmethod + def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: + """ + Execute a SQL query on the PostgreSQL database and return the results. + + Args: + sql_query: The SQL query to execute + db_url: PostgreSQL connection URL in format: + postgresql://username:password@host:port/database + + Returns: + List of dictionaries containing the query results + """ + try: + # Connect to PostgreSQL database + conn = psycopg2.connect(db_url) + cursor = conn.cursor() + + # Execute the SQL query + cursor.execute(sql_query) + columns = [desc[0] for desc in cursor.description] + results = cursor.fetchall() + + # Convert results to list of dictionaries + result_list = [dict(zip(columns, row)) for row in results] + + # Close database connection + cursor.close() + conn.close() + + return result_list + + except psycopg2.Error as e: + raise Exception(f"PostgreSQL query execution error: {str(e)}") + except Exception as e: + raise Exception(f"Error executing SQL query: {str(e)}") diff --git a/api/static/js/chat.js b/api/static/js/chat.js index e2bb65de..f3d9b66e 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -310,6 +310,13 @@ async function sendMessage() { ambValue.textContent = "N/A"; // graph.Labels.findIndex(l => l.name === cat.name)(step.message, false, true); addMessage(step.message, false, true); + } else if (step.type === 'query_result') { + // Handle query result + if (step.data) { + addMessage(`Query Result: ${JSON.stringify(step.data)}`, false, false, true); + } else { + addMessage("No results found for the query.", false); + } } else { // Default handling addMessage(step.message || JSON.stringify(step), false); From 4876a70d5ba841fa25f8f3f96346a8dd7dddeab5 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 08:04:27 +0300 Subject: [PATCH 23/58] add user readable result --- api/agents.py | 99 +++++++++++++++++++++++++++++++++++++++++++ api/index.py | 29 +++++++++++-- api/static/js/chat.js | 3 ++ 3 files changed, 128 insertions(+), 3 deletions(-) diff --git a/api/agents.py b/api/agents.py index 8e268873..679264ba 100644 --- a/api/agents.py +++ b/api/agents.py @@ -443,3 +443,102 @@ def _parse_response(response: str) -> Dict[str, Any]: "explanation": f"Failed to parse response: {str(e)}", "error": str(response), } + + +class ResponseFormatterAgent: + """Agent for generating user-readable responses from SQL query results.""" + + def __init__(self): + """Initialize the response formatter agent.""" + pass + + def format_response(self, user_query: str, sql_query: str, query_results: List[Dict], db_description: str = "") -> str: + """ + Generate a user-readable response based on the SQL query results. + + Args: + user_query: The original user question + sql_query: The SQL query that was executed + query_results: The results from the SQL query execution + db_description: Description of the database context + + Returns: + A formatted, user-readable response string + """ + prompt = self._build_response_prompt(user_query, sql_query, query_results, db_description) + + messages = [{"role": "user", "content": prompt}] + + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=messages, + temperature=0.3, # Slightly higher temperature for more natural responses + top_p=1, + ) + + response = completion_result.choices[0].message.content + return response.strip() + + def _build_response_prompt(self, user_query: str, sql_query: str, query_results: List[Dict], db_description: str) -> str: + """Build the prompt for generating user-readable responses.""" + + # Format the query results for better readability + formatted_results = self._format_query_results(query_results) + + prompt = f""" +You are an AI assistant that helps users understand database query results. Your task is to analyze the SQL query results and provide a clear, concise, and user-friendly explanation. + +**Context:** +Database Description: {db_description if db_description else "Not provided"} + +**User's Original Question:** +{user_query} + +**SQL Query Executed:** +{sql_query} + +**Query Results:** +{formatted_results} + +**Instructions:** +1. Provide a clear, natural language answer to the user's question based on the query results +2. Focus on the key insights and findings from the data +3. Use bullet points or numbered lists when presenting multiple items +4. Include relevant numbers, percentages, or trends if applicable +5. Be concise but comprehensive - avoid unnecessary technical jargon +6. If the results are empty, explain that no data was found matching the criteria +7. If there are many results, provide a summary with highlights +8. Do not mention the SQL query or technical database details unless specifically relevant to the user's understanding + +**Response Format:** +Provide a direct answer to the user's question in a conversational tone, as if you were explaining the findings to a colleague. +""" + + return prompt + + def _format_query_results(self, query_results: List[Dict]) -> str: + """Format query results for inclusion in the prompt.""" + if not query_results: + return "No results found." + + if len(query_results) == 0: + return "No results found." + + # Limit the number of results shown in the prompt to avoid token limits + max_results_to_show = 50 + results_to_show = query_results[:max_results_to_show] + + formatted = [] + for i, result in enumerate(results_to_show, 1): + if isinstance(result, dict): + result_str = ", ".join([f"{k}: {v}" for k, v in result.items()]) + formatted.append(f"{i}. {result_str}") + else: + formatted.append(f"{i}. {result}") + + result_text = "\n".join(formatted) + + if len(query_results) > max_results_to_show: + result_text += f"\n... and {len(query_results) - max_results_to_show} more results" + + return result_text diff --git a/api/index.py b/api/index.py index 52fe7c82..5d554114 100644 --- a/api/index.py +++ b/api/index.py @@ -13,7 +13,7 @@ from flask import session, redirect, url_for from flask_dance.contrib.google import make_google_blueprint, google -from api.agents import AnalysisAgent, RelevancyAgent +from api.agents import AnalysisAgent, RelevancyAgent, ResponseFormatterAgent from api.constants import EXAMPLES from api.extensions import db from api.graph import find, get_db_description @@ -250,13 +250,36 @@ def generate(): # If the SQL query is valid, execute it using the postgress database db_url if answer_an["is_sql_translatable"]: try: - result = PostgresLoader.execute_sql_query(answer_an["sql_query"], db_url) + step = {"type": "reasoning_step", "message": "Step 3: Executing SQL query"} + yield json.dumps(step) + MESSAGE_DELIMITER + + query_results = PostgresLoader.execute_sql_query(answer_an["sql_query"], db_url) yield json.dumps( { "type": "query_result", - "data": result, + "data": query_results, } ) + MESSAGE_DELIMITER + + # Generate user-readable response using AI + step = {"type": "reasoning_step", "message": "Step 4: Generating user-friendly response"} + yield json.dumps(step) + MESSAGE_DELIMITER + + response_agent = ResponseFormatterAgent() + user_readable_response = response_agent.format_response( + user_query=queries_history[-1], + sql_query=answer_an["sql_query"], + query_results=query_results, + db_description=db_description + ) + + yield json.dumps( + { + "type": "ai_response", + "message": user_readable_response, + } + ) + MESSAGE_DELIMITER + except Exception as e: logging.error("Error executing SQL query: %s", e) yield json.dumps( diff --git a/api/static/js/chat.js b/api/static/js/chat.js index f3d9b66e..d2bc7c4e 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -317,6 +317,9 @@ async function sendMessage() { } else { addMessage("No results found for the query.", false); } + } else if (step.type === 'ai_response') { + // Handle AI-generated user-friendly response + addMessage(step.message, false, false, true); } else { // Default handling addMessage(step.message || JSON.stringify(step), false); From 388ec6e4e52708266d9dc1b5a340f0ec44c3be90 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 08:21:11 +0300 Subject: [PATCH 24/58] handle none SELECT queries --- api/agents.py | 34 +++++++++++++++++++++------ api/loaders/postgres_loader.py | 43 ++++++++++++++++++++++++++++++---- 2 files changed, 65 insertions(+), 12 deletions(-) diff --git a/api/agents.py b/api/agents.py index 679264ba..caae916a 100644 --- a/api/agents.py +++ b/api/agents.py @@ -485,6 +485,9 @@ def _build_response_prompt(self, user_query: str, sql_query: str, query_results: # Format the query results for better readability formatted_results = self._format_query_results(query_results) + # Determine the type of SQL operation + sql_type = sql_query.strip().split()[0].upper() if sql_query else "UNKNOWN" + prompt = f""" You are an AI assistant that helps users understand database query results. Your task is to analyze the SQL query results and provide a clear, concise, and user-friendly explanation. @@ -497,18 +500,22 @@ def _build_response_prompt(self, user_query: str, sql_query: str, query_results: **SQL Query Executed:** {sql_query} +**Query Type:** {sql_type} + **Query Results:** {formatted_results} **Instructions:** 1. Provide a clear, natural language answer to the user's question based on the query results -2. Focus on the key insights and findings from the data -3. Use bullet points or numbered lists when presenting multiple items -4. Include relevant numbers, percentages, or trends if applicable -5. Be concise but comprehensive - avoid unnecessary technical jargon -6. If the results are empty, explain that no data was found matching the criteria -7. If there are many results, provide a summary with highlights -8. Do not mention the SQL query or technical database details unless specifically relevant to the user's understanding +2. For SELECT queries: Focus on the key insights and findings from the data +3. For INSERT/UPDATE/DELETE queries: Confirm the operation was successful and mention how many records were affected +4. For other operations (CREATE, DROP, etc.): Confirm the operation was completed successfully +5. Use bullet points or numbered lists when presenting multiple items +6. Include relevant numbers, percentages, or trends if applicable +7. Be concise but comprehensive - avoid unnecessary technical jargon +8. If the results are empty, explain that no data was found matching the criteria +9. If there are many results, provide a summary with highlights +10. Do not mention the SQL query or technical database details unless specifically relevant to the user's understanding **Response Format:** Provide a direct answer to the user's question in a conversational tone, as if you were explaining the findings to a colleague. @@ -524,6 +531,19 @@ def _format_query_results(self, query_results: List[Dict]) -> str: if len(query_results) == 0: return "No results found." + # Check if this is an operation result (INSERT/UPDATE/DELETE) + if len(query_results) == 1 and "operation" in query_results[0]: + result = query_results[0] + operation = result.get("operation", "UNKNOWN") + affected_rows = result.get("affected_rows") + status = result.get("status", "unknown") + + if affected_rows is not None: + return f"Operation: {operation}, Status: {status}, Affected rows: {affected_rows}" + else: + return f"Operation: {operation}, Status: {status}" + + # Handle regular SELECT query results # Limit the number of results shown in the prompt to avoid token limits max_results_to_show = 50 results_to_show = query_results[:max_results_to_show] diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index 97183839..915e0ca0 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -297,11 +297,34 @@ def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: # Execute the SQL query cursor.execute(sql_query) - columns = [desc[0] for desc in cursor.description] - results = cursor.fetchall() - - # Convert results to list of dictionaries - result_list = [dict(zip(columns, row)) for row in results] + + # Check if the query returns results (SELECT queries) + if cursor.description is not None: + # This is a SELECT query or similar that returns rows + columns = [desc[0] for desc in cursor.description] + results = cursor.fetchall() + result_list = [dict(zip(columns, row)) for row in results] + else: + # This is an INSERT, UPDATE, DELETE, or other non-SELECT query + # Return information about the operation + affected_rows = cursor.rowcount + sql_type = sql_query.strip().split()[0].upper() + + if sql_type in ['INSERT', 'UPDATE', 'DELETE']: + result_list = [{ + "operation": sql_type, + "affected_rows": affected_rows, + "status": "success" + }] + else: + # For other types of queries (CREATE, DROP, etc.) + result_list = [{ + "operation": sql_type, + "status": "success" + }] + + # Commit the transaction for write operations + conn.commit() # Close database connection cursor.close() @@ -310,6 +333,16 @@ def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: return result_list except psycopg2.Error as e: + # Rollback in case of error + if 'conn' in locals(): + conn.rollback() + cursor.close() + conn.close() raise Exception(f"PostgreSQL query execution error: {str(e)}") except Exception as e: + # Rollback in case of error + if 'conn' in locals(): + conn.rollback() + cursor.close() + conn.close() raise Exception(f"Error executing SQL query: {str(e)}") From 337012bb7482a93cb1cf9f1d49680104230fa491 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 08:31:25 +0300 Subject: [PATCH 25/58] merge step 1 and 2 to single message --- api/index.py | 8 +- docs/postgres_loader.md | 240 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 243 insertions(+), 5 deletions(-) create mode 100644 docs/postgres_loader.md diff --git a/api/index.py b/api/index.py index 5d554114..5f7c8d35 100644 --- a/api/index.py +++ b/api/index.py @@ -192,7 +192,7 @@ def generate(): agent_rel = RelevancyAgent(queries_history, result_history) agent_an = AnalysisAgent(queries_history, result_history) - step = {"type": "reasoning_step", "message": "Step 1: Analyzing the user query"} + step = {"type": "reasoning_step", "message": "Step 1: Analyzing user query and generating SQL..."} yield json.dumps(step) + MESSAGE_DELIMITER db_description, db_url = get_db_description(graph_id) # Ensure the database description is loaded @@ -227,8 +227,6 @@ def generate(): ) + MESSAGE_DELIMITER return - step = {"type": "reasoning_step", "message": "Step 2: Generating SQL query"} - yield json.dumps(step) + MESSAGE_DELIMITER logging.info("Calling to analysis agent with query: %s", queries_history[-1]) answer_an = agent_an.get_analysis( queries_history[-1], result, db_description, instructions @@ -250,7 +248,7 @@ def generate(): # If the SQL query is valid, execute it using the postgress database db_url if answer_an["is_sql_translatable"]: try: - step = {"type": "reasoning_step", "message": "Step 3: Executing SQL query"} + step = {"type": "reasoning_step", "message": "Step 2: Executing SQL query"} yield json.dumps(step) + MESSAGE_DELIMITER query_results = PostgresLoader.execute_sql_query(answer_an["sql_query"], db_url) @@ -262,7 +260,7 @@ def generate(): ) + MESSAGE_DELIMITER # Generate user-readable response using AI - step = {"type": "reasoning_step", "message": "Step 4: Generating user-friendly response"} + step = {"type": "reasoning_step", "message": "Step 3: Generating user-friendly response"} yield json.dumps(step) + MESSAGE_DELIMITER response_agent = ResponseFormatterAgent() diff --git a/docs/postgres_loader.md b/docs/postgres_loader.md new file mode 100644 index 00000000..d4024fd1 --- /dev/null +++ b/docs/postgres_loader.md @@ -0,0 +1,240 @@ +# PostgreSQL Schema Loader + +This loader connects to a PostgreSQL database and extracts the complete schema information, including tables, columns, relationships, and constraints. The extracted schema is then loaded into a graph database for further analysis and query generation. + +## Features + +- **Complete Schema Extraction**: Retrieves all tables, columns, data types, constraints, and relationships +- **Foreign Key Relationships**: Automatically discovers and maps foreign key relationships between tables +- **Column Metadata**: Extracts column comments, default values, nullability, and key types +- **Batch Processing**: Efficiently processes large schemas with progress tracking +- **Error Handling**: Robust error handling for connection issues and malformed schemas + +## Installation + +{% capture shell_0 %} +poetry add psycopg2-binary +{% endcapture %} + +{% capture shell_1 %} +pip install psycopg2-binary +{% endcapture %} + +{% include code_tabs.html id="install_tabs" shell=shell_0 shell2=shell_1 %} + +## Usage + +### Basic Usage + +{% capture python_0 %} +from api.loaders.postgres_loader import PostgreSQLLoader + +# Connection URL format: postgresql://username:password@host:port/database +connection_url = "postgresql://postgres:password@localhost:5432/mydatabase" +graph_id = "my_schema_graph" + +success, message = PostgreSQLLoader.load(graph_id, connection_url) + +if success: + print(f"Schema loaded successfully: {message}") +else: + print(f"Failed to load schema: {message}") +{% endcapture %} + +{% capture javascript_0 %} +import { PostgreSQLLoader } from 'your-pkg'; + +const connectionUrl = "postgresql://postgres:password@localhost:5432/mydatabase"; +const graphId = "my_schema_graph"; + +const [success, message] = await PostgreSQLLoader.load(graphId, connectionUrl); +if (success) { + console.log(`Schema loaded successfully: ${message}`); +} else { + console.log(`Failed to load schema: ${message}`); +} +{% endcapture %} + +{% capture java_0 %} +String connectionUrl = "postgresql://postgres:password@localhost:5432/mydatabase"; +String graphId = "my_schema_graph"; +Pair result = PostgreSQLLoader.load(graphId, connectionUrl); +if (result.getLeft()) { + System.out.println("Schema loaded successfully: " + result.getRight()); +} else { + System.out.println("Failed to load schema: " + result.getRight()); +} +{% endcapture %} + +{% capture rust_0 %} +let connection_url = "postgresql://postgres:password@localhost:5432/mydatabase"; +let graph_id = "my_schema_graph"; +let (success, message) = postgresql_loader::load(graph_id, connection_url)?; +if success { + println!("Schema loaded successfully: {}", message); +} else { + println!("Failed to load schema: {}", message); +} +{% endcapture %} + +{% include code_tabs.html id="basic_usage_tabs" python=python_0 javascript=javascript_0 java=java_0 rust=rust_0 %} + +### Connection URL Format + +``` +postgresql://[username[:password]@][host[:port]][/database][?options] +``` + +**Examples:** +- `postgresql://postgres:password@localhost:5432/mydatabase` +- `postgresql://user:pass@example.com:5432/production_db` +- `postgresql://postgres@127.0.0.1/testdb` + +### Integration with Graph Database + +{% capture python_1 %} +from api.loaders.postgres_loader import PostgreSQLLoader +from api.extensions import db + +# Load PostgreSQL schema into graph +graph_id = "customer_db_schema" +connection_url = "postgresql://postgres:password@localhost:5432/customers" + +success, message = PostgreSQLLoader.load(graph_id, connection_url) + +if success: + # The schema is now available in the graph database + graph = db.select_graph(graph_id) + + # Query for all tables + result = graph.query("MATCH (t:Table) RETURN t.name") + print("Tables:", [record[0] for record in result.result_set]) +{% endcapture %} + +{% capture javascript_1 %} +import { PostgreSQLLoader, db } from 'your-pkg'; + +const graphId = "customer_db_schema"; +const connectionUrl = "postgresql://postgres:password@localhost:5432/customers"; + +const [success, message] = await PostgreSQLLoader.load(graphId, connectionUrl); +if (success) { + const graph = db.selectGraph(graphId); + const result = await graph.query("MATCH (t:Table) RETURN t.name"); + console.log("Tables:", result.map(r => r[0])); +} +{% endcapture %} + +{% capture java_1 %} +String graphId = "customer_db_schema"; +String connectionUrl = "postgresql://postgres:password@localhost:5432/customers"; +Pair result = PostgreSQLLoader.load(graphId, connectionUrl); +if (result.getLeft()) { + Graph graph = db.selectGraph(graphId); + ResultSet rs = graph.query("MATCH (t:Table) RETURN t.name"); + // Print table names + for (Record record : rs) { + System.out.println(record.get(0)); + } +} +{% endcapture %} + +{% capture rust_1 %} +let graph_id = "customer_db_schema"; +let connection_url = "postgresql://postgres:password@localhost:5432/customers"; +let (success, message) = postgresql_loader::load(graph_id, connection_url)?; +if success { + let graph = db.select_graph(graph_id); + let result = graph.query("MATCH (t:Table) RETURN t.name")?; + println!("Tables: {:?}", result.iter().map(|r| &r[0]).collect::>()); +} +{% endcapture %} + +{% include code_tabs.html id="integration_tabs" python=python_1 javascript=javascript_1 java=java_1 rust=rust_1 %} + +## Schema Structure + +The loader extracts the following information: + +### Tables +- Table name +- Table description/comment +- Column information +- Foreign key relationships + +### Columns +- Column name +- Data type +- Nullability +- Default values +- Key type (PRIMARY KEY, FOREIGN KEY, or NONE) +- Column descriptions/comments + +### Relationships +- Foreign key constraints +- Referenced tables and columns +- Constraint names and metadata + +## Graph Database Schema + +The extracted schema is stored in the graph database with the following node types: + +- **Database**: Represents the source database +- **Table**: Represents database tables +- **Column**: Represents table columns + +And the following relationship types: + +- **BELONGS_TO**: Connects columns to their tables +- **REFERENCES**: Connects foreign key columns to their referenced columns + +## Error Handling + +The loader handles various error conditions: + +- **Connection Errors**: Invalid connection URLs or database unavailability +- **Permission Errors**: Insufficient database permissions +- **Schema Errors**: Invalid or corrupt schema information +- **Graph Errors**: Issues with graph database operations + +## Example Output + +{% capture shell_2 %} +Extracting table information: 100%|██████████| 15/15 [00:02<00:00, 7.50it/s] +Creating Graph Table Nodes: 100%|██████████| 15/15 [00:05<00:00, 2.80it/s] +Creating embeddings for customers columns: 100%|██████████| 2/2 [00:01<00:00, 1.20it/s] +Creating Graph Columns for customers: 100%|██████████| 8/8 [00:03<00:00, 2.40it/s] +... +Creating Graph Table Relationships: 100%|██████████| 12/12 [00:02<00:00, 5.20it/s] + +PostgreSQL schema loaded successfully. Found 15 tables. +{% endcapture %} + +{% include code_tabs.html id="output_tabs" shell=shell_2 %} + +## Requirements + +- Python 3.12+ +- psycopg2-binary +- Access to a PostgreSQL database +- Existing graph database infrastructure (FalkorDB) + +## Limitations + +- Currently only supports PostgreSQL databases +- Extracts schema from the 'public' schema only +- Requires read permissions on information_schema and pg_* system tables +- Large schemas may take time to process due to embedding generation + +## Troubleshooting + +### Common Issues + +1. **Connection Failed**: Verify the connection URL format and database credentials +2. **Permission Denied**: Ensure the database user has read access to system tables +3. **Schema Not Found**: Check that tables exist in the 'public' schema +4. **Graph Database Error**: Verify that the graph database is running and accessible + +### Debug Mode + +For debugging, you can enable verbose output by modifying the loader to print additional information about the extraction process. From 2792847da7848f12cb95b9cd47a28cf5d6070acb Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 11:40:00 +0300 Subject: [PATCH 26/58] add warning message for distructive operations --- api/index.py | 114 ++++++++++++++++++++++++++++++++++ api/static/css/chat.css | 80 ++++++++++++++++++++++++ api/static/js/chat.js | 132 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 326 insertions(+) diff --git a/api/index.py b/api/index.py index 5f7c8d35..c702f714 100644 --- a/api/index.py +++ b/api/index.py @@ -247,6 +247,52 @@ def generate(): # If the SQL query is valid, execute it using the postgress database db_url if answer_an["is_sql_translatable"]: + # Check if this is a destructive operation that requires confirmation + sql_query = answer_an["sql_query"] + sql_type = sql_query.strip().split()[0].upper() if sql_query else "" + + if sql_type in ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE']: + # This is a destructive operation - ask for user confirmation + confirmation_message = f""" +⚠️ DESTRUCTIVE OPERATION DETECTED ⚠️ + +The generated SQL query will perform a **{sql_type}** operation: + +SQL: +{sql_query} + +What this will do: +""" + if sql_type == 'INSERT': + confirmation_message += "• Add new data to the database" + elif sql_type == 'UPDATE': + confirmation_message += "• Modify existing data in the database" + elif sql_type == 'DELETE': + confirmation_message += "• **PERMANENTLY DELETE** data from the database" + elif sql_type == 'DROP': + confirmation_message += "• **PERMANENTLY DELETE** entire tables or database objects" + elif sql_type == 'CREATE': + confirmation_message += "• Create new tables or database objects" + elif sql_type == 'ALTER': + confirmation_message += "• Modify the structure of existing tables" + elif sql_type == 'TRUNCATE': + confirmation_message += "• **PERMANENTLY DELETE ALL DATA** from specified tables" + + confirmation_message += """ + +⚠️ WARNING: This operation will make changes to your database and may be irreversible. +""" + + yield json.dumps( + { + "type": "destructive_confirmation", + "message": confirmation_message, + "sql_query": sql_query, + "operation_type": sql_type + } + ) + MESSAGE_DELIMITER + return # Stop here and wait for user confirmation + try: step = {"type": "reasoning_step", "message": "Step 2: Executing SQL query"} yield json.dumps(step) + MESSAGE_DELIMITER @@ -287,6 +333,74 @@ def generate(): return Response(stream_with_context(generate()), content_type="application/json") +@app.route("/graphs//confirm", methods=["POST"]) +@token_required # Apply token authentication decorator +def confirm_destructive_operation(graph_id: str): + """ + Handle user confirmation for destructive SQL operations + """ + graph_id = g.user_id + "_" + graph_id.strip() + request_data = request.get_json() + confirmation = request_data.get("confirmation", "").strip().upper() + sql_query = request_data.get("sql_query", "") + queries_history = request_data.get("chat", []) + + if not sql_query: + return jsonify({"error": "No SQL query provided"}), 400 + + # Create a generator function for streaming the confirmation response + def generate_confirmation(): + if confirmation == "CONFIRM": + try: + db_description, db_url = get_db_description(graph_id) + + step = {"type": "reasoning_step", "message": "Step 1: Executing confirmed SQL query"} + yield json.dumps(step) + MESSAGE_DELIMITER + + query_results = PostgresLoader.execute_sql_query(sql_query, db_url) + yield json.dumps( + { + "type": "query_result", + "data": query_results, + } + ) + MESSAGE_DELIMITER + + # Generate user-readable response using AI + step = {"type": "reasoning_step", "message": "Step 2: Generating user-friendly response"} + yield json.dumps(step) + MESSAGE_DELIMITER + + response_agent = ResponseFormatterAgent() + user_readable_response = response_agent.format_response( + user_query=queries_history[-1] if queries_history else "Destructive operation", + sql_query=sql_query, + query_results=query_results, + db_description=db_description + ) + + yield json.dumps( + { + "type": "ai_response", + "message": user_readable_response, + } + ) + MESSAGE_DELIMITER + + except Exception as e: + logging.error("Error executing confirmed SQL query: %s", e) + yield json.dumps( + {"type": "error", "message": f"Error executing query: {str(e)}"} + ) + MESSAGE_DELIMITER + else: + # User cancelled or provided invalid confirmation + yield json.dumps( + { + "type": "operation_cancelled", + "message": "Operation cancelled. The destructive SQL query was not executed." + } + ) + MESSAGE_DELIMITER + + return Response(stream_with_context(generate_confirmation()), content_type="application/json") + + @app.route("/suggestions") @token_required # Apply token authentication decorator def suggestions(): diff --git a/api/static/css/chat.css b/api/static/css/chat.css index 9410ef27..e9d243ff 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -834,4 +834,84 @@ body { .logout-btn:hover { background: #c0392b; +} + +/* Destructive Confirmation Styles */ +.destructive-confirmation-container { + border: 2px solid #ff4444; + border-radius: 8px; + background: linear-gradient(135deg, #2a1f1f, #3a2222); + margin: 10px 0; + transition: all 0.3s ease; +} + +.destructive-confirmation-message { + background: none !important; + border: none !important; +} + +.destructive-confirmation { + padding: 20px; + transition: all 0.3s ease; +} + +.confirmation-text { + margin-bottom: 20px; + line-height: 1.6; + color: #ffffff; + font-size: 14px; +} + +.confirmation-text strong { + color: #ff6666; +} + +.confirmation-buttons { + display: flex; + gap: 15px; + justify-content: center; + margin-top: 20px; +} + +.confirm-btn, .cancel-btn { + padding: 12px 24px; + border: none; + border-radius: 6px; + font-family: 'Fira Code', monospace; + font-size: 14px; + font-weight: bold; + cursor: pointer; + transition: all 0.3s ease; + min-width: 140px; +} + +.confirm-btn { + background: #ff4444; + color: white; + border: 2px solid #ff4444; +} + +.confirm-btn:hover { + background: #ff2222; + border-color: #ff2222; + transform: translateY(-2px); + box-shadow: 0 4px 12px rgba(255, 68, 68, 0.4); +} + +.cancel-btn { + background: transparent; + color: #ffffff; + border: 2px solid #666666; +} + +.cancel-btn:hover { + background: #666666; + border-color: #888888; + transform: translateY(-2px); + box-shadow: 0 4px 12px rgba(102, 102, 102, 0.4); +} + +.confirm-btn:active, .cancel-btn:active { + transform: translateY(0); + box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3); } \ No newline at end of file diff --git a/api/static/js/chat.js b/api/static/js/chat.js index d2bc7c4e..856c0010 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -320,6 +320,12 @@ async function sendMessage() { } else if (step.type === 'ai_response') { // Handle AI-generated user-friendly response addMessage(step.message, false, false, true); + } else if (step.type === 'destructive_confirmation') { + // Handle destructive operation confirmation request + addDestructiveConfirmationMessage(step); + } else if (step.type === 'operation_cancelled') { + // Handle cancelled operation + addMessage(step.message, false, true); } else { // Default handling addMessage(step.message || JSON.stringify(step), false); @@ -394,6 +400,132 @@ function pauseRequest() { } } +function addDestructiveConfirmationMessage(step) { + const messageDiv = document.createElement('div'); + const messageDivContainer = document.createElement('div'); + + messageDivContainer.className = "message-container bot-message-container destructive-confirmation-container"; + messageDiv.className = "message bot-message destructive-confirmation-message"; + + // Create the confirmation UI + const confirmationHTML = ` +
+
${step.message.replace(/\n/g, '
')}
+
+ + +
+
+ `; + + messageDiv.innerHTML = confirmationHTML; + + messageDivContainer.appendChild(messageDiv); + chatMessages.appendChild(messageDivContainer); + chatMessages.scrollTop = chatMessages.scrollHeight; + + // Disable the main input while waiting for confirmation + messageInput.disabled = true; + submitButton.disabled = true; +} + +async function handleDestructiveConfirmation(confirmation, sqlQuery) { + // Re-enable the input + messageInput.disabled = false; + submitButton.disabled = false; + + // Add user's choice as a message + addMessage(`User choice: ${confirmation}`, true); + + if (confirmation === 'CANCEL') { + addMessage("Operation cancelled. The destructive SQL query was not executed.", false, true); + return; + } + + // If confirmed, send confirmation to server + try { + const selectedValue = document.getElementById("graph-select").value; + + const response = await fetch('/graphs/' + selectedValue + '/confirm', { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + confirmation: confirmation, + sql_query: sqlQuery, + chat: questions_history + }) + }); + + if (!response.ok) { + throw new Error(`Server responded with ${response.status}`); + } + + // Process the streaming response + const reader = response.body.getReader(); + let decoder = new TextDecoder(); + let buffer = ''; + + while (true) { + const { done, value } = await reader.read(); + + if (done) { + if (buffer.trim()) { + try { + const step = JSON.parse(buffer); + addMessage(step.message || JSON.stringify(step), false); + } catch (e) { + addMessage(buffer, false); + } + } + break; + } + + const chunk = decoder.decode(value, { stream: true }); + buffer += chunk; + + let delimiterIndex; + while ((delimiterIndex = buffer.indexOf(MESSAGE_DELIMITER)) !== -1) { + const message = buffer.slice(0, delimiterIndex).trim(); + buffer = buffer.slice(delimiterIndex + MESSAGE_DELIMITER.length); + + if (!message) continue; + + try { + const step = JSON.parse(message); + + if (step.type === 'reasoning_step') { + addMessage(step.message, false); + } else if (step.type === 'query_result') { + if (step.data) { + addMessage(`Query Result: ${JSON.stringify(step.data)}`, false, false, true); + } else { + addMessage("No results found for the query.", false); + } + } else if (step.type === 'ai_response') { + addMessage(step.message, false, false, true); + } else if (step.type === 'error') { + addMessage(`Error: ${step.message}`, false, true); + } else { + addMessage(step.message || JSON.stringify(step), false); + } + } catch (e) { + addMessage("Failed: " + message, false); + } + } + } + + } catch (error) { + console.error('Error:', error); + addMessage('Sorry, there was an error processing the confirmation: ' + error.message, false); + } +} + // Event listeners submitButton.addEventListener('click', sendMessage); pauseButton.addEventListener('click', pauseRequest); From 4316c2efbd6dc8c2bea52d88a632e169059fad9b Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 12:03:04 +0300 Subject: [PATCH 27/58] refactor agents file to a folder --- api/agents.py | 564 ------------------------- api/agents/README.md | 67 +++ api/agents/__init__.py | 17 + api/agents/analysis_agent.py | 209 +++++++++ api/agents/follow_up_agent.py | 72 ++++ api/agents/relevancy_agent.py | 89 ++++ api/agents/response_formatter_agent.py | 133 ++++++ api/agents/taxonomy_agent.py | 59 +++ api/agents/utils.py | 33 ++ api/index.py | 3 +- 10 files changed, 680 insertions(+), 566 deletions(-) delete mode 100644 api/agents.py create mode 100644 api/agents/README.md create mode 100644 api/agents/__init__.py create mode 100644 api/agents/analysis_agent.py create mode 100644 api/agents/follow_up_agent.py create mode 100644 api/agents/relevancy_agent.py create mode 100644 api/agents/response_formatter_agent.py create mode 100644 api/agents/taxonomy_agent.py create mode 100644 api/agents/utils.py diff --git a/api/agents.py b/api/agents.py deleted file mode 100644 index caae916a..00000000 --- a/api/agents.py +++ /dev/null @@ -1,564 +0,0 @@ -"""Module containing agent classes for handling analysis and SQL generation tasks.""" - -import json -from typing import Any, Dict, List - -from litellm import completion - -from api.config import Config - - -class AnalysisAgent: - """Agent for analyzing user queries and generating database analysis.""" - - def __init__(self, queries_history: list, result_history: list): - """Initialize the analysis agent with query and result history.""" - if result_history is None: - self.messages = [] - else: - self.messages = [] - for query, result in zip(queries_history[:-1], result_history): - self.messages.append({"role": "user", "content": query}) - self.messages.append({"role": "assistant", "content": result}) - - def get_analysis( - self, - user_query: str, - combined_tables: list, - db_description: str, - instructions: str = None, - ) -> dict: - """Get analysis of user query against database schema.""" - formatted_schema = self._format_schema(combined_tables) - prompt = self._build_prompt( - user_query, formatted_schema, db_description, instructions - ) - self.messages.append({"role": "user", "content": prompt}) - completion_result = completion( - model=Config.COMPLETION_MODEL, - messages=self.messages, - temperature=0, - top_p=1, - ) - - response = completion_result.choices[0].message.content - analysis = _parse_response(response) - if isinstance(analysis["ambiguities"], list): - analysis["ambiguities"] = [ - item.replace("-", " ") for item in analysis["ambiguities"] - ] - analysis["ambiguities"] = "- " + "- ".join(analysis["ambiguities"]) - if isinstance(analysis["missing_information"], list): - analysis["missing_information"] = [ - item.replace("-", " ") for item in analysis["missing_information"] - ] - analysis["missing_information"] = "- " + "- ".join( - analysis["missing_information"] - ) - self.messages.append({"role": "assistant", "content": analysis["sql_query"]}) - return analysis - - def _format_schema(self, schema_data: List) -> str: - """ - Format the schema data into a readable format for the prompt. - - Args: - schema_data: Schema in the structure [...] - - Returns: - Formatted schema as a string - """ - formatted_schema = [] - - for table_info in schema_data: - table_name = table_info[0] - table_description = table_info[1] - foreign_keys = table_info[2] - columns = table_info[3] - - # Format table header - table_str = f"Table: {table_name} - {table_description}\n" - - # Format columns using the updated OrderedDict structure - for column in columns: - col_name = column.get("columnName", "") - col_type = column.get("dataType", None) - col_description = column.get("description", "") - col_key = column.get("keyType", None) - nullable = column.get("nullable", False) - - key_info = ( - ", PRIMARY KEY" - if col_key == "PRI" - else ", FOREIGN KEY" if col_key == "FK" else "" - ) - column_str = (f" - {col_name} ({col_type},{key_info},{col_key}," - f"{nullable}): {col_description}") - table_str += column_str + "\n" - - # Format foreign keys - if isinstance(foreign_keys, dict) and foreign_keys: - table_str += " Foreign Keys:\n" - for fk_name, fk_info in foreign_keys.items(): - column = fk_info.get("column", "") - ref_table = fk_info.get("referenced_table", "") - ref_column = fk_info.get("referenced_column", "") - table_str += ( - f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" - ) - - formatted_schema.append(table_str) - - return "\n".join(formatted_schema) - - def _build_prompt( - self, user_input: str, formatted_schema: str, db_description: str, instructions - ) -> str: - """ - Build the prompt for Claude to analyze the query. - - Args: - user_input: The natural language query from the user - formatted_schema: Formatted database schema - - Returns: - The formatted prompt for Claude - """ - prompt = f""" - You must strictly follow the instructions below. Deviations will result in a penalty to your confidence score. - - MANDATORY RULES: - - Always explain if you cannot fully follow the instructions. - - Always reduce the confidence score if instructions cannot be fully applied. - - Never skip explaining missing information, ambiguities, or instruction issues. - - Respond ONLY in strict JSON format, without extra text. - - If the query relates to a previous question, you MUST take into account the previous question and its answer, and answer based on the context and information provided so far. - - If the user is asking a follow-up or continuing question, use the conversation history and previous answers to resolve references, context, or ambiguities. Always base your analysis on the cumulative context, not just the current question. - - Your output JSON MUST contain all fields, even if empty (e.g., "missing_information": []). - - --- - - Now analyze the user query based on the provided inputs: - - - {db_description} - - - - {instructions} - - - - {formatted_schema} - - - - {self.messages} - - - - {user_input} - - - --- - - Your task: - - - Analyze the query's translatability into SQL according to the instructions. - - Apply the instructions explicitly. - - If you CANNOT apply instructions in the SQL, explain why under - "instructions_comments", "explanation" and reduce your confidence. - - Penalize confidence appropriately if any part of the instructions is unmet. - - When there several tables that can be used to answer the question, - you can combine them in a single SQL query. - - Provide your output ONLY in the following JSON structure: - - ```json - {{ - "is_sql_translatable": true or false, - "instructions_comments": ("Comments about any part of the instructions, " - "especially if they are unclear, impossible, " - "or partially met"), - "explanation": ("Detailed explanation why the query can or cannot be " - "translated, mentioning instructions explicitly and " - "referencing conversation history if relevant"), - "sql_query": ("High-level SQL query (you must to applying instructions " - "and use previous answers if the question is a continuation)"), - "tables_used": ["list", "of", "tables", "used", "in", "the", "query", - "with", "the", "relationships", "between", "them"], - "missing_information": ["list", "of", "missing", "information"], - "ambiguities": ["list", "of", "ambiguities"], - "confidence": integer between 0 and 100 - }} - - Evaluation Guidelines: - - 1. Verify if all requested information exists in the schema. - 2. Check if the query's intent is clear enough for SQL translation. - 3. Identify any ambiguities in the query or instructions. - 4. List missing information explicitly if applicable. - 5. Confirm if necessary joins are possible. - 6. Consider if complex calculations are feasible in SQL. - 7. Identify multiple interpretations if they exist. - 8. Strictly apply instructions; explain and penalize if not possible. - 9. If the question is a follow-up, resolve references using the - conversation history and previous answers. - - Again: OUTPUT ONLY VALID JSON. No explanations outside the JSON block. """ - return prompt - - -class RelevancyAgent: - """Agent for determining relevancy of queries to database schema.""" - - def __init__(self, queries_history: list, result_history: list): - """Initialize the relevancy agent with query and result history.""" - if result_history is None: - self.messages = [] - else: - self.messages = [] - for query, result in zip(queries_history[:-1], result_history): - self.messages.append({"role": "user", "content": query}) - self.messages.append({"role": "assistant", "content": result}) - - def get_answer(self, user_question: str, database_desc: dict) -> dict: - """Get relevancy assessment for user question against database description.""" - self.messages.append( - { - "role": "user", - "content": RELEVANCY_PROMPT.format( - QUESTION_PLACEHOLDER=user_question, - DB_PLACEHOLDER=json.dumps(database_desc), - ), - } - ) - completion_result = completion( - model=Config.COMPLETION_MODEL, - messages=self.messages, - temperature=0, - ) - - answer = completion_result.choices[0].message.content - self.messages.append({"role": "assistant", "content": answer}) - return _parse_response(answer) - - -RELEVANCY_PROMPT = """ -You are an expert assistant tasked with determining whether the user’s question aligns with a given database description and whether the question is appropriate. You receive two inputs: - -The user’s question: {QUESTION_PLACEHOLDER} -The database description: {DB_PLACEHOLDER} -Please follow these instructions: - -Understand the question in the context of the database. -• Ask yourself: “Does this question relate to the data or concepts described in the database description?” -• Common tables that can be found in most of the systems considered "On-topic" even if it not explict in the database description. -• Don't answer questions that related to yourself. -• Don't answer questions that related to personal information unless it related to data in the schemas. -• Questions about the user's (first person) defined as "personal" and is Off-topic. -• Questions about yourself defined as "personal" and is Off-topic. - -Determine if the question is: -• On-topic and appropriate: -– If so, provide a JSON response in the following format: -{{ -"status": "On-topic", -"reason": "Brief explanation of why it is on-topic and appropriate." -"suggestions": [] -}} - -• Off-topic: -– If the question does not align with the data or use cases implied by the schema, provide a JSON response: -{{ -"status": "Off-topic", -"reason": "Short reason explaining why it is off-topic.", -"suggestions": [ -"An alternative, high-level question about the schema..." -] -}} - -• Inappropriate: -– If the question is offensive, illegal, or otherwise violates content guidelines, provide a JSON response: -{{ -"status": "Inappropriate", -"reason": "Short reason why it is inappropriate.", -"suggestions": [ -"Suggested topics that would be more appropriate..." -] -}} - -Ensure your response is concise, polite, and helpful. -""" - - -class FollowUpAgent: - """Agent for handling follow-up questions and conversational context.""" - - def __init__(self): - """Initialize the follow-up agent.""" - - def get_answer( - self, user_question: str, conversation_hist: list, database_schema: dict - ) -> dict: - """Get answer for follow-up questions using conversation history.""" - completion_result = completion( - model=Config.COMPLETION_MODEL, - messages=[ - { - "content": FOLLOW_UP_PROMPT.format( - QUESTION=user_question, - HISTORY=conversation_hist, - SCHEMA=json.dumps(database_schema), - ), - "role": "user", - } - ], - response_format={"type": "json_object"}, - temperature=0, - ) - - answer = completion_result.choices[0].message.content - return json.loads(answer) - - -FOLLOW_UP_PROMPT = """You are an expert assistant that receives two inputs: - -1. The user’s question: {QUESTION} -2. The history of his questions: {HISTORY} -3. A detected database schema (all relevant tables, columns, and their descriptions): {SCHEMA} - -Your primary goal is to decide if the user’s questions can be addressed using the existing schema or if new or additional data is required. -Any thing that can be calculated from the provided tables is define the status Data-focused. -Please follow these steps: - -1. Understand the user’s question in the context of the provided schema. -• Determine whether the question directly relates to the tables, columns, or concepts in the schema or needed more information about the filtering. - -2. If the question relates to the existing schema: -• Provide a concise JSON response indicating: -{{ -"status": "Data-focused", -"reason": "Brief explanation why this question is answerable with the given schema." -"followUpQuestion": "" -}} -• If relevant, note any additional observations or suggested follow-up. - -3. If the question cannot be answered solely with the given schema or if there seems to be missing context: -• Ask clarifying questions to confirm the user’s intent or to gather any necessary information. -• Use a JSON format such as: -{{ -"status": "Needs more data", -"reason": "Reason why the current schema is insufficient.", -"followUpQuestion": "Single question to clarify user intent or additional data needed, can be a specific value..." - -}} - -4. Ensure your response is concise, polite, and helpful. When asking clarifying - questions, be specific and guide the user toward providing the missing details - so you can effectively address their query.""" - - -class TaxonomyAgent: - """Agent for taxonomy classification of questions and SQL queries.""" - - def __init__(self): - """Initialize the taxonomy agent.""" - - def get_answer(self, question: str, sql: str) -> str: - """Get taxonomy classification for a question and SQL pair.""" - messages = [ - { - "content": TAXONOMY_PROMPT.format(QUESTION=question, SQL=sql), - "role": "user", - } - ] - completion_result = completion( - model=Config.COMPLETION_MODEL, - messages=messages, - temperature=0, - ) - - answer = completion_result.choices[0].message.content - return answer - - -TAXONOMY_PROMPT = """You are an advanced taxonomy generator. For a pair of question and SQL query \ -provde a single clarification question to the user. -* For any SQL query that contain WHERE clause, provide a clarification question to the user about the \ -generated value. -* Your question can contain more than one clarification related to WHERE clause. -* Please asked only about the clarifications that you need and not extand the answer. -* Please ask in a polite, humen, and concise manner. -* Do not meantion any tables or columns in your ouput!. -* If you dont need any clarification, please answer with "I don't need any clarification." -* The user didnt saw the SQL queryor the tables, so please understand this position and ask the \ -clarification in that way he have the relevent information to answer. -* When you ask the user to confirm a value, please provide the value in your answer. -* Mention only question about values and dont mention the SQL query or the tables in your answer. - -Please create the clarification question step by step. - -Question: -{QUESTION} - -SQL: -{SQL} - -For example: -question: "How many diabetic patients are there?" -SQL: "SELECT COUNT(*) FROM patients WHERE disease_code = 'E11'" -Your output: "The diabitic desease code is E11? If not, please provide the correct diabitic desease code. - -The question to the user:" -""" - - -def _parse_response(response: str) -> Dict[str, Any]: - """ - Parse Claude's response to extract the analysis. - - Args: - response: Claude's response string - - Returns: - Parsed analysis results - """ - try: - # Extract JSON from the response - json_start = response.find("{") - json_end = response.rfind("}") + 1 - json_str = response[json_start:json_end] - - # Parse the JSON - analysis = json.loads(json_str) - return analysis - except (json.JSONDecodeError, ValueError) as e: - # Fallback if JSON parsing fails - return { - "is_sql_translatable": False, - "confidence": 0, - "explanation": f"Failed to parse response: {str(e)}", - "error": str(response), - } - - -class ResponseFormatterAgent: - """Agent for generating user-readable responses from SQL query results.""" - - def __init__(self): - """Initialize the response formatter agent.""" - pass - - def format_response(self, user_query: str, sql_query: str, query_results: List[Dict], db_description: str = "") -> str: - """ - Generate a user-readable response based on the SQL query results. - - Args: - user_query: The original user question - sql_query: The SQL query that was executed - query_results: The results from the SQL query execution - db_description: Description of the database context - - Returns: - A formatted, user-readable response string - """ - prompt = self._build_response_prompt(user_query, sql_query, query_results, db_description) - - messages = [{"role": "user", "content": prompt}] - - completion_result = completion( - model=Config.COMPLETION_MODEL, - messages=messages, - temperature=0.3, # Slightly higher temperature for more natural responses - top_p=1, - ) - - response = completion_result.choices[0].message.content - return response.strip() - - def _build_response_prompt(self, user_query: str, sql_query: str, query_results: List[Dict], db_description: str) -> str: - """Build the prompt for generating user-readable responses.""" - - # Format the query results for better readability - formatted_results = self._format_query_results(query_results) - - # Determine the type of SQL operation - sql_type = sql_query.strip().split()[0].upper() if sql_query else "UNKNOWN" - - prompt = f""" -You are an AI assistant that helps users understand database query results. Your task is to analyze the SQL query results and provide a clear, concise, and user-friendly explanation. - -**Context:** -Database Description: {db_description if db_description else "Not provided"} - -**User's Original Question:** -{user_query} - -**SQL Query Executed:** -{sql_query} - -**Query Type:** {sql_type} - -**Query Results:** -{formatted_results} - -**Instructions:** -1. Provide a clear, natural language answer to the user's question based on the query results -2. For SELECT queries: Focus on the key insights and findings from the data -3. For INSERT/UPDATE/DELETE queries: Confirm the operation was successful and mention how many records were affected -4. For other operations (CREATE, DROP, etc.): Confirm the operation was completed successfully -5. Use bullet points or numbered lists when presenting multiple items -6. Include relevant numbers, percentages, or trends if applicable -7. Be concise but comprehensive - avoid unnecessary technical jargon -8. If the results are empty, explain that no data was found matching the criteria -9. If there are many results, provide a summary with highlights -10. Do not mention the SQL query or technical database details unless specifically relevant to the user's understanding - -**Response Format:** -Provide a direct answer to the user's question in a conversational tone, as if you were explaining the findings to a colleague. -""" - - return prompt - - def _format_query_results(self, query_results: List[Dict]) -> str: - """Format query results for inclusion in the prompt.""" - if not query_results: - return "No results found." - - if len(query_results) == 0: - return "No results found." - - # Check if this is an operation result (INSERT/UPDATE/DELETE) - if len(query_results) == 1 and "operation" in query_results[0]: - result = query_results[0] - operation = result.get("operation", "UNKNOWN") - affected_rows = result.get("affected_rows") - status = result.get("status", "unknown") - - if affected_rows is not None: - return f"Operation: {operation}, Status: {status}, Affected rows: {affected_rows}" - else: - return f"Operation: {operation}, Status: {status}" - - # Handle regular SELECT query results - # Limit the number of results shown in the prompt to avoid token limits - max_results_to_show = 50 - results_to_show = query_results[:max_results_to_show] - - formatted = [] - for i, result in enumerate(results_to_show, 1): - if isinstance(result, dict): - result_str = ", ".join([f"{k}: {v}" for k, v in result.items()]) - formatted.append(f"{i}. {result_str}") - else: - formatted.append(f"{i}. {result}") - - result_text = "\n".join(formatted) - - if len(query_results) > max_results_to_show: - result_text += f"\n... and {len(query_results) - max_results_to_show} more results" - - return result_text diff --git a/api/agents/README.md b/api/agents/README.md new file mode 100644 index 00000000..75f66bb7 --- /dev/null +++ b/api/agents/README.md @@ -0,0 +1,67 @@ +# Agents Module + +This module contains various AI agents for the text2sql application. Each agent is responsible for a specific task in the query processing pipeline. + +## Agents + +### AnalysisAgent (`analysis_agent.py`) +- **Purpose**: Analyzes user queries and generates database analysis +- **Key Method**: `get_analysis()` - Analyzes user queries against database schema +- **Features**: Schema formatting, prompt building, conversation history tracking + +### RelevancyAgent (`relevancy_agent.py`) +- **Purpose**: Determines if queries are relevant to the database schema +- **Key Method**: `get_answer()` - Assesses query relevancy against database description +- **Features**: Topic classification (On-topic, Off-topic, Inappropriate) + +### FollowUpAgent (`follow_up_agent.py`) +- **Purpose**: Handles follow-up questions and conversational context +- **Key Method**: `get_answer()` - Processes follow-up questions using conversation history +- **Features**: Context awareness, data availability assessment + +### TaxonomyAgent (`taxonomy_agent.py`) +- **Purpose**: Provides taxonomy classification and clarification for SQL queries +- **Key Method**: `get_answer()` - Generates clarification questions for SQL queries +- **Features**: WHERE clause analysis, user-friendly clarifications + +### ResponseFormatterAgent (`response_formatter_agent.py`) +- **Purpose**: Formats SQL query results into user-readable responses +- **Key Method**: `format_response()` - Converts raw SQL results to natural language +- **Features**: Result formatting, operation type detection, user-friendly explanations + +## Utilities + +### utils.py +- **parse_response()**: Shared utility function for parsing JSON responses from AI models +- Used by multiple agents for consistent response parsing + +## Usage + +```python +from api.agents import AnalysisAgent, RelevancyAgent, ResponseFormatterAgent + +# Initialize agents +analysis_agent = AnalysisAgent(queries_history, result_history) +relevancy_agent = RelevancyAgent(queries_history, result_history) +formatter_agent = ResponseFormatterAgent() + +# Use agents +analysis = analysis_agent.get_analysis(query, tables, db_description) +relevancy = relevancy_agent.get_answer(question, database_desc) +response = formatter_agent.format_response(query, sql, results, db_description) +``` + +## Architecture + +Each agent follows a consistent pattern: +1. **Initialization**: Set up with necessary context (history, configuration) +2. **Main Method**: Primary interface for the agent's functionality +3. **Helper Methods**: Private methods for internal processing +4. **Prompt Templates**: Stored as module-level constants for easy maintenance +5. **LLM Integration**: Uses litellm for AI model interactions + +This modular structure improves: +- **Maintainability**: Each agent is self-contained +- **Testability**: Agents can be tested independently +- **Reusability**: Agents can be used in different contexts +- **Scalability**: New agents can be added without affecting existing ones diff --git a/api/agents/__init__.py b/api/agents/__init__.py new file mode 100644 index 00000000..1d508cc4 --- /dev/null +++ b/api/agents/__init__.py @@ -0,0 +1,17 @@ +"""Agents package for text2sql application.""" + +from .analysis_agent import AnalysisAgent +from .relevancy_agent import RelevancyAgent +from .follow_up_agent import FollowUpAgent +from .taxonomy_agent import TaxonomyAgent +from .response_formatter_agent import ResponseFormatterAgent +from .utils import parse_response + +__all__ = [ + "AnalysisAgent", + "RelevancyAgent", + "FollowUpAgent", + "TaxonomyAgent", + "ResponseFormatterAgent", + "parse_response" +] diff --git a/api/agents/analysis_agent.py b/api/agents/analysis_agent.py new file mode 100644 index 00000000..52494d66 --- /dev/null +++ b/api/agents/analysis_agent.py @@ -0,0 +1,209 @@ +"""Analysis agent for analyzing user queries and generating database analysis.""" + +from typing import List +from litellm import completion +from api.config import Config +from .utils import parse_response + + +class AnalysisAgent: + """Agent for analyzing user queries and generating database analysis.""" + + def __init__(self, queries_history: list, result_history: list): + """Initialize the analysis agent with query and result history.""" + if result_history is None: + self.messages = [] + else: + self.messages = [] + for query, result in zip(queries_history[:-1], result_history): + self.messages.append({"role": "user", "content": query}) + self.messages.append({"role": "assistant", "content": result}) + + def get_analysis( + self, + user_query: str, + combined_tables: list, + db_description: str, + instructions: str = None, + ) -> dict: + """Get analysis of user query against database schema.""" + formatted_schema = self._format_schema(combined_tables) + prompt = self._build_prompt( + user_query, formatted_schema, db_description, instructions + ) + self.messages.append({"role": "user", "content": prompt}) + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=self.messages, + temperature=0, + top_p=1, + ) + + response = completion_result.choices[0].message.content + analysis = parse_response(response) + if isinstance(analysis["ambiguities"], list): + analysis["ambiguities"] = [ + item.replace("-", " ") for item in analysis["ambiguities"] + ] + analysis["ambiguities"] = "- " + "- ".join(analysis["ambiguities"]) + if isinstance(analysis["missing_information"], list): + analysis["missing_information"] = [ + item.replace("-", " ") for item in analysis["missing_information"] + ] + analysis["missing_information"] = "- " + "- ".join( + analysis["missing_information"] + ) + self.messages.append({"role": "assistant", "content": analysis["sql_query"]}) + return analysis + + def _format_schema(self, schema_data: List) -> str: + """ + Format the schema data into a readable format for the prompt. + + Args: + schema_data: Schema in the structure [...] + + Returns: + Formatted schema as a string + """ + formatted_schema = [] + + for table_info in schema_data: + table_name = table_info[0] + table_description = table_info[1] + foreign_keys = table_info[2] + columns = table_info[3] + + # Format table header + table_str = f"Table: {table_name} - {table_description}\n" + + # Format columns using the updated OrderedDict structure + for column in columns: + col_name = column.get("columnName", "") + col_type = column.get("dataType", None) + col_description = column.get("description", "") + col_key = column.get("keyType", None) + nullable = column.get("nullable", False) + + key_info = ( + ", PRIMARY KEY" + if col_key == "PRI" + else ", FOREIGN KEY" if col_key == "FK" else "" + ) + column_str = (f" - {col_name} ({col_type},{key_info},{col_key}," + f"{nullable}): {col_description}") + table_str += column_str + "\n" + + # Format foreign keys + if isinstance(foreign_keys, dict) and foreign_keys: + table_str += " Foreign Keys:\n" + for fk_name, fk_info in foreign_keys.items(): + column = fk_info.get("column", "") + ref_table = fk_info.get("referenced_table", "") + ref_column = fk_info.get("referenced_column", "") + table_str += ( + f" - {fk_name}: {column} references {ref_table}.{ref_column}\n" + ) + + formatted_schema.append(table_str) + + return "\n".join(formatted_schema) + + def _build_prompt( + self, user_input: str, formatted_schema: str, db_description: str, instructions + ) -> str: + """ + Build the prompt for Claude to analyze the query. + + Args: + user_input: The natural language query from the user + formatted_schema: Formatted database schema + + Returns: + The formatted prompt for Claude + """ + prompt = f""" + You must strictly follow the instructions below. Deviations will result in a penalty to your confidence score. + + MANDATORY RULES: + - Always explain if you cannot fully follow the instructions. + - Always reduce the confidence score if instructions cannot be fully applied. + - Never skip explaining missing information, ambiguities, or instruction issues. + - Respond ONLY in strict JSON format, without extra text. + - If the query relates to a previous question, you MUST take into account the previous question and its answer, and answer based on the context and information provided so far. + + If the user is asking a follow-up or continuing question, use the conversation history and previous answers to resolve references, context, or ambiguities. Always base your analysis on the cumulative context, not just the current question. + + Your output JSON MUST contain all fields, even if empty (e.g., "missing_information": []). + + --- + + Now analyze the user query based on the provided inputs: + + + {db_description} + + + + {instructions} + + + + {formatted_schema} + + + + {self.messages} + + + + {user_input} + + + --- + + Your task: + + - Analyze the query's translatability into SQL according to the instructions. + - Apply the instructions explicitly. + - If you CANNOT apply instructions in the SQL, explain why under + "instructions_comments", "explanation" and reduce your confidence. + - Penalize confidence appropriately if any part of the instructions is unmet. + - When there several tables that can be used to answer the question, + you can combine them in a single SQL query. + + Provide your output ONLY in the following JSON structure: + + ```json + {{ + "is_sql_translatable": true or false, + "instructions_comments": ("Comments about any part of the instructions, " + "especially if they are unclear, impossible, " + "or partially met"), + "explanation": ("Detailed explanation why the query can or cannot be " + "translated, mentioning instructions explicitly and " + "referencing conversation history if relevant"), + "sql_query": ("High-level SQL query (you must to applying instructions " + "and use previous answers if the question is a continuation)"), + "tables_used": ["list", "of", "tables", "used", "in", "the", "query", + "with", "the", "relationships", "between", "them"], + "missing_information": ["list", "of", "missing", "information"], + "ambiguities": ["list", "of", "ambiguities"], + "confidence": integer between 0 and 100 + }} + + Evaluation Guidelines: + + 1. Verify if all requested information exists in the schema. + 2. Check if the query's intent is clear enough for SQL translation. + 3. Identify any ambiguities in the query or instructions. + 4. List missing information explicitly if applicable. + 5. Confirm if necessary joins are possible. + 6. Consider if complex calculations are feasible in SQL. + 7. Identify multiple interpretations if they exist. + 8. Strictly apply instructions; explain and penalize if not possible. + 9. If the question is a follow-up, resolve references using the + conversation history and previous answers. + + Again: OUTPUT ONLY VALID JSON. No explanations outside the JSON block. """ + return prompt diff --git a/api/agents/follow_up_agent.py b/api/agents/follow_up_agent.py new file mode 100644 index 00000000..21f40e90 --- /dev/null +++ b/api/agents/follow_up_agent.py @@ -0,0 +1,72 @@ +"""Follow-up agent for handling follow-up questions and conversational context.""" + +import json +from litellm import completion +from api.config import Config + + +FOLLOW_UP_PROMPT = """You are an expert assistant that receives two inputs: + +1. The user's question: {QUESTION} +2. The history of his questions: {HISTORY} +3. A detected database schema (all relevant tables, columns, and their descriptions): {SCHEMA} + +Your primary goal is to decide if the user's questions can be addressed using the existing schema or if new or additional data is required. +Any thing that can be calculated from the provided tables is define the status Data-focused. +Please follow these steps: + +1. Understand the user's question in the context of the provided schema. +• Determine whether the question directly relates to the tables, columns, or concepts in the schema or needed more information about the filtering. + +2. If the question relates to the existing schema: +• Provide a concise JSON response indicating: +{{ +"status": "Data-focused", +"reason": "Brief explanation why this question is answerable with the given schema." +"followUpQuestion": "" +}} +• If relevant, note any additional observations or suggested follow-up. + +3. If the question cannot be answered solely with the given schema or if there seems to be missing context: +• Ask clarifying questions to confirm the user's intent or to gather any necessary information. +• Use a JSON format such as: +{{ +"status": "Needs more data", +"reason": "Reason why the current schema is insufficient.", +"followUpQuestion": "Single question to clarify user intent or additional data needed, can be a specific value..." + +}} + +4. Ensure your response is concise, polite, and helpful. When asking clarifying + questions, be specific and guide the user toward providing the missing details + so you can effectively address their query.""" + + +class FollowUpAgent: + """Agent for handling follow-up questions and conversational context.""" + + def __init__(self): + """Initialize the follow-up agent.""" + + def get_answer( + self, user_question: str, conversation_hist: list, database_schema: dict + ) -> dict: + """Get answer for follow-up questions using conversation history.""" + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=[ + { + "content": FOLLOW_UP_PROMPT.format( + QUESTION=user_question, + HISTORY=conversation_hist, + SCHEMA=json.dumps(database_schema), + ), + "role": "user", + } + ], + response_format={"type": "json_object"}, + temperature=0, + ) + + answer = completion_result.choices[0].message.content + return json.loads(answer) diff --git a/api/agents/relevancy_agent.py b/api/agents/relevancy_agent.py new file mode 100644 index 00000000..931d8ee4 --- /dev/null +++ b/api/agents/relevancy_agent.py @@ -0,0 +1,89 @@ +"""Relevancy agent for determining relevancy of queries to database schema.""" + +import json +from litellm import completion +from api.config import Config +from .utils import parse_response + + +RELEVANCY_PROMPT = """ +You are an expert assistant tasked with determining whether the user's question aligns with a given database description and whether the question is appropriate. You receive two inputs: + +The user's question: {QUESTION_PLACEHOLDER} +The database description: {DB_PLACEHOLDER} +Please follow these instructions: + +Understand the question in the context of the database. +• Ask yourself: "Does this question relate to the data or concepts described in the database description?" +• Common tables that can be found in most of the systems considered "On-topic" even if it not explict in the database description. +• Don't answer questions that related to yourself. +• Don't answer questions that related to personal information unless it related to data in the schemas. +• Questions about the user's (first person) defined as "personal" and is Off-topic. +• Questions about yourself defined as "personal" and is Off-topic. + +Determine if the question is: +• On-topic and appropriate: +– If so, provide a JSON response in the following format: +{{ +"status": "On-topic", +"reason": "Brief explanation of why it is on-topic and appropriate." +"suggestions": [] +}} + +• Off-topic: +– If the question does not align with the data or use cases implied by the schema, provide a JSON response: +{{ +"status": "Off-topic", +"reason": "Short reason explaining why it is off-topic.", +"suggestions": [ +"An alternative, high-level question about the schema..." +] +}} + +• Inappropriate: +– If the question is offensive, illegal, or otherwise violates content guidelines, provide a JSON response: +{{ +"status": "Inappropriate", +"reason": "Short reason why it is inappropriate.", +"suggestions": [ +"Suggested topics that would be more appropriate..." +] +}} + +Ensure your response is concise, polite, and helpful. +""" + + +class RelevancyAgent: + """Agent for determining relevancy of queries to database schema.""" + + def __init__(self, queries_history: list, result_history: list): + """Initialize the relevancy agent with query and result history.""" + if result_history is None: + self.messages = [] + else: + self.messages = [] + for query, result in zip(queries_history[:-1], result_history): + self.messages.append({"role": "user", "content": query}) + self.messages.append({"role": "assistant", "content": result}) + + def get_answer(self, user_question: str, database_desc: dict) -> dict: + """Get relevancy assessment for user question against database description.""" + self.messages.append( + { + "role": "user", + "content": RELEVANCY_PROMPT.format( + QUESTION_PLACEHOLDER=user_question, + DB_PLACEHOLDER=json.dumps(database_desc), + ), + } + ) + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=self.messages, + temperature=0, + ) + + answer = completion_result.choices[0].message.content + self.messages.append({"role": "assistant", "content": answer}) + return parse_response(answer) diff --git a/api/agents/response_formatter_agent.py b/api/agents/response_formatter_agent.py new file mode 100644 index 00000000..413a1558 --- /dev/null +++ b/api/agents/response_formatter_agent.py @@ -0,0 +1,133 @@ +"""Response formatter agent for generating user-readable responses from SQL query results.""" + +from typing import List, Dict +from litellm import completion +from api.config import Config + + +RESPONSE_FORMATTER_PROMPT = """ +You are an AI assistant that helps users understand database query results. Your task is to analyze the SQL query results and provide a clear, concise, and user-friendly explanation. + +**Context:** +Database Description: {DB_DESCRIPTION} + +**User's Original Question:** +{USER_QUERY} + +**SQL Query Executed:** +{SQL_QUERY} + +**Query Type:** {SQL_TYPE} + +**Query Results:** +{FORMATTED_RESULTS} + +**Instructions:** +1. Provide a clear, natural language answer to the user's question based on the query results +2. For SELECT queries: Focus on the key insights and findings from the data +3. For INSERT/UPDATE/DELETE queries: Confirm the operation was successful and mention how many records were affected +4. For other operations (CREATE, DROP, etc.): Confirm the operation was completed successfully +5. Use bullet points or numbered lists when presenting multiple items +6. Include relevant numbers, percentages, or trends if applicable +7. Be concise but comprehensive - avoid unnecessary technical jargon +8. If the results are empty, explain that no data was found matching the criteria +9. If there are many results, provide a summary with highlights +10. Do not mention the SQL query or technical database details unless specifically relevant to the user's understanding + +**Response Format:** +Provide a direct answer to the user's question in a conversational tone, as if you were explaining the findings to a colleague. +""" + + +class ResponseFormatterAgent: + """Agent for generating user-readable responses from SQL query results.""" + + def __init__(self): + """Initialize the response formatter agent.""" + pass + + def format_response(self, user_query: str, sql_query: str, query_results: List[Dict], db_description: str = "") -> str: + """ + Generate a user-readable response based on the SQL query results. + + Args: + user_query: The original user question + sql_query: The SQL query that was executed + query_results: The results from the SQL query execution + db_description: Description of the database context + + Returns: + A formatted, user-readable response string + """ + prompt = self._build_response_prompt(user_query, sql_query, query_results, db_description) + + messages = [{"role": "user", "content": prompt}] + + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=messages, + temperature=0.3, # Slightly higher temperature for more natural responses + top_p=1, + ) + + response = completion_result.choices[0].message.content + return response.strip() + + def _build_response_prompt(self, user_query: str, sql_query: str, query_results: List[Dict], db_description: str) -> str: + """Build the prompt for generating user-readable responses.""" + + # Format the query results for better readability + formatted_results = self._format_query_results(query_results) + + # Determine the type of SQL operation + sql_type = sql_query.strip().split()[0].upper() if sql_query else "UNKNOWN" + + prompt = RESPONSE_FORMATTER_PROMPT.format( + DB_DESCRIPTION=db_description if db_description else "Not provided", + USER_QUERY=user_query, + SQL_QUERY=sql_query, + SQL_TYPE=sql_type, + FORMATTED_RESULTS=formatted_results + ) + + return prompt + + def _format_query_results(self, query_results: List[Dict]) -> str: + """Format query results for inclusion in the prompt.""" + if not query_results: + return "No results found." + + if len(query_results) == 0: + return "No results found." + + # Check if this is an operation result (INSERT/UPDATE/DELETE) + if len(query_results) == 1 and "operation" in query_results[0]: + result = query_results[0] + operation = result.get("operation", "UNKNOWN") + affected_rows = result.get("affected_rows") + status = result.get("status", "unknown") + + if affected_rows is not None: + return f"Operation: {operation}, Status: {status}, Affected rows: {affected_rows}" + else: + return f"Operation: {operation}, Status: {status}" + + # Handle regular SELECT query results + # Limit the number of results shown in the prompt to avoid token limits + max_results_to_show = 50 + results_to_show = query_results[:max_results_to_show] + + formatted = [] + for i, result in enumerate(results_to_show, 1): + if isinstance(result, dict): + result_str = ", ".join([f"{k}: {v}" for k, v in result.items()]) + formatted.append(f"{i}. {result_str}") + else: + formatted.append(f"{i}. {result}") + + result_text = "\n".join(formatted) + + if len(query_results) > max_results_to_show: + result_text += f"\n... and {len(query_results) - max_results_to_show} more results" + + return result_text diff --git a/api/agents/taxonomy_agent.py b/api/agents/taxonomy_agent.py new file mode 100644 index 00000000..f3088a39 --- /dev/null +++ b/api/agents/taxonomy_agent.py @@ -0,0 +1,59 @@ +"""Taxonomy agent for taxonomy classification of questions and SQL queries.""" + +from litellm import completion +from api.config import Config + + +TAXONOMY_PROMPT = """You are an advanced taxonomy generator. For a pair of question and SQL query \ +provde a single clarification question to the user. +* For any SQL query that contain WHERE clause, provide a clarification question to the user about the \ +generated value. +* Your question can contain more than one clarification related to WHERE clause. +* Please asked only about the clarifications that you need and not extand the answer. +* Please ask in a polite, humen, and concise manner. +* Do not meantion any tables or columns in your ouput!. +* If you dont need any clarification, please answer with "I don't need any clarification." +* The user didnt saw the SQL queryor the tables, so please understand this position and ask the \ +clarification in that way he have the relevent information to answer. +* When you ask the user to confirm a value, please provide the value in your answer. +* Mention only question about values and dont mention the SQL query or the tables in your answer. + +Please create the clarification question step by step. + +Question: +{QUESTION} + +SQL: +{SQL} + +For example: +question: "How many diabetic patients are there?" +SQL: "SELECT COUNT(*) FROM patients WHERE disease_code = 'E11'" +Your output: "The diabitic desease code is E11? If not, please provide the correct diabitic desease code. + +The question to the user:" +""" + + +class TaxonomyAgent: + """Agent for taxonomy classification of questions and SQL queries.""" + + def __init__(self): + """Initialize the taxonomy agent.""" + + def get_answer(self, question: str, sql: str) -> str: + """Get taxonomy classification for a question and SQL pair.""" + messages = [ + { + "content": TAXONOMY_PROMPT.format(QUESTION=question, SQL=sql), + "role": "user", + } + ] + completion_result = completion( + model=Config.COMPLETION_MODEL, + messages=messages, + temperature=0, + ) + + answer = completion_result.choices[0].message.content + return answer diff --git a/api/agents/utils.py b/api/agents/utils.py new file mode 100644 index 00000000..25fefd2a --- /dev/null +++ b/api/agents/utils.py @@ -0,0 +1,33 @@ +"""Utility functions for agents.""" + +import json +from typing import Any, Dict + + +def parse_response(response: str) -> Dict[str, Any]: + """ + Parse Claude's response to extract the analysis. + + Args: + response: Claude's response string + + Returns: + Parsed analysis results + """ + try: + # Extract JSON from the response + json_start = response.find("{") + json_end = response.rfind("}") + 1 + json_str = response[json_start:json_end] + + # Parse the JSON + analysis = json.loads(json_str) + return analysis + except (json.JSONDecodeError, ValueError) as e: + # Fallback if JSON parsing fails + return { + "is_sql_translatable": False, + "confidence": 0, + "explanation": f"Failed to parse response: {str(e)}", + "error": str(response), + } diff --git a/api/index.py b/api/index.py index c702f714..072a89c2 100644 --- a/api/index.py +++ b/api/index.py @@ -253,8 +253,7 @@ def generate(): if sql_type in ['INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE']: # This is a destructive operation - ask for user confirmation - confirmation_message = f""" -⚠️ DESTRUCTIVE OPERATION DETECTED ⚠️ + confirmation_message = f"""⚠️ DESTRUCTIVE OPERATION DETECTED ⚠️ The generated SQL query will perform a **{sql_type}** operation: From 24f882a79cd47e746814afe86b81ab9f81e5443a Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 12:20:19 +0300 Subject: [PATCH 28/58] disable buttons on warning message after selection --- api/static/css/chat.css | 17 +++++++++++++++++ api/static/js/chat.js | 21 +++++++++++++++++---- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index e9d243ff..97cc8494 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -914,4 +914,21 @@ body { .confirm-btn:active, .cancel-btn:active { transform: translateY(0); box-shadow: 0 2px 6px rgba(0, 0, 0, 0.3); +} + +.confirm-btn:disabled, .cancel-btn:disabled { + background: #cccccc; + color: #888888; + border-color: #cccccc; + cursor: not-allowed; + transform: none; + box-shadow: none; +} + +.confirm-btn:disabled:hover, .cancel-btn:disabled:hover { + background: #cccccc; + color: #888888; + border-color: #cccccc; + transform: none; + box-shadow: none; } \ No newline at end of file diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 856c0010..144a5ea9 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -407,15 +407,18 @@ function addDestructiveConfirmationMessage(step) { messageDivContainer.className = "message-container bot-message-container destructive-confirmation-container"; messageDiv.className = "message bot-message destructive-confirmation-message"; + // Generate a unique ID for this confirmation dialog + const confirmationId = 'confirmation-' + Date.now(); + // Create the confirmation UI const confirmationHTML = ` -
+
${step.message.replace(/\n/g, '
')}
- -
@@ -433,7 +436,17 @@ function addDestructiveConfirmationMessage(step) { submitButton.disabled = true; } -async function handleDestructiveConfirmation(confirmation, sqlQuery) { +async function handleDestructiveConfirmation(confirmation, sqlQuery, confirmationId) { + // Find the specific confirmation dialog using the unique ID + const confirmationDialog = document.querySelector(`[data-confirmation-id="${confirmationId}"]`); + if (confirmationDialog) { + // Disable both confirmation buttons within this specific dialog + const confirmBtn = confirmationDialog.querySelector('.confirm-btn'); + const cancelBtn = confirmationDialog.querySelector('.cancel-btn'); + if (confirmBtn) confirmBtn.disabled = true; + if (cancelBtn) cancelBtn.disabled = true; + } + // Re-enable the input messageInput.disabled = false; submitButton.disabled = false; From 2843da9bf4de6ff687d5f4ee4f3e3c6b5aec8f49 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 13:59:35 +0300 Subject: [PATCH 29/58] refresh schema on db change --- api/index.py | 103 ++++++++++++++++++++- api/loaders/postgres_loader.py | 158 +++++++++++++++++++++++++++------ 2 files changed, 230 insertions(+), 31 deletions(-) diff --git a/api/index.py b/api/index.py index 072a89c2..018709bb 100644 --- a/api/index.py +++ b/api/index.py @@ -296,6 +296,9 @@ def generate(): step = {"type": "reasoning_step", "message": "Step 2: Executing SQL query"} yield json.dumps(step) + MESSAGE_DELIMITER + # Check if this query modifies the database schema + is_schema_modifying, operation_type = PostgresLoader.is_schema_modifying_query(sql_query) + query_results = PostgresLoader.execute_sql_query(answer_an["sql_query"], db_url) yield json.dumps( { @@ -304,8 +307,33 @@ def generate(): } ) + MESSAGE_DELIMITER + # If schema was modified, refresh the graph + if is_schema_modifying: + step = {"type": "reasoning_step", "message": "Step 3: Schema change detected - refreshing graph..."} + yield json.dumps(step) + MESSAGE_DELIMITER + + refresh_success, refresh_message = PostgresLoader.refresh_graph_schema(graph_id, db_url) + + if refresh_success: + yield json.dumps( + { + "type": "schema_refresh", + "message": f"✅ Schema change detected ({operation_type} operation)\n\n🔄 Graph schema has been automatically refreshed with the latest database structure.", + "refresh_status": "success" + } + ) + MESSAGE_DELIMITER + else: + yield json.dumps( + { + "type": "schema_refresh", + "message": f"⚠️ Schema was modified but graph refresh failed: {refresh_message}", + "refresh_status": "failed" + } + ) + MESSAGE_DELIMITER + # Generate user-readable response using AI - step = {"type": "reasoning_step", "message": "Step 3: Generating user-friendly response"} + step_num = "4" if is_schema_modifying else "3" + step = {"type": "reasoning_step", "message": f"Step {step_num}: Generating user-friendly response"} yield json.dumps(step) + MESSAGE_DELIMITER response_agent = ResponseFormatterAgent() @@ -353,9 +381,12 @@ def generate_confirmation(): try: db_description, db_url = get_db_description(graph_id) - step = {"type": "reasoning_step", "message": "Step 1: Executing confirmed SQL query"} + step = {"type": "reasoning_step", "message": "Step 2: Executing confirmed SQL query"} yield json.dumps(step) + MESSAGE_DELIMITER + # Check if this query modifies the database schema + is_schema_modifying, operation_type = PostgresLoader.is_schema_modifying_query(sql_query) + query_results = PostgresLoader.execute_sql_query(sql_query, db_url) yield json.dumps( { @@ -364,8 +395,33 @@ def generate_confirmation(): } ) + MESSAGE_DELIMITER + # If schema was modified, refresh the graph + if is_schema_modifying: + step = {"type": "reasoning_step", "message": "Step 3: Schema change detected - refreshing graph..."} + yield json.dumps(step) + MESSAGE_DELIMITER + + refresh_success, refresh_message = PostgresLoader.refresh_graph_schema(graph_id, db_url) + + if refresh_success: + yield json.dumps( + { + "type": "schema_refresh", + "message": f"✅ Schema change detected ({operation_type} operation)\n\n🔄 Graph schema has been automatically refreshed with the latest database structure.", + "refresh_status": "success" + } + ) + MESSAGE_DELIMITER + else: + yield json.dumps( + { + "type": "schema_refresh", + "message": f"⚠️ Schema was modified but graph refresh failed: {refresh_message}", + "refresh_status": "failed" + } + ) + MESSAGE_DELIMITER + # Generate user-readable response using AI - step = {"type": "reasoning_step", "message": "Step 2: Generating user-friendly response"} + step_num = "4" if is_schema_modifying else "3" + step = {"type": "reasoning_step", "message": f"Step {step_num}: Generating user-friendly response"} yield json.dumps(step) + MESSAGE_DELIMITER response_agent = ResponseFormatterAgent() @@ -447,6 +503,47 @@ def logout(): session.clear() return redirect(url_for("home")) +@app.route("/graphs//refresh", methods=["POST"]) +@token_required # Apply token authentication decorator +def refresh_graph_schema(graph_id: str): + """ + Manually refresh the graph schema from the database. + This endpoint allows users to manually trigger a schema refresh + if they suspect the graph is out of sync with the database. + """ + graph_id = g.user_id + "_" + graph_id.strip() + + try: + # Get database connection details + db_description, db_url = get_db_description(graph_id) + + if not db_url or db_url == "No URL available for this database.": + return jsonify({ + "success": False, + "error": "No database URL found for this graph" + }), 400 + + # Perform schema refresh + success, message = PostgresLoader.refresh_graph_schema(graph_id, db_url) + + if success: + return jsonify({ + "success": True, + "message": f"Graph schema refreshed successfully. {message}" + }), 200 + else: + return jsonify({ + "success": False, + "error": f"Failed to refresh schema: {message}" + }), 500 + + except Exception as e: + logging.error("Error in manual schema refresh: %s", e) + return jsonify({ + "success": False, + "error": f"Error refreshing schema: {str(e)}" + }), 500 + @app.route("/database", methods=["POST"]) @token_required # Apply token authentication decorator def connect_database(): diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index 915e0ca0..5c0cee99 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -1,15 +1,40 @@ from typing import Tuple, Dict, Any, List -import psycopg2 +import re +import logging import tqdm +import psycopg2 from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + class PostgresLoader(BaseLoader): """ Loader for PostgreSQL databases that connects and extracts schema information. """ + # DDL operations that modify database schema + SCHEMA_MODIFYING_OPERATIONS = { + 'CREATE', 'ALTER', 'DROP', 'RENAME', 'TRUNCATE' + } + + # More specific patterns for schema-affecting operations + SCHEMA_PATTERNS = [ + r'^\s*CREATE\s+TABLE', + r'^\s*CREATE\s+INDEX', + r'^\s*CREATE\s+UNIQUE\s+INDEX', + r'^\s*ALTER\s+TABLE', + r'^\s*DROP\s+TABLE', + r'^\s*DROP\s+INDEX', + r'^\s*RENAME\s+TABLE', + r'^\s*TRUNCATE\s+TABLE', + r'^\s*CREATE\s+VIEW', + r'^\s*DROP\s+VIEW', + r'^\s*CREATE\s+SCHEMA', + r'^\s*DROP\s+SCHEMA', + ] + @staticmethod def load(prefix: str, connection_url: str) -> Tuple[bool, str]: """ @@ -64,7 +89,7 @@ def extract_tables_info(cursor) -> Dict[str, Any]: Dict containing table information """ entities = {} - + # Get all tables in public schema cursor.execute(""" SELECT table_name, table_comment @@ -80,31 +105,31 @@ def extract_tables_info(cursor) -> Dict[str, Any]: AND t.table_type = 'BASE TABLE' ORDER BY t.table_name; """) - + tables = cursor.fetchall() - + for table_name, table_comment in tqdm.tqdm(tables, desc="Extracting table information"): table_name = table_name.strip() - + # Get column information for this table columns_info = PostgresLoader.extract_columns_info(cursor, table_name) - + # Get foreign keys for this table foreign_keys = PostgresLoader.extract_foreign_keys(cursor, table_name) - + # Generate table description table_description = table_comment if table_comment else f"Table: {table_name}" - + # Get column descriptions for batch embedding col_descriptions = [col_info['description'] for col_info in columns_info.values()] - + entities[table_name] = { 'description': table_description, 'columns': columns_info, 'foreign_keys': foreign_keys, 'col_descriptions': col_descriptions } - + return entities @staticmethod @@ -155,29 +180,29 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]: AND c.table_schema = 'public' ORDER BY c.ordinal_position; """, (table_name, table_name, table_name)) - + columns = cursor.fetchall() columns_info = {} - + for col_name, data_type, is_nullable, column_default, key_type, column_comment in columns: col_name = col_name.strip() - + # Generate column description description_parts = [] if column_comment: description_parts.append(column_comment) else: description_parts.append(f"Column {col_name} of type {data_type}") - + if key_type != 'NONE': description_parts.append(f"({key_type})") - + if is_nullable == 'NO': description_parts.append("(NOT NULL)") - + if column_default: description_parts.append(f"(Default: {column_default})") - + columns_info[col_name] = { 'type': data_type, 'null': is_nullable, @@ -185,7 +210,7 @@ def extract_columns_info(cursor, table_name: str) -> Dict[str, Any]: 'description': ' '.join(description_parts), 'default': column_default } - + return columns_info @staticmethod @@ -217,7 +242,7 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]: AND tc.table_name = %s AND tc.table_schema = 'public'; """, (table_name,)) - + foreign_keys = [] for constraint_name, column_name, foreign_table, foreign_column in cursor.fetchall(): foreign_keys.append({ @@ -226,7 +251,7 @@ def extract_foreign_keys(cursor, table_name: str) -> List[Dict[str, str]]: 'referenced_table': foreign_table.strip(), 'referenced_column': foreign_column.strip() }) - + return foreign_keys @staticmethod @@ -258,15 +283,15 @@ def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]: AND tc.table_schema = 'public' ORDER BY tc.table_name, tc.constraint_name; """) - + relationships = {} for table_name, constraint_name, column_name, foreign_table, foreign_column in cursor.fetchall(): table_name = table_name.strip() constraint_name = constraint_name.strip() - + if constraint_name not in relationships: relationships[constraint_name] = [] - + relationships[constraint_name].append({ 'from': table_name, 'to': foreign_table.strip(), @@ -274,9 +299,86 @@ def extract_relationships(cursor) -> Dict[str, List[Dict[str, str]]]: 'target_column': foreign_column.strip(), 'note': f'Foreign key constraint: {constraint_name}' }) - + return relationships - + + @staticmethod + def is_schema_modifying_query(sql_query: str) -> Tuple[bool, str]: + """ + Check if a SQL query modifies the database schema. + + Args: + sql_query: The SQL query to check + + Returns: + Tuple of (is_schema_modifying, operation_type) + """ + if not sql_query or not sql_query.strip(): + return False, "" + + # Clean and normalize the query + normalized_query = sql_query.strip().upper() + + # Check for basic DDL operations + first_word = normalized_query.split()[0] if normalized_query.split() else "" + if first_word in PostgresLoader.SCHEMA_MODIFYING_OPERATIONS: + # Additional pattern matching for more precise detection + for pattern in PostgresLoader.SCHEMA_PATTERNS: + if re.match(pattern, normalized_query, re.IGNORECASE): + return True, first_word + + # If it's a known DDL operation but doesn't match specific patterns, + # still consider it schema-modifying (better safe than sorry) + return True, first_word + + return False, "" + + @staticmethod + def refresh_graph_schema(graph_id: str, db_url: str) -> Tuple[bool, str]: + """ + Refresh the graph schema by clearing existing data and reloading from the database. + + Args: + graph_id: The graph ID to refresh + db_url: Database connection URL + + Returns: + Tuple of (success, message) + """ + try: + logging.info("Schema modification detected. Refreshing graph schema for: %s", graph_id) + + # Import here to avoid circular imports + from api.extensions import db + + # Clear existing graph data + # Drop current graph before reloading + graph = db.select_graph(graph_id) + graph.delete() + + # Extract prefix from graph_id (remove database name part) + # graph_id format is typically "prefix_database_name" + parts = graph_id.split('_') + if len(parts) >= 2: + # Reconstruct prefix by joining all parts except the last one + prefix = '_'.join(parts[:-1]) + else: + prefix = graph_id + + # Reuse the existing load method to reload the schema + success, message = PostgresLoader.load(prefix, db_url) + + if success: + logging.info("Graph schema refreshed successfully.") + return True, message + else: + return False, f"Failed to reload schema: {message}" + + except Exception as e: + error_msg = f"Error refreshing graph schema: {str(e)}" + logging.error(error_msg) + return False, error_msg + @staticmethod def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: """ @@ -297,7 +399,7 @@ def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: # Execute the SQL query cursor.execute(sql_query) - + # Check if the query returns results (SELECT queries) if cursor.description is not None: # This is a SELECT query or similar that returns rows @@ -309,7 +411,7 @@ def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: # Return information about the operation affected_rows = cursor.rowcount sql_type = sql_query.strip().split()[0].upper() - + if sql_type in ['INSERT', 'UPDATE', 'DELETE']: result_list = [{ "operation": sql_type, @@ -322,7 +424,7 @@ def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: "operation": sql_type, "status": "success" }] - + # Commit the transaction for write operations conn.commit() From 01623cacffedee92e7033b7e2050dd5012d7894b Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 14:39:42 +0300 Subject: [PATCH 30/58] change user bubble --- api/static/css/chat.css | 1 + 1 file changed, 1 insertion(+) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index 97cc8494..2a505e8d 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -152,6 +152,7 @@ body { .user-message-container::before { margin-right: 10px; background: var(--falkor-quaternary); + content: 'User'; } .bot-message-container::after { From eafdefc5ca0091e5bbcae0e4fe1f44ad31650a6c Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 15:05:35 +0300 Subject: [PATCH 31/58] wrap a long text --- api/static/css/chat.css | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index 2a505e8d..e73f76e4 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -186,6 +186,18 @@ body { margin: 5px 0; line-height: 1.4; color: var(--text-primary); + word-wrap: break-word; + overflow-wrap: break-word; + overflow-x: auto; + white-space: pre-wrap; +} + +/* Styles for formatted text blocks */ +.sql-line, .array-line, .plain-line { + word-wrap: break-word; + overflow-wrap: break-word; + white-space: pre-wrap; + margin: 2px 0; } .bot-message { From 4e52f19ef2913cf445853c9e278655c971fc733b Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 15:28:11 +0300 Subject: [PATCH 32/58] json fix --- api/loaders/postgres_loader.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index 5c0cee99..a12e3ac2 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -3,6 +3,8 @@ import logging import tqdm import psycopg2 +import datetime +import decimal from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph @@ -35,6 +37,28 @@ class PostgresLoader(BaseLoader): r'^\s*DROP\s+SCHEMA', ] + @staticmethod + def _serialize_value(value): + """ + Convert non-JSON serializable values to JSON serializable format. + + Args: + value: The value to serialize + + Returns: + JSON serializable version of the value + """ + if isinstance(value, (datetime.date, datetime.datetime)): + return value.isoformat() + elif isinstance(value, datetime.time): + return value.isoformat() + elif isinstance(value, decimal.Decimal): + return float(value) + elif value is None: + return None + else: + return value + @staticmethod def load(prefix: str, connection_url: str) -> Tuple[bool, str]: """ @@ -405,7 +429,14 @@ def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: # This is a SELECT query or similar that returns rows columns = [desc[0] for desc in cursor.description] results = cursor.fetchall() - result_list = [dict(zip(columns, row)) for row in results] + result_list = [] + for row in results: + # Serialize each value to ensure JSON compatibility + serialized_row = { + columns[i]: PostgresLoader._serialize_value(row[i]) + for i in range(len(columns)) + } + result_list.append(serialized_row) else: # This is an INSERT, UPDATE, DELETE, or other non-SELECT query # Return information about the operation From 1753230479abc4d8490c29ff9fd52d3bd560a3ca Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 15:56:10 +0300 Subject: [PATCH 33/58] add spinner on graph load --- api/static/css/chat.css | 32 ++++++++++++++++++++++++++++++++ api/static/js/chat.js | 30 +++++++++++++++++++++++++++++- api/templates/chat.j2 | 8 +++++++- 3 files changed, 68 insertions(+), 2 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index e73f76e4..f617ae11 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -680,6 +680,28 @@ body { .pg-modal-connect:hover { background: #3367d6; } +.pg-modal-connect:disabled { + background: #6c8db8; + cursor: not-allowed; +} +.pg-modal-loading-spinner { + display: flex; + align-items: center; + justify-content: center; + gap: 8px; +} +.spinner { + width: 16px; + height: 16px; + border: 2px solid #ffffff40; + border-top: 2px solid #ffffff; + border-radius: 50%; + animation: spin 1s linear infinite; +} +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} .pg-modal-cancel { background: #e0e0e0; color: #333; @@ -687,6 +709,16 @@ body { .pg-modal-cancel:hover { background: #cacaca; } +.pg-modal-cancel:disabled { + background: #f0f0f0; + color: #999; + cursor: not-allowed; +} +.pg-modal-input:disabled { + background: #f8f8f8; + color: #999; + cursor: not-allowed; +} .google-login-modal { display: none; position: fixed; diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 144a5ea9..5de0dd26 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -760,6 +760,18 @@ document.addEventListener('DOMContentLoaded', function() { alert('Please enter a Postgres URL.'); return; } + + // Show loading state + const connectText = connectPgModalBtn.querySelector('.pg-modal-connect-text'); + const loadingSpinner = connectPgModalBtn.querySelector('.pg-modal-loading-spinner'); + const cancelBtn = document.getElementById('pg-modal-cancel'); + + connectText.style.display = 'none'; + loadingSpinner.style.display = 'flex'; + connectPgModalBtn.disabled = true; + cancelBtn.disabled = true; + pgUrlInput.disabled = true; + fetch('/database', { method: 'POST', headers: { @@ -769,13 +781,29 @@ document.addEventListener('DOMContentLoaded', function() { }) .then(response => response.json()) .then(data => { + // Reset loading state + connectText.style.display = 'inline'; + loadingSpinner.style.display = 'none'; + connectPgModalBtn.disabled = false; + cancelBtn.disabled = false; + pgUrlInput.disabled = false; + if (data.success) { - pgModal.style.display = 'none'; // Close modal on success, no alert + pgModal.style.display = 'none'; // Close modal on success + // Refresh the graph list to show the new database + location.reload(); } else { alert('Failed to connect: ' + (data.error || 'Unknown error')); } }) .catch(error => { + // Reset loading state on error + connectText.style.display = 'inline'; + loadingSpinner.style.display = 'none'; + connectPgModalBtn.disabled = false; + cancelBtn.disabled = false; + pgUrlInput.disabled = false; + alert('Error connecting to database: ' + error.message); }); }); diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index 8d03aa42..98706011 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -119,7 +119,13 @@

Connect to Postgres

- +
From c92c93ff40680cf297201e70fae89a66b1861301 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 16:04:59 +0300 Subject: [PATCH 34/58] change text on spinning --- api/templates/chat.j2 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index 98706011..3990690f 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -123,7 +123,7 @@ Connect From 99c73d081aa9b9b12939182ed0bbe6d5c619163b Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 16:05:25 +0300 Subject: [PATCH 35/58] add example SQL file for CRM --- examples/crm.sql | 611 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 611 insertions(+) create mode 100644 examples/crm.sql diff --git a/examples/crm.sql b/examples/crm.sql new file mode 100644 index 00000000..1c05fd6c --- /dev/null +++ b/examples/crm.sql @@ -0,0 +1,611 @@ +-- SQL Script 1 (Extended): Table Creation (DDL) with Comments +-- This script creates the tables for your CRM database and adds descriptions for each table and column. + +-- Drop existing tables to start fresh +DROP TABLE IF EXISTS SalesOrderItems, SalesOrders, Invoices, Payments, Products, ProductCategories, Leads, Opportunities, Contacts, Customers, Campaigns, CampaignMembers, Tasks, Notes, Attachments, SupportTickets, TicketComments, Users, Roles, UserRoles CASCADE; + +-- Roles for access control +CREATE TABLE Roles ( + RoleID SERIAL PRIMARY KEY, + RoleName VARCHAR(50) UNIQUE NOT NULL +); +COMMENT ON TABLE Roles IS 'Defines user roles for access control within the CRM (e.g., Admin, Sales Manager).'; +COMMENT ON COLUMN Roles.RoleID IS 'Unique identifier for the role.'; +COMMENT ON COLUMN Roles.RoleName IS 'Name of the role (e.g., "Admin", "Sales Representative").'; + +-- Users of the CRM system +CREATE TABLE Users ( + UserID SERIAL PRIMARY KEY, + Username VARCHAR(50) UNIQUE NOT NULL, + PasswordHash VARCHAR(255) NOT NULL, + Email VARCHAR(100) UNIQUE NOT NULL, + FirstName VARCHAR(50), + LastName VARCHAR(50), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Users IS 'Stores information about users who can log in to the CRM system.'; +COMMENT ON COLUMN Users.UserID IS 'Unique identifier for the user.'; +COMMENT ON COLUMN Users.Username IS 'The username for logging in.'; +COMMENT ON COLUMN Users.PasswordHash IS 'Hashed password for security.'; +COMMENT ON COLUMN Users.Email IS 'The user''s email address.'; +COMMENT ON COLUMN Users.FirstName IS 'The user''s first name.'; +COMMENT ON COLUMN Users.LastName IS 'The user''s last name.'; +COMMENT ON COLUMN Users.CreatedAt IS 'Timestamp when the user account was created.'; + +-- Junction table for Users and Roles +CREATE TABLE UserRoles ( + UserID INT REFERENCES Users(UserID), + RoleID INT REFERENCES Roles(RoleID), + PRIMARY KEY (UserID, RoleID) +); +COMMENT ON TABLE UserRoles IS 'Maps users to their assigned roles, supporting many-to-many relationships.'; +COMMENT ON COLUMN UserRoles.UserID IS 'Foreign key referencing the Users table.'; +COMMENT ON COLUMN UserRoles.RoleID IS 'Foreign key referencing the Roles table.'; + +-- Customer accounts +CREATE TABLE Customers ( + CustomerID SERIAL PRIMARY KEY, + CustomerName VARCHAR(100) NOT NULL, + Industry VARCHAR(50), + Website VARCHAR(100), + Phone VARCHAR(20), + Address VARCHAR(255), + City VARCHAR(50), + State VARCHAR(50), + ZipCode VARCHAR(20), + Country VARCHAR(50), + AssignedTo INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Customers IS 'Represents customer accounts or companies.'; +COMMENT ON COLUMN Customers.CustomerID IS 'Unique identifier for the customer.'; +COMMENT ON COLUMN Customers.CustomerName IS 'The name of the customer company.'; +COMMENT ON COLUMN Customers.Industry IS 'The industry the customer belongs to.'; +COMMENT ON COLUMN Customers.Website IS 'The customer''s official website.'; +COMMENT ON COLUMN Customers.Phone IS 'The customer''s primary phone number.'; +COMMENT ON COLUMN Customers.Address IS 'The customer''s physical address.'; +COMMENT ON COLUMN Customers.City IS 'The city part of the address.'; +COMMENT ON COLUMN Customers.State IS 'The state or province part of the address.'; +COMMENT ON COLUMN Customers.ZipCode IS 'The postal or zip code.'; +COMMENT ON COLUMN Customers.Country IS 'The country part of the address.'; +COMMENT ON COLUMN Customers.AssignedTo IS 'The user (sales representative) assigned to this customer account.'; +COMMENT ON COLUMN Customers.CreatedAt IS 'Timestamp when the customer was added.'; + +-- Individual contacts associated with customers +CREATE TABLE Contacts ( + ContactID SERIAL PRIMARY KEY, + CustomerID INT REFERENCES Customers(CustomerID), + FirstName VARCHAR(50) NOT NULL, + LastName VARCHAR(50) NOT NULL, + Email VARCHAR(100) UNIQUE, + Phone VARCHAR(20), + JobTitle VARCHAR(50), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Contacts IS 'Stores information about individual contacts associated with customer accounts.'; +COMMENT ON COLUMN Contacts.ContactID IS 'Unique identifier for the contact.'; +COMMENT ON COLUMN Contacts.CustomerID IS 'Foreign key linking the contact to a customer account.'; +COMMENT ON COLUMN Contacts.FirstName IS 'The contact''s first name.'; +COMMENT ON COLUMN Contacts.LastName IS 'The contact''s last name.'; +COMMENT ON COLUMN Contacts.Email IS 'The contact''s email address.'; +COMMENT ON COLUMN Contacts.Phone IS 'The contact''s phone number.'; +COMMENT ON COLUMN Contacts.JobTitle IS 'The contact''s job title or position.'; +COMMENT ON COLUMN Contacts.CreatedAt IS 'Timestamp when the contact was created.'; + +-- Potential sales leads +CREATE TABLE Leads ( + LeadID SERIAL PRIMARY KEY, + FirstName VARCHAR(50), + LastName VARCHAR(50), + Email VARCHAR(100), + Phone VARCHAR(20), + Company VARCHAR(100), + Status VARCHAR(50) DEFAULT 'New', + Source VARCHAR(50), + AssignedTo INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Leads IS 'Represents potential customers or sales prospects (not yet qualified).'; +COMMENT ON COLUMN Leads.LeadID IS 'Unique identifier for the lead.'; +COMMENT ON COLUMN Leads.Status IS 'Current status of the lead (e.g., New, Contacted, Qualified, Lost).'; +COMMENT ON COLUMN Leads.Source IS 'The source from which the lead was generated (e.g., Website, Referral).'; +COMMENT ON COLUMN Leads.AssignedTo IS 'The user assigned to follow up with this lead.'; +COMMENT ON COLUMN Leads.CreatedAt IS 'Timestamp when the lead was created.'; + +-- Sales opportunities +CREATE TABLE Opportunities ( + OpportunityID SERIAL PRIMARY KEY, + CustomerID INT REFERENCES Customers(CustomerID), + OpportunityName VARCHAR(100) NOT NULL, + Stage VARCHAR(50) DEFAULT 'Prospecting', + Amount DECIMAL(12, 2), + CloseDate DATE, + AssignedTo INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Opportunities IS 'Tracks qualified sales deals with potential revenue.'; +COMMENT ON COLUMN Opportunities.OpportunityID IS 'Unique identifier for the opportunity.'; +COMMENT ON COLUMN Opportunities.CustomerID IS 'Foreign key linking the opportunity to a customer account.'; +COMMENT ON COLUMN Opportunities.OpportunityName IS 'A descriptive name for the sales opportunity.'; +COMMENT ON COLUMN Opportunities.Stage IS 'The current stage in the sales pipeline (e.g., Prospecting, Proposal, Closed Won).'; +COMMENT ON COLUMN Opportunities.Amount IS 'The estimated value of the opportunity.'; +COMMENT ON COLUMN Opportunities.CloseDate IS 'The expected date the deal will close.'; +COMMENT ON COLUMN Opportunities.AssignedTo IS 'The user responsible for this opportunity.'; +COMMENT ON COLUMN Opportunities.CreatedAt IS 'Timestamp when the opportunity was created.'; + +-- Product categories +CREATE TABLE ProductCategories ( + CategoryID SERIAL PRIMARY KEY, + CategoryName VARCHAR(50) NOT NULL, + Description TEXT +); +COMMENT ON TABLE ProductCategories IS 'Used to group products into categories (e.g., Software, Hardware).'; +COMMENT ON COLUMN ProductCategories.CategoryID IS 'Unique identifier for the category.'; +COMMENT ON COLUMN ProductCategories.CategoryName IS 'Name of the product category.'; +COMMENT ON COLUMN ProductCategories.Description IS 'A brief description of the category.'; + +-- Products or services offered +CREATE TABLE Products ( + ProductID SERIAL PRIMARY KEY, + ProductName VARCHAR(100) NOT NULL, + CategoryID INT REFERENCES ProductCategories(CategoryID), + Description TEXT, + Price DECIMAL(10, 2) NOT NULL, + StockQuantity INT DEFAULT 0 +); +COMMENT ON TABLE Products IS 'Stores details of the products or services the company sells.'; +COMMENT ON COLUMN Products.ProductID IS 'Unique identifier for the product.'; +COMMENT ON COLUMN Products.ProductName IS 'Name of the product.'; +COMMENT ON COLUMN Products.CategoryID IS 'Foreign key linking the product to a category.'; +COMMENT ON COLUMN Products.Description IS 'Detailed description of the product.'; +COMMENT ON COLUMN Products.Price IS 'The unit price of the product.'; +COMMENT ON COLUMN Products.StockQuantity IS 'The quantity of the product available in stock.'; + +-- Sales orders +CREATE TABLE SalesOrders ( + OrderID SERIAL PRIMARY KEY, + CustomerID INT REFERENCES Customers(CustomerID), + OpportunityID INT REFERENCES Opportunities(OpportunityID), + OrderDate DATE NOT NULL, + Status VARCHAR(50) DEFAULT 'Pending', + TotalAmount DECIMAL(12, 2), + AssignedTo INT REFERENCES Users(UserID) +); +COMMENT ON TABLE SalesOrders IS 'Records of confirmed sales to customers.'; +COMMENT ON COLUMN SalesOrders.OrderID IS 'Unique identifier for the sales order.'; +COMMENT ON COLUMN SalesOrders.CustomerID IS 'Foreign key linking the order to a customer.'; +COMMENT ON COLUMN SalesOrders.OpportunityID IS 'Foreign key linking the order to the sales opportunity it came from.'; +COMMENT ON COLUMN SalesOrders.OrderDate IS 'The date the order was placed.'; +COMMENT ON COLUMN SalesOrders.Status IS 'The current status of the order (e.g., Pending, Shipped, Canceled).'; +COMMENT ON COLUMN SalesOrders.TotalAmount IS 'The total calculated amount for the order.'; +COMMENT ON COLUMN SalesOrders.AssignedTo IS 'The user who processed the order.'; + +-- Items within a sales order +CREATE TABLE SalesOrderItems ( + OrderItemID SERIAL PRIMARY KEY, + OrderID INT REFERENCES SalesOrders(OrderID) ON DELETE CASCADE, + ProductID INT REFERENCES Products(ProductID), + Quantity INT NOT NULL, + UnitPrice DECIMAL(10, 2) NOT NULL +); +COMMENT ON TABLE SalesOrderItems IS 'Line items for each product within a sales order.'; +COMMENT ON COLUMN SalesOrderItems.OrderItemID IS 'Unique identifier for the order item.'; +COMMENT ON COLUMN SalesOrderItems.OrderID IS 'Foreign key linking this item to a sales order.'; +COMMENT ON COLUMN SalesOrderItems.ProductID IS 'Foreign key linking to the product being ordered.'; +COMMENT ON COLUMN SalesOrderItems.Quantity IS 'The quantity of the product ordered.'; +COMMENT ON COLUMN SalesOrderItems.UnitPrice IS 'The price per unit at the time of sale.'; + +-- Invoices for sales +CREATE TABLE Invoices ( + InvoiceID SERIAL PRIMARY KEY, + OrderID INT REFERENCES SalesOrders(OrderID), + InvoiceDate DATE NOT NULL, + DueDate DATE, + TotalAmount DECIMAL(12, 2), + Status VARCHAR(50) DEFAULT 'Unpaid' +); +COMMENT ON TABLE Invoices IS 'Represents billing invoices sent to customers.'; +COMMENT ON COLUMN Invoices.InvoiceID IS 'Unique identifier for the invoice.'; +COMMENT ON COLUMN Invoices.OrderID IS 'Foreign key linking the invoice to a sales order.'; +COMMENT ON COLUMN Invoices.InvoiceDate IS 'The date the invoice was issued.'; +COMMENT ON COLUMN Invoices.DueDate IS 'The date the payment is due.'; +COMMENT ON COLUMN Invoices.TotalAmount IS 'The total amount due on the invoice.'; +COMMENT ON COLUMN Invoices.Status IS 'The payment status of the invoice (e.g., Unpaid, Paid, Overdue).'; + +-- Payment records +CREATE TABLE Payments ( + PaymentID SERIAL PRIMARY KEY, + InvoiceID INT REFERENCES Invoices(InvoiceID), + PaymentDate DATE NOT NULL, + Amount DECIMAL(12, 2), + PaymentMethod VARCHAR(50) +); +COMMENT ON TABLE Payments IS 'Tracks payments received from customers against invoices.'; +COMMENT ON COLUMN Payments.PaymentID IS 'Unique identifier for the payment.'; +COMMENT ON COLUMN Payments.InvoiceID IS 'Foreign key linking the payment to an invoice.'; +COMMENT ON COLUMN Payments.PaymentDate IS 'The date the payment was received.'; +COMMENT ON COLUMN Payments.Amount IS 'The amount that was paid.'; +COMMENT ON COLUMN Payments.PaymentMethod IS 'The method of payment (e.g., Credit Card, Bank Transfer).'; + +-- Marketing campaigns +CREATE TABLE Campaigns ( + CampaignID SERIAL PRIMARY KEY, + CampaignName VARCHAR(100) NOT NULL, + StartDate DATE, + EndDate DATE, + Budget DECIMAL(12, 2), + Status VARCHAR(50), + Owner INT REFERENCES Users(UserID) +); +COMMENT ON TABLE Campaigns IS 'Stores information about marketing campaigns.'; +COMMENT ON COLUMN Campaigns.CampaignID IS 'Unique identifier for the campaign.'; +COMMENT ON COLUMN Campaigns.CampaignName IS 'The name of the marketing campaign.'; +COMMENT ON COLUMN Campaigns.StartDate IS 'The start date of the campaign.'; +COMMENT ON COLUMN Campaigns.EndDate IS 'The end date of the campaign.'; +COMMENT ON COLUMN Campaigns.Budget IS 'The allocated budget for the campaign.'; +COMMENT ON COLUMN Campaigns.Status IS 'The current status of the campaign (e.g., Planned, Active, Completed).'; +COMMENT ON COLUMN Campaigns.Owner IS 'The user responsible for the campaign.'; + +-- Members of a marketing campaign (leads or contacts) +CREATE TABLE CampaignMembers ( + CampaignMemberID SERIAL PRIMARY KEY, + CampaignID INT REFERENCES Campaigns(CampaignID), + LeadID INT REFERENCES Leads(LeadID), + ContactID INT REFERENCES Contacts(ContactID), + Status VARCHAR(50) +); +COMMENT ON TABLE CampaignMembers IS 'Links leads and contacts to the marketing campaigns they are a part of.'; +COMMENT ON COLUMN CampaignMembers.CampaignMemberID IS 'Unique identifier for the campaign member record.'; +COMMENT ON COLUMN CampaignMembers.CampaignID IS 'Foreign key linking to the campaign.'; +COMMENT ON COLUMN CampaignMembers.LeadID IS 'Foreign key linking to a lead (if the member is a lead).'; +COMMENT ON COLUMN CampaignMembers.ContactID IS 'Foreign key linking to a contact (if the member is a contact).'; +COMMENT ON COLUMN CampaignMembers.Status IS 'The status of the member in the campaign (e.g., Sent, Responded).'; + +-- Tasks for users +CREATE TABLE Tasks ( + TaskID SERIAL PRIMARY KEY, + Title VARCHAR(100) NOT NULL, + Description TEXT, + DueDate DATE, + Status VARCHAR(50) DEFAULT 'Not Started', + Priority VARCHAR(20) DEFAULT 'Normal', + AssignedTo INT REFERENCES Users(UserID), + RelatedToEntity VARCHAR(50), + RelatedToID INT +); +COMMENT ON TABLE Tasks IS 'Tracks tasks or to-do items for CRM users.'; +COMMENT ON COLUMN Tasks.TaskID IS 'Unique identifier for the task.'; +COMMENT ON COLUMN Tasks.Title IS 'A short title for the task.'; +COMMENT ON COLUMN Tasks.Description IS 'A detailed description of the task.'; +COMMENT ON COLUMN Tasks.DueDate IS 'The date the task is due to be completed.'; +COMMENT ON COLUMN Tasks.Status IS 'The current status of the task (e.g., Not Started, In Progress, Completed).'; +COMMENT ON COLUMN Tasks.Priority IS 'The priority level of the task (e.g., Low, Normal, High).'; +COMMENT ON COLUMN Tasks.AssignedTo IS 'The user the task is assigned to.'; +COMMENT ON COLUMN Tasks.RelatedToEntity IS 'The type of record this task is related to (e.g., ''Lead'', ''Opportunity'').'; +COMMENT ON COLUMN Tasks.RelatedToID IS 'The ID of the related record.'; + +-- Notes related to various records +CREATE TABLE Notes ( + NoteID SERIAL PRIMARY KEY, + Content TEXT NOT NULL, + CreatedBy INT REFERENCES Users(UserID), + RelatedToEntity VARCHAR(50), + RelatedToID INT, + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Notes IS 'Allows users to add notes to various records (e.g., contacts, opportunities).'; +COMMENT ON COLUMN Notes.NoteID IS 'Unique identifier for the note.'; +COMMENT on COLUMN Notes.Content IS 'The text content of the note.'; +COMMENT ON COLUMN Notes.CreatedBy IS 'The user who created the note.'; +COMMENT ON COLUMN Notes.RelatedToEntity IS 'The type of record this note is related to (e.g., ''Contact'', ''Customer'').'; +COMMENT ON COLUMN Notes.RelatedToID IS 'The ID of the related record.'; +COMMENT ON COLUMN Notes.CreatedAt IS 'Timestamp when the note was created.'; + +-- File attachments +CREATE TABLE Attachments ( + AttachmentID SERIAL PRIMARY KEY, + FileName VARCHAR(255) NOT NULL, + FilePath VARCHAR(255) NOT NULL, + FileSize INT, + FileType VARCHAR(100), + UploadedBy INT REFERENCES Users(UserID), + RelatedToEntity VARCHAR(50), + RelatedToID INT, + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE Attachments IS 'Stores metadata about files attached to records in the CRM.'; +COMMENT ON COLUMN Attachments.AttachmentID IS 'Unique identifier for the attachment.'; +COMMENT ON COLUMN Attachments.FileName IS 'The original name of the uploaded file.'; +COMMENT ON COLUMN Attachments.FilePath IS 'The path where the file is stored on the server.'; +COMMENT ON COLUMN Attachments.FileSize IS 'The size of the file in bytes.'; +COMMENT ON COLUMN Attachments.FileType IS 'The MIME type of the file (e.g., ''application/pdf'').'; +COMMENT ON COLUMN Attachments.UploadedBy IS 'The user who uploaded the file.'; +COMMENT ON COLUMN Attachments.RelatedToEntity IS 'The type of record this attachment is related to.'; +COMMENT ON COLUMN Attachments.RelatedToID IS 'The ID of the related record.'; +COMMENT ON COLUMN Attachments.CreatedAt IS 'Timestamp when the file was uploaded.'; + +-- Customer support tickets +CREATE TABLE SupportTickets ( + TicketID SERIAL PRIMARY KEY, + CustomerID INT REFERENCES Customers(CustomerID), + ContactID INT REFERENCES Contacts(ContactID), + Subject VARCHAR(255) NOT NULL, + Description TEXT, + Status VARCHAR(50) DEFAULT 'Open', + Priority VARCHAR(20) DEFAULT 'Normal', + AssignedTo INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE SupportTickets IS 'Tracks customer service and support requests.'; +COMMENT ON COLUMN SupportTickets.TicketID IS 'Unique identifier for the support ticket.'; +COMMENT ON COLUMN SupportTickets.CustomerID IS 'Foreign key linking the ticket to a customer.'; +COMMENT ON COLUMN SupportTickets.ContactID IS 'Foreign key linking the ticket to a specific contact.'; +COMMENT ON COLUMN SupportTickets.Subject IS 'A brief summary of the support issue.'; +COMMENT ON COLUMN SupportTickets.Description IS 'A detailed description of the issue.'; +COMMENT ON COLUMN SupportTickets.Status IS 'The current status of the ticket (e.g., Open, In Progress, Resolved).'; +COMMENT ON COLUMN SupportTickets.Priority IS 'The priority of the ticket (e.g., Low, Normal, High).'; +COMMENT ON COLUMN SupportTickets.AssignedTo IS 'The support agent the ticket is assigned to.'; +COMMENT ON COLUMN SupportTickets.CreatedAt IS 'Timestamp when the ticket was created.'; + +-- Comments on support tickets +CREATE TABLE TicketComments ( + CommentID SERIAL PRIMARY KEY, + TicketID INT REFERENCES SupportTickets(TicketID) ON DELETE CASCADE, + Comment TEXT NOT NULL, + CreatedBy INT REFERENCES Users(UserID), + CreatedAt TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP +); +COMMENT ON TABLE TicketComments IS 'Stores comments and updates related to a support ticket.'; +COMMENT ON COLUMN TicketComments.CommentID IS 'Unique identifier for the comment.'; +COMMENT ON COLUMN TicketComments.TicketID IS 'Foreign key linking the comment to a support ticket.'; +COMMENT ON COLUMN TicketComments.Comment IS 'The text content of the comment.'; +COMMENT ON COLUMN TicketComments.CreatedBy IS 'The user who added the comment.'; +COMMENT ON COLUMN TicketComments.CreatedAt IS 'Timestamp when the comment was added.'; + + +-- SQL Script 2: Data Insertion (DML) +-- This script populates the tables with sample data. + +-- Insert Roles +INSERT INTO Roles (RoleName) VALUES ('Admin'), ('Sales Manager'), ('Sales Representative'), ('Support Agent'); + +-- Insert Users +INSERT INTO Users (Username, PasswordHash, Email, FirstName, LastName) VALUES +('admin', 'hashed_password', 'admin@example.com', 'Admin', 'User'), +('sales_manager', 'hashed_password', 'manager@example.com', 'John', 'Doe'), +('sales_rep1', 'hashed_password', 'rep1@example.com', 'Jane', 'Smith'), +('sales_rep2', 'hashed_password', 'rep2@example.com', 'Peter', 'Jones'), +('support_agent1', 'hashed_password', 'support1@example.com', 'Mary', 'Williams'); + +-- Assign Roles to Users +INSERT INTO UserRoles (UserID, RoleID) VALUES +(1, 1), (2, 2), (3, 3), (4, 3), (5, 4); + +-- Insert Customers +INSERT INTO Customers (CustomerName, Industry, Website, Phone, Address, City, State, ZipCode, Country, AssignedTo) VALUES +('ABC Corporation', 'Technology', 'http://www.abccorp.com', '123-456-7890', '123 Tech Park', 'Techville', 'CA', '90210', 'USA', 3), +('Innovate Inc.', 'Software', 'http://www.innovate.com', '234-567-8901', '456 Innovation Dr', 'Devtown', 'TX', '75001', 'USA', 4), +('Global Solutions', 'Consulting', 'http://www.globalsolutions.com', '345-678-9012', '789 Global Ave', 'Businesston', 'NY', '10001', 'USA', 3), +('Data Dynamics', 'Analytics', 'http://www.datadynamics.com', '456-123-7890', '789 Data Dr', 'Metropolis', 'IL', '60601', 'USA', 4), +('Synergy Solutions', 'HR', 'http://www.synergysolutions.com', '789-456-1230', '101 Synergy Blvd', 'Union City', 'NJ', '07087', 'USA', 3); + +-- Insert Contacts +INSERT INTO Contacts (CustomerID, FirstName, LastName, Email, Phone, JobTitle) VALUES +(1, 'Alice', 'Wonder', 'alice.wonder@abccorp.com', '123-456-7891', 'CTO'), +(1, 'Bob', 'Builder', 'bob.builder@abccorp.com', '123-456-7892', 'Project Manager'), +(2, 'Charlie', 'Chocolate', 'charlie.chocolate@innovate.com', '234-567-8902', 'CEO'), +(3, 'Diana', 'Prince', 'diana.prince@globalsolutions.com', '345-678-9013', 'Consultant'), +(4, 'Leo', 'Lytics', 'leo.lytics@datadynamics.com', '456-123-7891', 'Data Scientist'), +(5, 'Hannah', 'Resources', 'hannah.r@synergysolutions.com', '789-456-1231', 'HR Manager'); + +-- Insert Leads +INSERT INTO Leads (FirstName, LastName, Email, Phone, Company, Status, Source, AssignedTo) VALUES +('Eve', 'Apple', 'eve.apple@email.com', '456-789-0123', 'Future Gadgets', 'Qualified', 'Website', 3), +('Frank', 'Stein', 'frank.stein@email.com', '567-890-1234', 'Monster Corp', 'New', 'Referral', 4), +('Grace', 'Hopper', 'grace.hopper@email.com', '678-901-2345', 'Cobol Inc.', 'Contacted', 'Cold Call', 3), +('Ivy', 'Green', 'ivy.g@webmail.com', '890-123-4567', 'Eco Systems', 'New', 'Trade Show', 4), +('Jack', 'Nimble', 'jack.n@fastmail.com', '901-234-5678', 'Quick Corp', 'Qualified', 'Website', 3); + +-- Insert Opportunities +INSERT INTO Opportunities (CustomerID, OpportunityName, Stage, Amount, CloseDate, AssignedTo) VALUES +(1, 'ABC Corp Website Redesign', 'Proposal', 50000.00, '2025-08-30', 3), +(2, 'Innovate Inc. Mobile App', 'Qualification', 75000.00, '2025-09-15', 4), +(3, 'Global Solutions IT Consulting', 'Negotiation', 120000.00, '2025-08-20', 3), +(4, 'Analytics Platform Subscription', 'Proposal', 90000.00, '2025-09-30', 4), +(5, 'HR Software Implementation', 'Prospecting', 65000.00, '2025-10-25', 3); + +-- Insert Product Categories +INSERT INTO ProductCategories (CategoryName, Description) VALUES +('Software', 'Business and productivity software'), +('Hardware', 'Computer hardware and peripherals'), +('Services', 'Consulting and support services'); + +-- Insert Products +INSERT INTO Products (ProductName, CategoryID, Description, Price, StockQuantity) VALUES +('CRM Pro', 1, 'Advanced CRM Software Suite', 1500.00, 100), +('Office Laptop Model X', 2, 'High-performance laptop for business', 1200.00, 50), +('IT Support Package', 3, '24/7 IT support services', 300.00, 200), +('Analytics Dashboard Pro', 1, 'Advanced analytics dashboard', 2500.00, 75), +('Ergonomic Office Chair', 2, 'Comfortable chair for long hours', 350.00, 150); + +-- Insert Sales Orders +INSERT INTO SalesOrders (CustomerID, OpportunityID, OrderDate, Status, TotalAmount, AssignedTo) VALUES +(1, 1, '2025-07-20', 'Shipped', 1500.00, 3), +(2, 2, '2025-07-22', 'Pending', 2400.00, 4), +(3, 3, '2025-07-24', 'Delivered', 300.00, 3), +(4, 4, '2025-07-25', 'Pending', 2500.00, 4); + +-- Insert Sales Order Items +INSERT INTO SalesOrderItems (OrderID, ProductID, Quantity, UnitPrice) VALUES +(1, 1, 1, 1500.00), +(2, 2, 2, 1200.00), +(3, 3, 1, 300.00), +(4, 4, 1, 2500.00); + +-- Insert Invoices +INSERT INTO Invoices (OrderID, InvoiceDate, DueDate, TotalAmount, Status) VALUES +(1, '2025-07-21', '2025-08-20', 1500.00, 'Paid'), +(2, '2025-07-23', '2025-08-22', 2400.00, 'Unpaid'), +(3, '2025-07-24', '2025-08-23', 300.00, 'Paid'), +(4, '2025-07-25', '2025-08-24', 2500.00, 'Unpaid'); + +-- Insert Payments +INSERT INTO Payments (InvoiceID, PaymentDate, Amount, PaymentMethod) VALUES +(1, '2025-07-25', 1500.00, 'Credit Card'), +(3, '2025-07-25', 300.00, 'Bank Transfer'); + +-- Insert Campaigns +INSERT INTO Campaigns (CampaignName, StartDate, EndDate, Budget, Status, Owner) VALUES +('Summer Sale 2025', '2025-06-01', '2025-08-31', 10000.00, 'Active', 2), +('Q4 Product Launch', '2025-10-01', '2025-12-31', 25000.00, 'Planned', 2); + +-- Insert Campaign Members +INSERT INTO CampaignMembers (CampaignID, LeadID, Status) VALUES +(1, 1, 'Responded'), +(1, 2, 'Sent'), +(1, 4, 'Sent'); +INSERT INTO CampaignMembers (CampaignID, ContactID, Status) VALUES +(1, 4, 'Sent'), +(1, 5, 'Responded'); + +-- Insert Tasks +INSERT INTO Tasks (Title, Description, DueDate, Status, Priority, AssignedTo, RelatedToEntity, RelatedToID) VALUES +('Follow up with ABC Corp', 'Discuss proposal details', '2025-08-01', 'In Progress', 'High', 3, 'Opportunity', 1), +('Prepare demo for Innovate Inc.', 'Customize demo for their needs', '2025-08-05', 'Not Started', 'Normal', 4, 'Opportunity', 2), +('Send updated proposal to Global Solutions', 'Include new service terms', '2025-07-28', 'Completed', 'High', 3, 'Opportunity', 3), +('Schedule initial call with Synergy Solutions', 'Discuss HR software needs', '2025-08-02', 'Not Started', 'Normal', 3, 'Customer', 5); + +-- Insert Notes +INSERT INTO Notes (Content, CreatedBy, RelatedToEntity, RelatedToID) VALUES +('Alice is very interested in the mobile integration features.', 3, 'Contact', 1), +('Lead from the tech conference last week.', 4, 'Lead', 2), +('Customer is looking for a cloud-based solution.', 4, 'Opportunity', 4), +('Met Ivy at the GreenTech expo. Promising lead.', 4, 'Lead', 4); + +-- Insert Attachments +INSERT INTO Attachments (FileName, FilePath, FileSize, FileType, UploadedBy, RelatedToEntity, RelatedToID) VALUES +('proposal_v1.pdf', '/attachments/proposal_v1.pdf', 102400, 'application/pdf', 3, 'Opportunity', 1), +('analytics_brochure.pdf', '/attachments/analytics_brochure.pdf', 256000, 'application/pdf', 4, 'Opportunity', 4); + +-- Insert Support Tickets +INSERT INTO SupportTickets (CustomerID, ContactID, Subject, Description, Status, Priority, AssignedTo) VALUES +(1, 1, 'Cannot login to portal', 'User Alice Wonder is unable to access the customer portal.', 'Resolved', 'High', 5), +(2, 3, 'Billing question', 'Question about the last invoice.', 'In Progress', 'Normal', 5), +(3, 4, 'Feature Request: Dark Mode', 'Requesting dark mode for the user dashboard.', 'Open', 'Low', 5), +(1, 2, 'Integration issue with calendar', 'Tasks are not syncing with Google Calendar.', 'In Progress', 'High', 5); + +-- Insert Ticket Comments +INSERT INTO TicketComments (TicketID, Comment, CreatedBy) VALUES +(1, 'Have reset the password. Please ask the user to try again.', 5), +(1, 'User confirmed they can now log in. Closing the ticket.', 5), +(2, 'Checking API logs for sync errors.', 5), +(3, 'Feature has been added to the development backlog.', 5); + +-- SQL Script 3: Insert More Demo Data (DML) +-- This script adds more sample data to the CRM database. +-- Run this script AFTER running 1_create_tables.sql and 2_insert_data.sql. + +-- Insert more Customers (starting from CustomerID 6) +INSERT INTO Customers (CustomerName, Industry, Website, Phone, Address, City, State, ZipCode, Country, AssignedTo) VALUES +('Quantum Innovations', 'R&D', 'http://www.quantuminnovate.com', '555-0101', '100 Research Pkwy', 'Quantumville', 'MA', '02139', 'USA', 3), +('HealthFirst Medical', 'Healthcare', 'http://www.healthfirst.com', '555-0102', '200 Health Blvd', 'Wellnesston', 'FL', '33101', 'USA', 4), +('GreenScape Solutions', 'Environmental', 'http://www.greenscape.com', '555-0103', '300 Nature Way', 'Ecoville', 'OR', '97201', 'USA', 3), +('Pinnacle Finance', 'Finance', 'http://www.pinnaclefinance.com', '555-0104', '400 Wall St', 'Financeton', 'NY', '10005', 'USA', 4), +('Creative Minds Agency', 'Marketing', 'http://www.creativeminds.com', '555-0105', '500 Ad Ave', 'Creator City', 'CA', '90028', 'USA', 3); + +-- Insert more Contacts (starting from ContactID 7) +-- Assuming CustomerIDs 6-10 were just created +INSERT INTO Contacts (CustomerID, FirstName, LastName, Email, Phone, JobTitle) VALUES +(6, 'Quentin', 'Physics', 'q.physics@quantuminnovate.com', '555-0101-1', 'Lead Scientist'), +(7, 'Helen', 'Healer', 'h.healer@healthfirst.com', '555-0102-1', 'Hospital Administrator'), +(7, 'Marcus', 'Welby', 'm.welby@healthfirst.com', '555-0102-2', 'Chief of Medicine'), +(8, 'Gary', 'Gardener', 'g.gardener@greenscape.com', '555-0103-1', 'CEO'), +(9, 'Fiona', 'Funds', 'f.funds@pinnaclefinance.com', '555-0104-1', 'Investment Banker'), +(10, 'Chris', 'Creative', 'c.creative@creativeminds.com', '555-0105-1', 'Art Director'), +(1, 'Carol', 'Client', 'c.client@abccorp.com', '123-456-7893', 'IT Director'); -- Contact for existing customer + +-- Insert more Leads (starting from LeadID 6) +INSERT INTO Leads (FirstName, LastName, Email, Phone, Company, Status, Source, AssignedTo) VALUES +('Ken', 'Knowledge', 'ken.k@university.edu', '555-0201', 'State University', 'Contacted', 'Referral', 4), +('Laura', 'Legal', 'laura.l@lawfirm.com', '555-0202', 'Law & Order LLC', 'New', 'Website', 3), +('Mike', 'Mechanic', 'mike.m@autoshop.com', '555-0203', 'Auto Fixers', 'Lost', 'Cold Call', 4), +('Nancy', 'Nurse', 'nancy.n@clinic.com', '555-0204', 'Community Clinic', 'Qualified', 'Trade Show', 3), +('Oscar', 'Organizer', 'oscar.o@events.com', '555-0205', 'Events R Us', 'New', 'Website', 4); + +-- Insert more Opportunities (starting from OpportunityID 6) +-- Assuming CustomerIDs 6-10 were just created +INSERT INTO Opportunities (CustomerID, OpportunityName, Stage, Amount, CloseDate, AssignedTo) VALUES +(6, 'Quantum Computing Simulation Software', 'Qualification', 250000.00, '2025-11-15', 3), +(7, 'Patient Management System Upgrade', 'Proposal', 180000.00, '2025-12-01', 4), +(8, 'Environmental Impact Reporting Tool', 'Negotiation', 75000.00, '2025-10-30', 3), +(9, 'Wealth Management Platform', 'Closed Won', 300000.00, '2025-07-25', 4), +(10, 'Digital Marketing Campaign Analytics', 'Prospecting', 45000.00, '2025-11-20', 3); + +-- Insert a new Product Category first +INSERT INTO ProductCategories (CategoryName, Description) VALUES +('Cloud Solutions', 'Cloud-based infrastructure and platforms'); -- This will be CategoryID 4 + +-- Insert more Products (starting from ProductID 6) +INSERT INTO Products (ProductName, CategoryID, Description, Price, StockQuantity) VALUES +('Wealth Management Suite', 1, 'Comprehensive software for financial advisors', 5000.00, 50), +('Patient Record System', 1, 'EHR system for clinics and hospitals', 4500.00, 80), +('Cloud Storage - 10TB Plan', 4, '10TB of enterprise cloud storage', 1000.00, 500); + +-- Insert more Sales Orders (starting from OrderID 5) +-- For the 'Closed Won' opportunity (ID 9) +INSERT INTO SalesOrders (CustomerID, OpportunityID, OrderDate, Status, TotalAmount, AssignedTo) VALUES +(9, 9, '2025-07-26', 'Delivered', 5000.00, 4); + +-- Insert more Sales Order Items (for OrderID 5) +INSERT INTO SalesOrderItems (OrderID, ProductID, Quantity, UnitPrice) VALUES +(5, 6, 1, 5000.00); -- Wealth Management Suite (ProductID 6) + +-- Insert more Invoices (starting from InvoiceID 5) +INSERT INTO Invoices (OrderID, InvoiceDate, DueDate, TotalAmount, Status) VALUES +(5, '2025-07-26', '2025-08-25', 5000.00, 'Paid'); + +-- Insert more Payments (starting from PaymentID 3) +INSERT INTO Payments (InvoiceID, PaymentDate, Amount, PaymentMethod) VALUES +(2, '2025-07-25', 2400.00, 'Bank Transfer'), -- Payment for an existing unpaid invoice +(5, '2025-07-26', 5000.00, 'Credit Card'); + +-- Insert a new Campaign (starting from CampaignID 3) +INSERT INTO Campaigns (CampaignName, StartDate, EndDate, Budget, Status, Owner) VALUES +('Healthcare Solutions Webinar', '2025-09-01', '2025-09-30', 7500.00, 'Planned', 2); + +-- Insert more Campaign Members +INSERT INTO CampaignMembers (CampaignID, LeadID, Status) VALUES +(3, 9, 'Sent'); -- Nancy Nurse (LeadID 9) for Healthcare campaign +INSERT INTO CampaignMembers (CampaignID, ContactID, Status) VALUES +(3, 8, 'Sent'), -- Helen Healer (ContactID 8) +(3, 9, 'Responded'); -- Marcus Welby (ContactID 9) + +-- Insert more Tasks (starting from TaskID 5) +INSERT INTO Tasks (Title, Description, DueDate, Status, Priority, AssignedTo, RelatedToEntity, RelatedToID) VALUES +('Draft contract for Pinnacle Finance', 'Based on the final negotiation terms.', '2025-07-28', 'Completed', 'High', 4, 'Opportunity', 9), +('Schedule webinar with HealthFirst', 'Discuss Patient Management System demo.', '2025-08-10', 'Not Started', 'High', 4, 'Opportunity', 7), +('Research Quantum Innovations needs', 'Prepare for qualification call.', '2025-08-15', 'In Progress', 'Normal', 3, 'Opportunity', 6), +('Call Nancy Nurse to follow up', 'Follow up from trade show conversation.', '2025-08-05', 'Not Started', 'Normal', 3, 'Lead', 9); + +-- Insert more Notes (starting from NoteID 5) +INSERT INTO Notes (Content, CreatedBy, RelatedToEntity, RelatedToID) VALUES +('Pinnacle deal closed! Great work team.', 2, 'Opportunity', 9), +('GreenScape is looking for a solution before year-end for compliance reasons.', 3, 'Opportunity', 8), +('Nancy was very engaged at the booth, good prospect.', 3, 'Lead', 9); + +-- Insert more Support Tickets (starting from TicketID 5) +INSERT INTO SupportTickets (CustomerID, ContactID, Subject, Description, Status, Priority, AssignedTo) VALUES +(4, 5, 'Dashboard data not refreshing', 'The main dashboard widgets are not updating in real-time.', 'Open', 'High', 5), +(5, 6, 'Report generation is slow', 'Generating the quarterly HR report takes over 10 minutes.', 'In Progress', 'Normal', 5), +(9, 11, 'Login issue for new user', 'Fiona Funds cannot log into the new Wealth Management platform.', 'Open', 'High', 5); + +-- Insert more Ticket Comments (starting from CommentID 5) +INSERT INTO TicketComments (TicketID, Comment, CreatedBy) VALUES +(2, 'Invoice has been resent to the customer.', 5), -- Comment on existing ticket +(4, 'The calendar sync issue seems to be related to a recent Google API update. Investigating.', 5), -- Comment on existing ticket +(5, 'Escalated to engineering to check the database query performance.', 5), +(6, 'Confirmed the issue is with the real-time data service. Restarting the service.', 5); + +-- Update existing records to show data changes +UPDATE Leads SET Status = 'Contacted' WHERE LeadID = 2; -- Frank Stein +UPDATE Invoices SET Status = 'Paid' WHERE InvoiceID = 2; -- Innovate Inc. invoice From e49c4f4b69e894bd11ea5f6a77dc17307487a96f Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 20:06:07 +0300 Subject: [PATCH 36/58] fix expired token --- api/index.py | 114 +++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 93 insertions(+), 21 deletions(-) diff --git a/api/index.py b/api/index.py index 018709bb..23dd30d7 100644 --- a/api/index.py +++ b/api/index.py @@ -4,6 +4,7 @@ import logging import os import random +import time from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FuturesTimeoutError from functools import wraps @@ -35,16 +36,57 @@ SECRET_TOKEN_ERP = os.getenv("SECRET_TOKEN_ERP") +def validate_and_cache_user(): + """ + Helper function to validate OAuth token and cache user info. + Returns (user_info, is_authenticated) tuple. + """ + user_info = session.get("google_user") + token_validated_at = session.get("token_validated_at", 0) + current_time = time.time() + + # Use cached user info if it's less than 15 minutes old + if user_info and (current_time - token_validated_at) < 900: # 15 minutes + return user_info, True + + # If no valid session data or cache expired, check OAuth token + if not google.authorized: + session.clear() + return None, False + + try: + # Make network call to validate token + resp = google.get("/oauth2/v2/userinfo") + if not resp.ok: + session.clear() + return None, False + + user_info = resp.json() + session["google_user"] = user_info + session["token_validated_at"] = current_time + return user_info, True + + except Exception as e: + logging.warning(f"OAuth validation error: {e}") + session.clear() + return None, False + + def token_required(f): """Decorator to protect routes with token authentication""" @wraps(f) def decorated_function(*args, **kwargs): - user_info = session.get("google_user") - if user_info: - g.user_id = user_info.get("id") - else: - return jsonify(message="Unauthorized"), 401 + user_info, is_authenticated = validate_and_cache_user() + + if not is_authenticated: + return jsonify(message="Unauthorized - Please log in"), 401 + + g.user_id = user_info.get("id") + if not g.user_id: + session.clear() + return jsonify(message="Unauthorized - Invalid user"), 401 + return f(*args, **kwargs) return decorated_function @@ -68,15 +110,28 @@ def decorated_function(*args, **kwargs): app.register_blueprint(google_bp, url_prefix="/login") +@app.errorhandler(Exception) +def handle_oauth_error(error): + """Handle OAuth-related errors gracefully""" + # Check if it's an OAuth-related error + if "token" in str(error).lower() or "oauth" in str(error).lower(): + logging.warning(f"OAuth error occurred: {error}") + session.clear() + return redirect(url_for("home")) + + # For other errors, let them bubble up + raise error + + @app.route("/") def home(): """Home route""" - is_authenticated = "google_oauth_token" in session - if is_authenticated: - resp = google.get("/oauth2/v2/userinfo") - if resp.ok: - user_info = resp.json() - session["google_user"] = user_info + _, is_authenticated = validate_and_cache_user() + + # If not authenticated through OAuth, check for any stale session data + if not is_authenticated and not google.authorized: + session.pop("google_user", None) + return render_template("chat.j2", is_authenticated=is_authenticated) @@ -489,18 +544,35 @@ def suggestions(): def login_google(): if not google.authorized: return redirect(url_for("google.login")) - resp = google.get("/oauth2/v2/userinfo") - if resp.ok: - user_info = resp.json() - session["google_user"] = user_info - # You can set your own token/session logic here - return redirect(url_for("home")) - return "Could not fetch your information from Google.", 400 + + try: + resp = google.get("/oauth2/v2/userinfo") + if resp.ok: + user_info = resp.json() + session["google_user"] = user_info + return redirect(url_for("home")) + else: + # OAuth token might be expired, redirect to login + session.clear() + return redirect(url_for("google.login")) + except Exception as e: + logging.error(f"Login error: {e}") + session.clear() + return redirect(url_for("google.login")) @app.route("/logout") def logout(): session.clear() + # Also revoke the OAuth token if possible + if google.authorized: + try: + google.get( + "https://accounts.google.com/o/oauth2/revoke", + params={"token": google.access_token} + ) + except Exception as e: + logging.warning(f"Error revoking token: {e}") return redirect(url_for("home")) @app.route("/graphs//refresh", methods=["POST"]) @@ -522,10 +594,10 @@ def refresh_graph_schema(graph_id: str): "success": False, "error": "No database URL found for this graph" }), 400 - + # Perform schema refresh success, message = PostgresLoader.refresh_graph_schema(graph_id, db_url) - + if success: return jsonify({ "success": True, @@ -536,7 +608,7 @@ def refresh_graph_schema(graph_id: str): "success": False, "error": f"Failed to refresh schema: {message}" }), 500 - + except Exception as e: logging.error("Error in manual schema refresh: %s", e) return jsonify({ From 406c54a20acb1ab031305ce5c5fc5d92c0837efb Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 20:13:27 +0300 Subject: [PATCH 37/58] Add user image --- api/index.py | 4 +- api/static/css/chat.css | 83 +++++++++++++++++++++++++++++++++++++++++ api/static/js/chat.js | 26 +++++++++++++ api/templates/chat.j2 | 15 +++++++- 4 files changed, 124 insertions(+), 4 deletions(-) diff --git a/api/index.py b/api/index.py index 23dd30d7..69574864 100644 --- a/api/index.py +++ b/api/index.py @@ -126,13 +126,13 @@ def handle_oauth_error(error): @app.route("/") def home(): """Home route""" - _, is_authenticated = validate_and_cache_user() + user_info, is_authenticated = validate_and_cache_user() # If not authenticated through OAuth, check for any stale session data if not is_authenticated and not google.authorized: session.pop("google_user", None) - return render_template("chat.j2", is_authenticated=is_authenticated) + return render_template("chat.j2", is_authenticated=is_authenticated, user_info=user_info) @app.route("/graphs") diff --git a/api/static/css/chat.css b/api/static/css/chat.css index f617ae11..ec384ba7 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -881,6 +881,89 @@ body { background: #c0392b; } +/* User Profile Button Styles */ +.user-profile-btn { + position: fixed; + top: 20px; + right: 30px; + z-index: 2000; + width: 48px; + height: 48px; + border: none; + border-radius: 50%; + background: #fff; + box-shadow: 0 2px 8px rgba(0,0,0,0.15); + cursor: pointer; + transition: all 0.2s ease; + padding: 2px; +} + +.user-profile-btn:hover { + box-shadow: 0 4px 12px rgba(0,0,0,0.25); + transform: translateY(-1px); +} + +.user-profile-img { + width: 100%; + height: 100%; + border-radius: 50%; + object-fit: cover; +} + +/* User Profile Dropdown */ +.user-profile-dropdown { + position: fixed; + top: 80px; + right: 30px; + z-index: 1999; + background: var(--falkor-secondary); + border: 1px solid var(--falkor-border-primary); + border-radius: 8px; + box-shadow: 0 4px 12px rgba(0,0,0,0.25); + min-width: 200px; + display: none; +} + +.user-profile-dropdown.show { + display: block; +} + +.user-profile-info { + padding: 15px; + border-bottom: 1px solid var(--falkor-quaternary); +} + +.user-profile-name { + color: var(--text-primary); + font-weight: bold; + margin-bottom: 5px; +} + +.user-profile-email { + color: var(--text-secondary); + font-size: 0.9em; +} + +.user-profile-actions { + padding: 10px; +} + +.user-profile-logout { + width: 100%; + padding: 10px; + background: #e74c3c; + color: #fff; + border: none; + border-radius: 5px; + cursor: pointer; + font-weight: bold; + transition: background 0.2s; +} + +.user-profile-logout:hover { + background: #c0392b; +} + /* Destructive Confirmation Styles */ .destructive-confirmation-container { border: 2px solid #ff4444; diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 5de0dd26..612586f5 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -808,4 +808,30 @@ document.addEventListener('DOMContentLoaded', function() { }); }); } +}); + +// User Profile Dropdown Functionality +document.addEventListener('DOMContentLoaded', function() { + const userProfileBtn = document.getElementById('user-profile-btn'); + const userProfileDropdown = document.getElementById('user-profile-dropdown'); + + if (userProfileBtn && userProfileDropdown) { + // Toggle dropdown when profile button is clicked + userProfileBtn.addEventListener('click', function(e) { + e.stopPropagation(); + userProfileDropdown.classList.toggle('show'); + }); + + // Close dropdown when clicking outside + document.addEventListener('click', function(e) { + if (!userProfileBtn.contains(e.target) && !userProfileDropdown.contains(e.target)) { + userProfileDropdown.classList.remove('show'); + } + }); + + // Prevent dropdown from closing when clicking inside it + userProfileDropdown.addEventListener('click', function(e) { + e.stopPropagation(); + }); + } }); \ No newline at end of file diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index 3990690f..f3297f23 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -10,8 +10,19 @@ - {% if is_authenticated %} - Logout + {% if is_authenticated and user_info %} + + {% endif %}
- +

Text-to-SQL(Natural Language to SQL Generator)

From e41c378c2b92d46a3353df6370d1983ff0dc8933 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 25 Jul 2025 23:41:49 +0300 Subject: [PATCH 41/58] fix the warning message --- api/static/css/chat.css | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index b4b0ce02..3aa062cf 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -267,7 +267,6 @@ body { word-wrap: break-word; overflow-wrap: break-word; overflow-x: auto; - white-space: pre-wrap; } /* Styles for formatted text blocks */ @@ -600,11 +599,25 @@ body { right: -60px; top: 50%; transform: translateY(-50%); - background: transparent; - color: var(--text-primary); + z-index: 2000; + width: 48px; + height: 48px; border: none; + border-radius: 50%; + background: var(--falkor-quaternary); + color: var(--text-primary); + box-shadow: 0 2px 8px rgba(0,0,0,0.15); cursor: pointer; - padding: 4px; + transition: all 0.2s ease; + display: flex; + align-items: center; + justify-content: center; + padding: 0; +} + +#reset-button:hover { + box-shadow: 0 4px 12px rgba(0,0,0,0.25); + background: var(--falkor-accent); } #reset-button:disabled { @@ -1159,16 +1172,13 @@ body { /* Destructive Confirmation Styles */ .destructive-confirmation-container { - border: 2px solid #ff4444; - border-radius: 8px; - background: linear-gradient(135deg, #2a1f1f, #3a2222); - margin: 10px 0; transition: all 0.3s ease; } .destructive-confirmation-message { - background: none !important; - border: none !important; + border: 2px solid #ff4444; + border-radius: 8px; + background: linear-gradient(135deg, #2a1f1f, #3a2222); } .destructive-confirmation { From 6d375e76e5443b906e292da8a38165d2a7cc4ab0 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sat, 26 Jul 2025 00:18:14 +0300 Subject: [PATCH 42/58] update menu button for dark/light --- api/static/css/chat.css | 72 ++++---------------------------- api/static/public/icons/menu.svg | 3 -- api/templates/chat.j2 | 17 +++++--- 3 files changed, 20 insertions(+), 72 deletions(-) delete mode 100644 api/static/public/icons/menu.svg diff --git a/api/static/css/chat.css b/api/static/css/chat.css index 3aa062cf..e928dd47 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -395,43 +395,6 @@ body { position: absolute; top: 20px; left: 20px; - z-index: 10; - background: var(--falkor-accent); - box-shadow: 0 2px 8px rgba(0,0,0,0.3); -} - -.menu-trigger { - border: none; - height: 32px; - width: 32px; - background: var(--falkor-accent); - border: 2px solid var(--falkor-primary); - box-shadow: 0 2px 6px rgba(0,0,0,0.2); -} - -.menu-trigger:hover { - background: var(--falkor-primary); - border-color: var(--falkor-tertiary); - box-shadow: 0 4px 12px rgba(0,0,0,0.4); -} - -.menu-trigger img { - width: 20px; - height: 20px; - filter: brightness(0) invert(1); - transition: filter 0.2s ease; -} - -#menu-button { - background: var(--falkor-accent); - border: 2px solid var(--falkor-primary); - box-shadow: 0 2px 6px rgba(0,0,0,0.2); -} - -#menu-button:hover { - background: var(--falkor-primary); - border-color: var(--falkor-tertiary); - box-shadow: 0 4px 12px rgba(0,0,0,0.4); } #menu-button img { @@ -599,6 +562,14 @@ body { right: -60px; top: 50%; transform: translateY(-50%); +} + +.action-button:hover { + box-shadow: 0 4px 12px rgba(0,0,0,0.25); + background: var(--falkor-accent); +} + +.action-button{ z-index: 2000; width: 48px; height: 48px; @@ -615,11 +586,6 @@ body { padding: 0; } -#reset-button:hover { - box-shadow: 0 4px 12px rgba(0,0,0,0.25); - background: var(--falkor-accent); -} - #reset-button:disabled { opacity: 0.5; cursor: not-allowed; @@ -1105,30 +1071,10 @@ body { } /* Theme Toggle Button Styles */ -.theme-toggle-btn { +#theme-toggle-btn { position: fixed; top: 20px; right: 90px; - z-index: 2000; - width: 48px; - height: 48px; - border: none; - border-radius: 50%; - background: var(--falkor-quaternary); - color: var(--text-primary); - box-shadow: 0 2px 8px rgba(0,0,0,0.15); - cursor: pointer; - transition: all 0.2s ease; - display: flex; - align-items: center; - justify-content: center; - padding: 0; -} - -.theme-toggle-btn:hover { - box-shadow: 0 4px 12px rgba(0,0,0,0.25); - transform: translateY(-1px); - background: var(--falkor-accent); } .theme-icon { diff --git a/api/static/public/icons/menu.svg b/api/static/public/icons/menu.svg deleted file mode 100644 index 99a5c97e..00000000 --- a/api/static/public/icons/menu.svg +++ /dev/null @@ -1,3 +0,0 @@ - - - diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index d2026ae3..10b1f4e5 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -11,7 +11,7 @@ - +
-
@@ -125,7 +130,7 @@ - +
From 5825cc2bbbaec06bff063ac03bf9a7757e269b11 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sat, 26 Jul 2025 09:41:15 +0300 Subject: [PATCH 43/58] tune down the colors --- api/static/css/chat.css | 183 +++++++++++++++++++++------------------- 1 file changed, 97 insertions(+), 86 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index e928dd47..9145c0bd 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -5,91 +5,91 @@ } :root { - /* FalkorDB brand colors - based on browser.falkordb.com */ - --falkor-primary: #7466FF; - /* FalkorDB primary teal */ - --falkor-secondary: #191919; - /* Dark navy blue - main background */ - --falkor-tertiary: #FF66B3; - /* Falkor Tertiary Color */ - --falkor-quaternary: #393939; - /* Falkor Quaternary Color */ - --dark-bg: black; - /* Slightly lighter navy for surfaces */ - --falkor-accent: #19B6C9; - /* Secondary teal for hover states */ - --falkor-border-primary: #7466FF; + /* Professional color palette - muted and business-appropriate */ + --falkor-primary: #5B6BC0; + /* Muted indigo - professional primary */ + --falkor-secondary: #1A1A1A; + /* Dark charcoal - main background */ + --falkor-tertiary: #B39DDB; + /* Muted lavender - subtle tertiary */ + --falkor-quaternary: #424242; + /* Medium gray - quaternary */ + --dark-bg: #0F0F0F; + /* Deep dark for surfaces */ + --falkor-accent: #26A69A; + /* Professional teal - subdued accent */ + --falkor-border-primary: #5B6BC0; /* Primary border color*/ - --falkor-border-secondary: #FF804D; - /* Secondary border color*/ - --falkor-border-tertiary: #FFFFFF; - /* Tertiary border color*/ - --text-primary: #FFFFFF; - /* Primary text color */ - --text-secondary: #CDD3DF; - /* Secondary text color */ - --text-tertiary: #525252; - /* Tertiary text color */ - --falkor-highlight: #20C9D8; - /* Highlight color */ - --accent-green: #4CAF50; - /* Success/final result color */ + --falkor-border-secondary: #90A4AE; + /* Muted blue-gray - professional secondary border */ + --falkor-border-tertiary: #E0E0E0; + /* Light gray border */ + --text-primary: #F5F5F5; + /* Soft white text */ + --text-secondary: #B0BEC5; + /* Muted blue-gray text */ + --text-tertiary: #616161; + /* Medium gray text */ + --falkor-highlight: #4DB6AC; + /* Subtle teal highlight */ + --accent-green: #66BB6A; + /* Professional green */ --icon-filter: invert(1); /* Dark theme - invert icons for white appearance */ - --bg-tertiary: #2a2a2a; - /* Tertiary background */ - --border-color: #555; - /* Border color */ + --bg-tertiary: #2E2E2E; + /* Professional dark gray */ + --border-color: #616161; + /* Subtle border color */ } /* Light theme variables */ [data-theme="light"] { - --falkor-primary: #7466FF; - --falkor-secondary: #FFFFFF; - --falkor-tertiary: #FF66B3; + --falkor-primary: #5B6BC0; + --falkor-secondary: #FAFAFA; + --falkor-tertiary: #B39DDB; --falkor-quaternary: #F5F5F5; --dark-bg: #FFFFFF; - --falkor-accent: #19B6C9; - --falkor-border-primary: #7466FF; - --falkor-border-secondary: #FF804D; - --falkor-border-tertiary: #333333; - --text-primary: #333333; - --text-secondary: #666666; - --text-tertiary: #999999; - --falkor-highlight: #20C9D8; - --accent-green: #2E7D32; - /* Darker green for light theme */ + --falkor-accent: #26A69A; + --falkor-border-primary: #5B6BC0; + --falkor-border-secondary: #90A4AE; + --falkor-border-tertiary: #424242; + --text-primary: #212121; + --text-secondary: #616161; + --text-tertiary: #9E9E9E; + --falkor-highlight: #4DB6AC; + --accent-green: #388E3C; + /* Professional dark green for light theme */ --icon-filter: invert(0); /* Light theme - no inversion for dark icons */ - --bg-tertiary: #F0F0F0; - /* Light tertiary background */ - --border-color: #DDD; + --bg-tertiary: #F8F8F8; + /* Very light gray background */ + --border-color: #E0E0E0; /* Light border color */ } /* System theme detection */ @media (prefers-color-scheme: light) { [data-theme="system"] { - --falkor-primary: #7466FF; - --falkor-secondary: #FFFFFF; - --falkor-tertiary: #FF66B3; + --falkor-primary: #5B6BC0; + --falkor-secondary: #FAFAFA; + --falkor-tertiary: #B39DDB; --falkor-quaternary: #F5F5F5; --dark-bg: #FFFFFF; - --falkor-accent: #19B6C9; - --falkor-border-primary: #7466FF; - --falkor-border-secondary: #FF804D; - --falkor-border-tertiary: #333333; - --text-primary: #333333; - --text-secondary: #666666; - --text-tertiary: #999999; - --falkor-highlight: #20C9D8; - --accent-green: #2E7D32; - /* Darker green for light theme */ + --falkor-accent: #26A69A; + --falkor-border-primary: #5B6BC0; + --falkor-border-secondary: #90A4AE; + --falkor-border-tertiary: #424242; + --text-primary: #212121; + --text-secondary: #616161; + --text-tertiary: #9E9E9E; + --falkor-highlight: #4DB6AC; + --accent-green: #388E3C; + /* Professional dark green for light theme */ --icon-filter: invert(0); /* Light theme - no inversion for dark icons */ - --bg-tertiary: #F0F0F0; - /* Light tertiary background */ - --border-color: #DDD; + --bg-tertiary: #F8F8F8; + /* Very light gray background */ + --border-color: #E0E0E0; /* Light border color */ } } @@ -110,6 +110,11 @@ body { overflow: hidden; } +/* Ensure all form elements inherit the consistent font */ +button, input, select, textarea { + font-family: inherit; +} + #container { height: 98%; width: 100%; @@ -121,7 +126,7 @@ body { #gradient { width: 100%; height: 2%; - background: linear-gradient(to right, #C15CFF, #FF5454); + background: linear-gradient(to right, var(--falkor-primary), var(--falkor-accent)); } .logo { @@ -312,7 +317,8 @@ body { flex-grow: 1; padding: 10px; border-radius: 6px; - box-shadow: 0 0 20px 3px var(--falkor-tertiary); + box-shadow: 0 0 8px 1px var(--falkor-primary); + border: 1px solid var(--border-color); } .input-container.loading { @@ -347,6 +353,10 @@ body { border: none; } +.input-button img { + filter: var(--icon-filter) brightness(0.8) saturate(0.3); +} + .input-button:hover { opacity: 0.7; } @@ -567,6 +577,7 @@ body { .action-button:hover { box-shadow: 0 4px 12px rgba(0,0,0,0.25); background: var(--falkor-accent); + transform: translateY(-1px); } .action-button{ @@ -871,15 +882,15 @@ body { } @keyframes shadow-fade { 0% { - box-shadow: 0 0 20px 3px var(--falkor-tertiary) + box-shadow: 0 0 8px 1px var(--falkor-primary) } 50% { - box-shadow: 0 0 40px 6px var(--falkor-tertiary) + box-shadow: 0 0 12px 2px var(--falkor-primary) } 100% { - box-shadow: 0 0 20px 3px var(--falkor-tertiary) + box-shadow: 0 0 8px 1px var(--falkor-primary) } } @@ -1057,7 +1068,7 @@ body { .user-profile-logout { width: 100%; padding: 10px; - background: #e74c3c; + background: #D32F2F; color: #fff; border: none; border-radius: 5px; @@ -1067,7 +1078,7 @@ body { } .user-profile-logout:hover { - background: #c0392b; + background: #B71C1C; } /* Theme Toggle Button Styles */ @@ -1122,9 +1133,9 @@ body { } .destructive-confirmation-message { - border: 2px solid #ff4444; + border: 1px solid #D32F2F; border-radius: 8px; - background: linear-gradient(135deg, #2a1f1f, #3a2222); + background: linear-gradient(135deg, var(--bg-tertiary), var(--falkor-quaternary)); } .destructive-confirmation { @@ -1140,7 +1151,7 @@ body { } .confirmation-text strong { - color: #ff6666; + color: #FFCDD2; } .confirmation-buttons { @@ -1163,29 +1174,29 @@ body { } .confirm-btn { - background: #ff4444; + background: #D32F2F; color: white; - border: 2px solid #ff4444; + border: 1px solid #D32F2F; } .confirm-btn:hover { - background: #ff2222; - border-color: #ff2222; - transform: translateY(-2px); - box-shadow: 0 4px 12px rgba(255, 68, 68, 0.4); + background: #B71C1C; + border-color: #B71C1C; + transform: translateY(-1px); + box-shadow: 0 2px 8px rgba(211, 47, 47, 0.3); } .cancel-btn { background: transparent; - color: #ffffff; - border: 2px solid #666666; + color: var(--text-primary); + border: 1px solid var(--border-color); } .cancel-btn:hover { - background: #666666; - border-color: #888888; - transform: translateY(-2px); - box-shadow: 0 4px 12px rgba(102, 102, 102, 0.4); + background: var(--bg-tertiary); + border-color: var(--text-secondary); + transform: translateY(-1px); + box-shadow: 0 2px 8px rgba(0, 0, 0, 0.2); } .confirm-btn:active, .cancel-btn:active { From 211b31c689b6e0856d77adb3d5d1dbd5dad73b0d Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sat, 26 Jul 2025 10:23:29 +0300 Subject: [PATCH 44/58] fix icon color in light mode --- api/static/css/chat.css | 5 +++++ api/static/public/icons/logo.svg | 25 ++++++++++--------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index 9145c0bd..374461d6 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -134,6 +134,11 @@ button, input, select, textarea { width: auto; } +[data-theme="light"] .logo { + filter: invert(1); +} + + .chat-container { flex: 1 1 0; display: flex; diff --git a/api/static/public/icons/logo.svg b/api/static/public/icons/logo.svg index 3ae7f5cd..60ebfd33 100644 --- a/api/static/public/icons/logo.svg +++ b/api/static/public/icons/logo.svg @@ -3,30 +3,25 @@ - - - - - - - - - + + + + + + + + + - + - - - - - From 5810e9975f406e42eea8e8bae657be43ffe1b241 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sun, 27 Jul 2025 07:39:11 +0300 Subject: [PATCH 45/58] add docker --- Dockerfile | 34 +++++++++++++ requirements.txt | 125 +++++++++++++++++++++++------------------------ start.sh | 21 ++++++++ 3 files changed, 117 insertions(+), 63 deletions(-) create mode 100644 start.sh diff --git a/Dockerfile b/Dockerfile index e69de29b..34b7c0cc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -0,0 +1,34 @@ +# Use a single stage build with FalkorDB base image +FROM falkordb/falkordb:latest + +ENV PYTHONUNBUFFERED=1 \ + FALKORDB_HOST=localhost \ + FALKORDB_PORT=6379 + +USER root + +# Install Python and pip, netcat for wait loop in start.sh +RUN apt-get update && apt-get install -y \ + python3 \ + python3-pip \ + netcat-openbsd \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Copy requirements and install Python dependencies +COPY requirements.txt . +RUN python3 -m pip install --no-cache-dir --break-system-packages -r requirements.txt + +# Copy application code +COPY . . + +# Copy and make start.sh executable +COPY start.sh /start.sh +RUN chmod +x /start.sh + +EXPOSE 5000 6379 3000 + + +# Use start.sh as entrypoint +ENTRYPOINT ["/start.sh"] diff --git a/requirements.txt b/requirements.txt index a93da3ef..5bffe872 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,63 +1,62 @@ -aiohappyeyeballs==2.6.1 ; python_version == "3.12" -aiohttp==3.12.13 ; python_version == "3.12" -aiosignal==1.3.2 ; python_version == "3.12" -annotated-types==0.7.0 ; python_version == "3.12" -anyio==4.9.0 ; python_version == "3.12" -attrs==25.3.0 ; python_version == "3.12" -blinker==1.9.0 ; python_version == "3.12" -boto3==1.38.46 ; python_version == "3.12" -botocore==1.38.46 ; python_version == "3.12" -certifi==2025.6.15 ; python_version == "3.12" -charset-normalizer==3.4.2 ; python_version == "3.12" -click==8.2.1 ; python_version == "3.12" -colorama==0.4.6 ; python_version == "3.12" and platform_system == "Windows" -distro==1.9.0 ; python_version == "3.12" -falkordb==1.1.2 ; python_version == "3.12" -filelock==3.18.0 ; python_version == "3.12" -flask==3.1.1 ; python_version == "3.12" -flask-dance==7.1.0 ; python_version == "3.12" -frozenlist==1.7.0 ; python_version == "3.12" -fsspec==2025.5.1 ; python_version == "3.12" -h11==0.16.0 ; python_version == "3.12" -hf-xet==1.1.5 ; python_version == "3.12" and (platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "arm64" or platform_machine == "aarch64") -httpcore==1.0.9 ; python_version == "3.12" -httpx==0.28.1 ; python_version == "3.12" -huggingface-hub==0.33.1 ; python_version == "3.12" -idna==3.10 ; python_version == "3.12" -importlib-metadata==8.7.0 ; python_version == "3.12" -itsdangerous==2.2.0 ; python_version == "3.12" -jinja2==3.1.6 ; python_version == "3.12" -jiter==0.10.0 ; python_version == "3.12" -jmespath==1.0.1 ; python_version == "3.12" -jsonschema-specifications==2025.4.1 ; python_version == "3.12" -jsonschema==4.24.0 ; python_version == "3.12" -litellm==1.73.6 ; python_version == "3.12" -markupsafe==3.0.2 ; python_version == "3.12" -multidict==6.6.2 ; python_version == "3.12" -openai==1.93.0 ; python_version == "3.12" -packaging==25.0 ; python_version == "3.12" -propcache==0.3.2 ; python_version == "3.12" -pydantic-core==2.33.2 ; python_version == "3.12" -pydantic==2.11.7 ; python_version == "3.12" -pyjwt==2.9.0 ; python_version == "3.12" -python-dateutil==2.9.0.post0 ; python_version == "3.12" -python-dotenv==1.1.1 ; python_version == "3.12" -pyyaml==6.0.2 ; python_version == "3.12" -redis==5.3.0 ; python_version == "3.12" -referencing==0.36.2 ; python_version == "3.12" -regex==2024.11.6 ; python_version == "3.12" -requests==2.32.4 ; python_version == "3.12" -rpds-py==0.25.1 ; python_version == "3.12" -s3transfer==0.13.0 ; python_version == "3.12" -six==1.17.0 ; python_version == "3.12" -sniffio==1.3.1 ; python_version == "3.12" -tiktoken==0.9.0 ; python_version == "3.12" -tokenizers==0.21.2 ; python_version == "3.12" -tqdm==4.67.1 ; python_version == "3.12" -typing-extensions==4.14.0 ; python_version == "3.12" -typing-inspection==0.4.1 ; python_version == "3.12" -urllib3==2.5.0 ; python_version == "3.12" -werkzeug==3.1.3 ; python_version == "3.12" -yarl==1.20.1 ; python_version == "3.12" -zipp==3.23.0 ; python_version == "3.12" -psycopg2-binary==2.9.9 ; python_version == "3.12" +aiohappyeyeballs==2.6.1 +aiohttp==3.12.13 +aiosignal==1.3.2 +annotated-types==0.7.0 +anyio==4.9.0 +attrs==25.3.0 +blinker==1.9.0 +boto3==1.38.46 +botocore==1.38.46 +certifi==2025.6.15 +charset-normalizer==3.4.2 +click==8.2.1 +distro==1.9.0 +falkordb==1.1.2 +filelock==3.18.0 +flask==3.1.1 +flask-dance==7.1.0 +frozenlist==1.7.0 +fsspec==2025.5.1 +h11==0.16.0 +hf-xet==1.1.5 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.33.1 +idna==3.10 +importlib-metadata==8.7.0 +itsdangerous==2.2.0 +jinja2==3.1.6 +jiter==0.10.0 +jmespath==1.0.1 +jsonschema-specifications==2025.4.1 +jsonschema==4.24.0 +litellm==1.73.6 +markupsafe==3.0.2 +multidict==6.6.2 +openai==1.93.0 +packaging==25.0 +propcache==0.3.2 +pydantic-core==2.33.2 +pydantic==2.11.7 +pyjwt==2.9.0 +python-dateutil==2.9.0.post0 +python-dotenv==1.1.1 +pyyaml==6.0.2 +redis==5.3.0 +referencing==0.36.2 +regex==2024.11.6 +requests==2.32.4 +rpds-py==0.25.1 +s3transfer==0.13.0 +six==1.17.0 +sniffio==1.3.1 +tiktoken==0.9.0 +tokenizers==0.21.2 +tqdm==4.67.1 +typing-extensions==4.14.0 +typing-inspection==0.4.1 +urllib3==2.5.0 +werkzeug==3.1.3 +yarl==1.20.1 +zipp==3.23.0 +psycopg2-binary==2.9.9 diff --git a/start.sh b/start.sh new file mode 100644 index 00000000..c0db7ef6 --- /dev/null +++ b/start.sh @@ -0,0 +1,21 @@ +#!/bin/bash +set -e + + +# Set default values if not set +FALKORDB_HOST="${FALKORDB_HOST:-localhost}" +FALKORDB_PORT="${FALKORDB_PORT:-6379}" + +# Start FalkorDB Redis server in background +redis-server --loadmodule /var/lib/falkordb/bin/falkordb.so & + +# Wait until FalkorDB is ready +echo "Waiting for FalkorDB to start on $FALKORDB_HOST:$FALKORDB_PORT..." + +while ! nc -z "$FALKORDB_HOST" "$FALKORDB_PORT"; do + sleep 0.5 +done + + +echo "FalkorDB is up - launching Flask..." +exec python3 -m flask --app api.index run --host=0.0.0.0 --port=5000 \ No newline at end of file From 5c10089b96dc8cfe1552cd768766c4111c7e7831 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sun, 27 Jul 2025 08:14:10 +0300 Subject: [PATCH 46/58] replace the system icon dark/light --- api/templates/chat.j2 | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index 10b1f4e5..d1c2c4fc 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -18,11 +18,21 @@ - - - - - + + + + + + + + + + + + + + + {% if is_authenticated and user_info %} From d44669605bac7f5316c49bb6179fbd505b25fb25 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sun, 27 Jul 2025 08:18:18 +0300 Subject: [PATCH 47/58] fix Connect to Postgres in light mode --- api/static/css/chat.css | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index 374461d6..9381ee48 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -765,7 +765,7 @@ button, input, select, textarea { border-radius: 4px; margin-bottom: 1.5em; color: var(--text-primary); - background: var(--falkor-primary); + background: var(--falkor-quaternary); } .pg-modal-actions { display: flex; @@ -814,6 +814,7 @@ button, input, select, textarea { .pg-modal-cancel { background: var(--bg-tertiary); color: var(--text-primary); + border: 1px solid var(--border-color); } .pg-modal-cancel:hover { background: var(--border-color); @@ -822,6 +823,7 @@ button, input, select, textarea { background: var(--bg-tertiary); color: var(--text-secondary); cursor: not-allowed; + border: 1px solid var(--text-secondary); } .pg-modal-input:disabled { background: var(--bg-tertiary); From 1ad7437c24dfa99714fffc55e2f15ce6c01d5028 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sun, 27 Jul 2025 14:38:05 +0300 Subject: [PATCH 48/58] add responisve support --- api/static/css/chat.css | 172 +++++++++++++++++++++++++++++++++++++++- api/static/js/chat.js | 40 +++++++++- 2 files changed, 205 insertions(+), 7 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index 9381ee48..f89bc537 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -153,8 +153,26 @@ button, input, select, textarea { transition: margin-left 0.4s cubic-bezier(0.4, 0, 0.2, 1); } +/* Mobile responsive adjustments */ +@media (max-width: 768px) { + .chat-container { + padding-right: 10px; + padding-left: 10px; + padding-top: 10px; + padding-bottom: 10px; + } +} + +/* Desktop: Menu pushes content */ #menu-container.open~#chat-container { - margin-left: 0; + margin-left: 30dvw; +} + +/* Mobile: Menu overlays content (no pushing) */ +@media (max-width: 768px) { + #menu-container.open~#chat-container { + margin-left: 0; + } } .chat-header { @@ -279,6 +297,27 @@ button, input, select, textarea { overflow-x: auto; } +/* Mobile responsive messages */ +@media (max-width: 768px) { + .message { + max-width: 85%; + padding: 10px 12px; + font-size: 14px; + } + + .chat-header h1 { + font-size: 18px; + } + + #message-input { + font-size: 16px !important; + } + + #message-input::placeholder { + font-size: 16px !important; + } +} + /* Styles for formatted text blocks */ .sql-line, .array-line, .plain-line { word-wrap: break-word; @@ -393,6 +432,22 @@ button, input, select, textarea { pointer-events: auto; } +/* Mobile responsive menu */ +@media (max-width: 768px) { + #menu-container { + position: fixed; + top: 0; + left: 0; + height: 100vh; + z-index: 999; + } + + #menu-container.open { + width: 80vw; + padding: 15px; + } +} + #menu-header { display: flex; flex-direction: row; @@ -412,6 +467,14 @@ button, input, select, textarea { left: 20px; } +/* Mobile responsive side menu button */ +@media (max-width: 768px) { + #side-menu-button { + top: 15px; + left: 15px; + } +} + #menu-button img { rotate: 180deg; filter: brightness(0) invert(1); @@ -422,7 +485,6 @@ button, input, select, textarea { width: 32px; height: 32px; object-fit: cover; - background: none !important; } .menu-item { @@ -498,6 +560,14 @@ button, input, select, textarea { gap: 10px; } +/* Mobile responsive button container */ +@media (max-width: 768px) { + .button-container { + flex-direction: row; + gap: 8px; + } +} + #graph-select { height: 100%; padding: 8px 12px; @@ -517,6 +587,19 @@ button, input, select, textarea { cursor:pointer; } +/* Mobile responsive select elements */ +@media (max-width: 768px) { + #graph-select, + #custom-file-upload, + #open-pg-modal { + min-width: 120px; + width: auto; + padding: 8px 10px; + font-size: 14px; + flex: 1; + } +} + #graph-select:focus { outline: none; border-color: var(--falkor-border-primary); @@ -579,6 +662,22 @@ button, input, select, textarea { transform: translateY(-50%); } +/* Mobile responsive reset button */ +@media (max-width: 768px) { + #reset-button { + position: relative; + right: auto; + top: auto; + transform: none; + margin-left: 10px; + } + + .chat-input { + padding: 12px 16px; + gap: 12px; + } +} + .action-button:hover { box-shadow: 0 4px 12px rgba(0,0,0,0.25); background: var(--falkor-accent); @@ -586,7 +685,7 @@ button, input, select, textarea { } .action-button{ - z-index: 2000; + z-index: 100; width: 48px; height: 48px; border: none; @@ -600,6 +699,7 @@ button, input, select, textarea { align-items: center; justify-content: center; padding: 0; + position: relative; } #reset-button:disabled { @@ -625,6 +725,14 @@ button, input, select, textarea { min-height: 56px; } +/* Mobile responsive suggestions */ +@media (max-width: 768px) { + .suggestions-container { + flex-direction: row; + gap: 6px; + } +} + .suggestion-item { display: flex; align-items: center; @@ -642,6 +750,22 @@ button, input, select, textarea { border-radius: 6px; } +/* Mobile responsive suggestion items */ +@media (max-width: 768px) { + .suggestion-item { + width: calc(33.333% - 4px); + min-width: calc(33.333% - 4px); + max-width: calc(33.333% - 4px); + min-height: 48px; + padding: 8px 6px; + } + + .suggestion-item button p { + font-size: 12px; + line-height: 1.2; + } +} + .suggestion-item button { display: flex; align-items: center; @@ -753,6 +877,15 @@ button, input, select, textarea { text-align: center; min-width: 680px; } + +/* Mobile responsive modal */ +@media (max-width: 768px) { + .pg-modal-content { + padding: 1.5em 2em; + min-width: 90vw; + margin: 0 5vw; + } +} .pg-modal-title { margin-bottom: 1em; color: var(--text-primary); @@ -1048,6 +1181,15 @@ button, input, select, textarea { display: none; } +/* Mobile responsive dropdown */ +@media (max-width: 768px) { + .user-profile-dropdown { + top: 65px; + right: 15px; + min-width: 180px; + } +} + .user-profile-dropdown.show { display: block; } @@ -1102,6 +1244,30 @@ button, input, select, textarea { transition: all 0.3s ease; } +/* Mobile responsive user profile and theme toggle */ +@media (max-width: 768px) { + .action-button { + width: 40px; + height: 40px; + } + .user-profile-btn { + top: 15px; + right: 15px; + width: 40px; + height: 40px; + } + + #theme-toggle-btn { + top: 15px; + right: 80px; + } + + .theme-icon { + width: 18px; + height: 18px; + } +} + /* Theme icon states */ [data-theme="dark"] .theme-icon .sun, [data-theme="system"] .theme-icon .sun { diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 2ef8bff8..b8f2f868 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -375,16 +375,27 @@ async function sendMessage() { } function toggleMenu() { + // Check if we're on mobile (768px breakpoint to match CSS) + const isMobile = window.innerWidth <= 768; + if (!menuContainer.classList.contains('open')) { menuContainer.classList.add('open'); sideMenuButton.style.display = 'none'; - chatContainer.style.paddingRight = '10%'; - chatContainer.style.paddingLeft = '10%'; + + // Only adjust padding on desktop, not mobile (mobile uses overlay) + if (!isMobile) { + chatContainer.style.paddingRight = '10%'; + chatContainer.style.paddingLeft = '10%'; + } } else { menuContainer.classList.remove('open'); sideMenuButton.style.display = 'block'; - chatContainer.style.paddingRight = '20%'; - chatContainer.style.paddingLeft = '20%'; + + // Only adjust padding on desktop, not mobile (mobile uses overlay) + if (!isMobile) { + chatContainer.style.paddingRight = '20%'; + chatContainer.style.paddingLeft = '20%'; + } } } @@ -894,4 +905,25 @@ document.addEventListener('DOMContentLoaded', function() { }; themeToggleBtn.title = titles[currentTheme]; } +}); + +// Handle window resize to ensure proper menu behavior across breakpoints +window.addEventListener('resize', function() { + const isMobile = window.innerWidth <= 768; + + // If menu is open and we switch to mobile, remove any desktop padding + if (isMobile && menuContainer.classList.contains('open')) { + chatContainer.style.paddingRight = ''; + chatContainer.style.paddingLeft = ''; + } + // If menu is open and we switch to desktop, apply desktop padding + else if (!isMobile && menuContainer.classList.contains('open')) { + chatContainer.style.paddingRight = '10%'; + chatContainer.style.paddingLeft = '10%'; + } + // If menu is closed and we're on desktop, ensure default desktop padding + else if (!isMobile && !menuContainer.classList.contains('open')) { + chatContainer.style.paddingRight = '20%'; + chatContainer.style.paddingLeft = '20%'; + } }); \ No newline at end of file From 736d590e7318424272a04c784fc4b2fecb33311a Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sun, 27 Jul 2025 14:48:58 +0300 Subject: [PATCH 49/58] remove dead code --- api/static/js/chat.js | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/api/static/js/chat.js b/api/static/js/chat.js index b8f2f868..43527ff1 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -143,22 +143,6 @@ function formatBlock(text) { return lineDiv; }); } - if (text.includes('\n')) { - return text.split('\n').map((line, i) => { - const lineDiv = document.createElement('div'); - lineDiv.className = 'plain-line'; - lineDiv.textContent = line; - return lineDiv; - }); - } - if (text.includes('\n')) { - return text.split('\n').map((line, i) => { - const lineDiv = document.createElement('div'); - lineDiv.className = 'plain-line'; - lineDiv.textContent = line; - return lineDiv; - }); - } } function initChat() { From cc9f1973da0b85278956033b4edc295c1bad180d Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sun, 27 Jul 2025 15:14:30 +0300 Subject: [PATCH 50/58] fix #65 disable chat input when no schema --- api/static/js/chat.js | 76 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 70 insertions(+), 6 deletions(-) diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 43527ff1..88d1a5f1 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -151,11 +151,20 @@ function initChat() { item.classList.remove('active'); }); chatMessages.innerHTML = ''; - [confValue, expValue, missValue, ambValue].forEach((element) => { + [confValue, expValue, missValue].forEach((element) => { element.innerHTML = ''; }); - addMessage('Hello! How can I help you today?', false); - suggestionsContainer.style.display = 'flex'; + + // Check if we have graphs available + const graphSelect = document.getElementById("graph-select"); + if (graphSelect && graphSelect.options.length > 0 && graphSelect.options[0].value) { + addMessage('Hello! How can I help you today?', false); + suggestionsContainer.style.display = 'flex'; + } else { + addMessage('Hello! Please select a graph from the dropdown above or upload a schema to get started.', false); + suggestionsContainer.style.display = 'none'; + } + questions_history = []; result_history = []; } @@ -175,6 +184,13 @@ async function sendMessage() { const message = messageInput.value.trim(); if (!message) return; + // Check if a graph is selected + const selectedValue = document.getElementById("graph-select").value; + if (!selectedValue) { + addMessage("Please select a graph from the dropdown before sending a message.", false, true); + return; + } + // Cancel any ongoing request if (currentRequestController) { currentRequestController.abort(); @@ -196,7 +212,6 @@ async function sendMessage() { }); try { - const selectedValue = document.getElementById("graph-select").value; // Create an AbortController for this request currentRequestController = new AbortController(); @@ -591,9 +606,35 @@ document.addEventListener("DOMContentLoaded", function () { // Fetch available graphs fetch("/graphs?token=" + TOKEN) - .then(response => response.json()) + .then(response => { + if (!response.ok) { + if (response.status === 401) { + throw new Error("Authentication required. Please log in to access graphs."); + } + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + return response.json(); + }) .then(data => { graphSelect.innerHTML = ""; + + if (!data || data.length === 0) { + // No graphs available + const option = document.createElement("option"); + option.value = ""; + option.textContent = "No graphs available"; + option.disabled = true; + graphSelect.appendChild(option); + + // Disable chat input when no graphs are available + messageInput.disabled = true; + submitButton.disabled = true; + messageInput.placeholder = "Please upload a schema or connect a database to start chatting"; + + addMessage("No graphs are available. Please upload a schema file or connect to a database to get started.", false); + return; + } + data.forEach(graph => { const option = document.createElement("option"); option.value = graph; @@ -602,6 +643,11 @@ document.addEventListener("DOMContentLoaded", function () { graphSelect.appendChild(option); }); + // Re-enable chat input when graphs are available + messageInput.disabled = false; + submitButton.disabled = false; + messageInput.placeholder = "Describe the SQL query you want..."; + // Fetch suggestions for the first graph (if any) if (data.length > 0) { fetchSuggestions(); @@ -609,7 +655,25 @@ document.addEventListener("DOMContentLoaded", function () { }) .catch(error => { console.error("Error fetching graphs:", error); - addMessage("Sorry, there was an error fetching the available graphs: " + error.message, false); + + // Show appropriate error message and disable input + if (error.message.includes("Authentication required")) { + addMessage("Authentication required. Please log in to access your graphs.", false); + // Don't disable input for auth errors as user needs to log in + } else { + addMessage("Sorry, there was an error fetching the available graphs: " + error.message, false); + messageInput.disabled = true; + submitButton.disabled = true; + messageInput.placeholder = "Cannot connect to server"; + } + + // Add a placeholder option to show the error state + graphSelect.innerHTML = ""; + const option = document.createElement("option"); + option.value = ""; + option.textContent = error.message.includes("Authentication") ? "Please log in" : "Error loading graphs"; + option.disabled = true; + graphSelect.appendChild(option); }); // Function to fetch suggestions based on selected graph From 297371e0fd6eb495ad4e8c1f12388a9cef2a0e63 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sun, 27 Jul 2025 16:33:13 +0300 Subject: [PATCH 51/58] fix #67 switch user and bot sides: --- api/static/css/chat.css | 28 ++++++++++++++-------------- api/static/js/chat.js | 15 ++++++++++----- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index f89bc537..ccd42fd3 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -210,12 +210,12 @@ button, input, select, textarea { } .user-message-container { - justify-content: flex-start; + justify-content: flex-end; position: relative; } /* Hide the default "User" text when avatar is present */ -.user-message-container.has-avatar::before { +.user-message-container.has-avatar::after { display: none; } @@ -225,20 +225,20 @@ button, input, select, textarea { width: 32px; border-radius: 50%; object-fit: cover; - margin-right: 10px; + margin-left: 10px; border: 2px solid var(--falkor-quaternary); } .bot-message-container, .followup-message-container, .final-result-message-container { - justify-content: flex-end; + justify-content: flex-start; } -.user-message-container::before, -.bot-message-container::after, -.followup-message-container::after, -.final-result-message-container::after { +.user-message-container::after, +.bot-message-container::before, +.followup-message-container::before, +.final-result-message-container::before { height: 32px; width: 32px; content: 'Bot'; @@ -248,27 +248,27 @@ button, input, select, textarea { justify-content: center; font-weight: 500; font-size: 16px; - margin-left: 10px; + margin-right: 10px; border-radius: 100%; padding: 4px; } -.user-message-container::before { - margin-right: 10px; +.user-message-container::after { + margin-left: 10px; background: var(--falkor-quaternary); content: 'User'; } -.bot-message-container::after { +.bot-message-container::before { background: color-mix(in srgb, var(--falkor-tertiary) 33%, transparent); } -.followup-message-container::after { +.followup-message-container::before { background: color-mix(in srgb, var(--falkor-tertiary) 40%, transparent); border: 1px solid var(--falkor-tertiary); } -.final-result-message-container::after { +.final-result-message-container::before { background: color-mix(in srgb, var(--accent-green) 40%, transparent); border: 1px solid var(--accent-green); } diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 88d1a5f1..0499ef53 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -36,6 +36,8 @@ function addMessage(message, isUser = false, isFollowup = false, isFinalResult = messageDiv.className = "message"; messageDivContainer.className = "message-container"; + let userAvatar = null; + if (isFollowup) { messageDivContainer.className += " followup-message-container"; messageDiv.className += " followup-message"; @@ -45,13 +47,12 @@ function addMessage(message, isUser = false, isFollowup = false, isFinalResult = messageDivContainer.className += " user-message-container"; messageDiv.className += " user-message"; - // Add user profile image if userInfo is provided + // Prepare user avatar if userInfo is provided if (userInfo && userInfo.picture) { - const userAvatar = document.createElement('img'); + userAvatar = document.createElement('img'); userAvatar.src = userInfo.picture; userAvatar.alt = userInfo.name || 'User'; userAvatar.className = 'user-message-avatar'; - messageDivContainer.appendChild(userAvatar); messageDivContainer.classList.add('has-avatar'); } @@ -60,7 +61,6 @@ function addMessage(message, isUser = false, isFollowup = false, isFinalResult = result_history.push(message); messageDivContainer.className += " final-result-message-container"; messageDiv.className += " final-result-message"; - // messageDiv.textContent = message; } else { messageDivContainer.className += " bot-message-container"; messageDiv.className += " bot-message"; @@ -70,7 +70,7 @@ function addMessage(message, isUser = false, isFollowup = false, isFinalResult = } } - const block = formatBlock(message) + const block = formatBlock(message); if (block) { block.forEach(lineDiv => { @@ -82,9 +82,14 @@ function addMessage(message, isUser = false, isFollowup = false, isFinalResult = if (!isLoading) { messageDivContainer.appendChild(messageDiv); + if (userAvatar) { + messageDivContainer.appendChild(userAvatar); + } } + chatMessages.appendChild(messageDivContainer); chatMessages.scrollTop = chatMessages.scrollHeight; + return messageDiv; } From f064d95a43d2884885f5bf423857e181a1164226 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sun, 27 Jul 2025 16:57:11 +0300 Subject: [PATCH 52/58] fix #68 show the first letter as alt --- api/static/css/chat.css | 11 +++++++++-- api/static/js/chat.js | 2 +- api/templates/chat.j2 | 2 +- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index ccd42fd3..091492aa 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -221,12 +221,17 @@ button, input, select, textarea { /* User message avatar styling */ .user-message-avatar { - height: 32px; - width: 32px; + height: 40px; + width: 40px; border-radius: 50%; object-fit: cover; margin-left: 10px; border: 2px solid var(--falkor-quaternary); + font-weight: 500; + font-size: 16px; + justify-content: center; + align-items: center; + display: flex; } .bot-message-container, @@ -1165,6 +1170,8 @@ button, input, select, textarea { height: 100%; border-radius: 50%; object-fit: cover; + font-weight: 500; + font-size: 22px; } /* User Profile Dropdown */ diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 0499ef53..3fcc01da 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -51,7 +51,7 @@ function addMessage(message, isUser = false, isFollowup = false, isFinalResult = if (userInfo && userInfo.picture) { userAvatar = document.createElement('img'); userAvatar.src = userInfo.picture; - userAvatar.alt = userInfo.name || 'User'; + userAvatar.alt = userInfo.name?.charAt(0).toUpperCase() || 'User'; userAvatar.className = 'user-message-avatar'; messageDivContainer.classList.add('has-avatar'); } diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index d1c2c4fc..7f8e5253 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -37,7 +37,7 @@ {% if is_authenticated and user_info %}
@@ -151,9 +155,12 @@
From 4fc97ea788095189b676f39120f23d3277bd9695 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sun, 27 Jul 2025 19:20:42 +0300 Subject: [PATCH 54/58] fix #58 add appoval message on reset --- api/static/css/chat.css | 68 +++++++++++++++++++++++++++++++++++++++-- api/static/js/chat.js | 35 ++++++++++++++++++++- api/templates/chat.j2 | 10 ++++++ 3 files changed, 110 insertions(+), 3 deletions(-) diff --git a/api/static/css/chat.css b/api/static/css/chat.css index b85a7055..cd2915e1 100644 --- a/api/static/css/chat.css +++ b/api/static/css/chat.css @@ -376,8 +376,7 @@ button, input, select, textarea { #message-input { color: var(--text-primary); - width: 100%; - height: 100%; + flex-grow: 1; background-color: transparent; border: none; font-size: 18px !important; @@ -423,6 +422,15 @@ button, input, select, textarea { display: none; } +/* Mobile responsive reset button */ +@media (max-width: 768px) { + .input-button { + width: 40px; + height: 40px; + } +} + + #menu-container { width: 0; min-width: 0; @@ -676,6 +684,11 @@ button, input, select, textarea { transform: translateY(-50%); } +#reset-button svg { + width: 20px; + height: 20px; +} + /* Mobile responsive reset button */ @media (max-width: 768px) { #reset-button { @@ -686,6 +699,11 @@ button, input, select, textarea { margin-left: 10px; } + #reset-button svg { + width: 18px; + height: 18px; + } + .chat-input { padding: 12px 16px; gap: 12px; @@ -1433,4 +1451,50 @@ button, input, select, textarea { border-color: #cccccc; transform: none; box-shadow: none; +} + +/* Reset Confirmation Modal */ +.reset-confirmation-modal { + display: none; + position: fixed; + top: 0; + left: 0; + width: 100vw; + height: 100vh; + background: rgba(0,0,0,0.6); + z-index: 2000; + align-items: center; + justify-content: center; +} + +.reset-confirmation-modal-content { + background: var(--falkor-secondary); + padding: 2em 3em; + border-radius: 10px; + box-shadow: 0 2px 16px rgba(0,0,0,0.2); + text-align: center; + color: var(--text-primary); + min-width: 400px; +} + +/* Mobile responsive reset modal */ +@media (max-width: 768px) { + .reset-confirmation-modal-content { + padding: 1.5em 2em; + min-width: 90vw; + margin: 0 5vw; + } +} + +.reset-confirmation-modal-content h3 { + color: var(--text-primary); + margin-bottom: 0.5em; + font-size: 1.3em; +} + +.reset-confirmation-modal-content p { + color: var(--text-secondary); + margin-bottom: 1.5em; + font-size: 1em; + line-height: 1.5; } \ No newline at end of file diff --git a/api/static/js/chat.js b/api/static/js/chat.js index 1f299343..7c4d4e74 100644 --- a/api/static/js/chat.js +++ b/api/static/js/chat.js @@ -588,7 +588,40 @@ menuButton.addEventListener('click', toggleMenu); sideMenuButton.addEventListener('click', toggleMenu); -newChatButton.addEventListener('click', initChat); +// Reset confirmation modal elements +const resetConfirmationModal = document.getElementById('reset-confirmation-modal'); +const resetConfirmBtn = document.getElementById('reset-confirm-btn'); +const resetCancelBtn = document.getElementById('reset-cancel-btn'); + +// Show reset confirmation modal instead of directly resetting +newChatButton.addEventListener('click', () => { + resetConfirmationModal.style.display = 'flex'; +}); + +// Handle reset confirmation +resetConfirmBtn.addEventListener('click', () => { + resetConfirmationModal.style.display = 'none'; + initChat(); +}); + +// Handle reset cancellation +resetCancelBtn.addEventListener('click', () => { + resetConfirmationModal.style.display = 'none'; +}); + +// Close modal when clicking outside of it +resetConfirmationModal.addEventListener('click', (e) => { + if (e.target === resetConfirmationModal) { + resetConfirmationModal.style.display = 'none'; + } +}); + +// Close modal with Escape key +document.addEventListener('keydown', (e) => { + if (e.key === 'Escape' && resetConfirmationModal.style.display === 'flex') { + resetConfirmationModal.style.display = 'none'; + } +}); // Add event listener to each suggestion item suggestionItems.forEach(item => { diff --git a/api/templates/chat.j2 b/api/templates/chat.j2 index bb2a3218..79bf8c71 100644 --- a/api/templates/chat.j2 +++ b/api/templates/chat.j2 @@ -179,6 +179,16 @@
+
+
+

Reset Session

+

Are you sure you want to reset the current session? This will clear all chat history and start a new conversation.

+
+ + +
+
+
{# Set authentication state for JS before loading chat.js #}