From fe90292a2ba8932688ce3e3b30474391a1f9bd9c Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 22 Aug 2025 19:18:24 +0300 Subject: [PATCH 1/3] add toolbar to mobile --- app/public/css/buttons.css | 91 +++++++++++++--- app/public/css/chat-components.css | 18 ++++ app/public/css/responsive.css | 85 +++++++++++++-- app/templates/components/left_toolbar.j2 | 127 +++++++++++++++++++---- 4 files changed, 276 insertions(+), 45 deletions(-) diff --git a/app/public/css/buttons.css b/app/public/css/buttons.css index d89cb3a4..500ea70c 100644 --- a/app/public/css/buttons.css +++ b/app/public/css/buttons.css @@ -343,11 +343,6 @@ background: linear-gradient(180deg, rgba(255,255,255,0.02), rgba(255,255,255,0.01)); } -.toolbar-button[aria-pressed="true"] { - background: var(--falkor-accent, rgba(150,120,220,0.22)); - transform: translateY(-4px) scale(1.03); -} - #toolbar-buttons { display: flex; flex-direction: column; @@ -364,19 +359,83 @@ flex-shrink: 0; } -@media (max-width: 768px) { - #left-toolbar { - left: 8px; - top: 12px; - bottom: auto; - width: auto; - height: 48px; - padding: 4px; - border-radius: 8px; - flex-direction: row; +/* Collapsed state: only show the burger button. Remove background/border/shadow so the bar is unobtrusive. */ +#left-toolbar.collapsed { + /* minimize width to avoid blocking content */ + width: 0px; + padding: 0; + transition: width 220ms ease, padding 220ms ease, background 200ms ease; + background: transparent; + box-shadow: none; + border: none; + align-items: flex-start; /* keep burger at top */ + overflow: visible; + position: relative; /* allow z-index for inner button */ + pointer-events: none; /* let clicks pass through, but burger will override */ +} + +#left-toolbar.collapsed .toolbar-button { + display: none; /* hide all toolbar buttons */ +} + +#left-toolbar.collapsed #burger-toggle-btn { + display: flex; + width: 48px; + height: 48px; + margin: 6px 0; + position: absolute; + top: 6px; + left: 6px; + z-index: 9999; /* keep on top so it's always clickable */ + pointer-events: auto; +} + +/* Mobile fallback: without JS, default to only the burger visible on small viewports */ +@media (max-width: 767px) { + #left-toolbar.collapsed { + width: 0px; + padding: 0; + background: transparent; + box-shadow: none; + border: none; + align-items: flex-start; + pointer-events: none; + } + + /* When NOT collapsed on mobile, show full toolbar */ + #left-toolbar:not(.collapsed) { + width: 48px; + padding: 6px 6px; + background: var(--bg-tertiary, rgba(255,255,255,0.02)); + box-shadow: 0 6px 18px rgba(0,0,0,0.25); + border: 1px solid var(--border-color, rgba(255,255,255,0.03)); + backdrop-filter: blur(6px); align-items: center; + pointer-events: auto; + z-index: 1050; + } + + /* On mobile hide buttons only when explicitly collapsed (so burger can open toolbar) */ + #left-toolbar.collapsed .toolbar-button { display: none; } + #left-toolbar.collapsed #burger-toggle-btn { + display: flex; + width: 48px; + height: 48px; + position: absolute; + top: 6px; + left: 6px; + pointer-events: auto; + } + /* When not collapsed, show toolbar buttons as normal */ + #left-toolbar:not(.collapsed) .toolbar-button { display: flex; } + #left-toolbar:not(.collapsed) #burger-toggle-btn { + display: flex; + position: relative; + top: auto; + left: auto; + pointer-events: auto; } - #left-toolbar-inner { flex-direction: row; gap: 6px; } } + diff --git a/app/public/css/chat-components.css b/app/public/css/chat-components.css index 78c05656..df098210 100644 --- a/app/public/css/chat-components.css +++ b/app/public/css/chat-components.css @@ -28,6 +28,24 @@ background-color: var(--falkor-secondary); } +@media (max-width: 768px) { + .chat-messages { + padding: 2px; + } + + .chat-container { + padding-right: 10px; + padding-left: 10px; + padding-top: 10px; + padding-bottom: 10px; + /* Ensure no horizontal overflow */ + max-width: 100vw; + overflow-x: hidden; + box-sizing: border-box; + /* Ensure content is accessible when toolbar is open */ + transition: margin-left 220ms ease; + } +} .message-container { display: flex; flex-direction: row; diff --git a/app/public/css/responsive.css b/app/public/css/responsive.css index 01be5d39..0fc73bb5 100644 --- a/app/public/css/responsive.css +++ b/app/public/css/responsive.css @@ -2,11 +2,55 @@ /* Layout Responsive */ @media (max-width: 768px) { - .chat-container { - padding-right: 10px; - padding-left: 10px; - padding-top: 10px; - padding-bottom: 10px; + /* When left toolbar is open, push content to make room */ + body.left-toolbar-open .chat-container { + max-width: calc(100vw - 48px); + } + + /* Ensure body doesn't overflow */ + body { + max-width: 100vw; + overflow-x: hidden; + } + + /* Main container mobile adjustments */ + #container { + width: 100%; + max-width: 100vw; + overflow-x: hidden; + box-sizing: border-box; + transition: margin-left 220ms ease; + } + + /* Adjust container when toolbar is open */ + body.left-toolbar-open #container { + margin-left: 48px; + width: calc(100% - 48px); + max-width: calc(100vw - 48px); + } + + /* Prevent any background showing behind the shifted content */ + body.left-toolbar-open { + background: var(--falkor-secondary); + } + + /* Ensure chat container fills the available space properly */ + body.left-toolbar-open .chat-container { + background: var(--falkor-secondary); + border-left: none; + } + + /* Ensure chat header elements are properly positioned when toolbar is open */ + body.left-toolbar-open .chat-header { + padding-left: 15px; /* Add extra padding to prevent overlap */ + width: 100%; + box-sizing: border-box; + } + + /* Ensure dropdown and buttons in header don't get cut off */ + body.left-toolbar-open .chat-header > * { + margin-left: 0; + width: 100%; } } @@ -99,8 +143,35 @@ } .chat-input { - padding: 12px 16px; - gap: 12px; + padding: 8px 12px; + gap: 8px; + /* Ensure it fits within viewport width */ + margin: 0; + box-sizing: border-box; + } + + .input-container { + /* Reduce padding on mobile to maximize input space */ + padding: 8px; + gap: 4px; + min-width: 0; /* Allow container to shrink */ + flex-shrink: 1; + } + + #message-input { + font-size: 16px !important; /* Prevent zoom on iOS */ + min-width: 0; /* Allow input to shrink */ + } + + #message-input::placeholder { + font-size: 16px !important; + } + + /* Ensure buttons don't overflow */ + .input-button { + width: 40px; + height: 40px; + flex-shrink: 0; /* Prevent buttons from shrinking */ } } diff --git a/app/templates/components/left_toolbar.j2 b/app/templates/components/left_toolbar.j2 index b706f464..c7310a4b 100644 --- a/app/templates/components/left_toolbar.j2 +++ b/app/templates/components/left_toolbar.j2 @@ -1,34 +1,39 @@ {# Left vertical toolbar to host expandable buttons (keeps future buttons easy to add) #} \ No newline at end of file + + + \ No newline at end of file From d512568c78e46c9fa3db6831a1ec23412bb3cade Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Fri, 22 Aug 2025 19:41:24 +0300 Subject: [PATCH 2/3] set select size --- app/public/css/responsive.css | 104 ++++++++++++++++++++++++++++++---- 1 file changed, 92 insertions(+), 12 deletions(-) diff --git a/app/public/css/responsive.css b/app/public/css/responsive.css index 0fc73bb5..57f16e7c 100644 --- a/app/public/css/responsive.css +++ b/app/public/css/responsive.css @@ -100,22 +100,102 @@ /* Button Container Responsive */ @media (max-width: 768px) { - .button-container { - flex-direction: row; + .chat-header .button-container { + width: 100%; + max-width: 100%; + padding: 0 10px; + margin: 0; + box-sizing: border-box; gap: 8px; + flex-wrap: nowrap; + align-items: stretch; + } + + /* Hide vertical separators on mobile */ + .vertical-separator { + display: none; + } + + /* Make selectors and buttons fit screen width */ + #graph-select, + #custom-file-upload { + flex: 1; + min-width: 0; + padding: 8px 6px; + font-size: 13px; + text-overflow: ellipsis; + white-space: nowrap; + overflow: hidden; + height: 40px; + box-sizing: border-box; + } + + #graph-select { + max-width: 30%; + } + + #custom-file-upload { + max-width: 35%; + text-align: center; + display: flex; + align-items: center; + justify-content: center; + cursor: pointer; + } + + .dropdown-container { + flex: 1; + max-width: 35%; + } + + .custom-dropdown { + width: 100%; + height: 40px; + } + + .dropdown-selected { + padding: 8px 6px; + font-size: 13px; + height: 100%; + box-sizing: border-box; + display: flex; + align-items: center; + justify-content: space-between; + } + + .dropdown-text { + text-overflow: ellipsis; + white-space: nowrap; + overflow: hidden; + flex: 1; + } + + .dropdown-arrow { + margin-left: 4px; + flex-shrink: 0; } } -/* Select Elements Responsive */ -@media (max-width: 768px) { +@media (max-width: 480px) { + .chat-header .button-container { + padding: 0; + gap: 5px; + } + #graph-select, - #custom-file-upload, - #open-pg-modal { - min-width: 120px; - width: auto; - padding: 8px 10px; - font-size: 14px; - flex: 1; + #custom-file-upload { + padding: 6px 4px; + font-size: 12px; + height: 36px; + } + + .custom-dropdown { + height: 36px; + } + + .dropdown-selected { + padding: 6px 4px; + font-size: 12px; } } @@ -143,7 +223,7 @@ } .chat-input { - padding: 8px 12px; + padding: 0px; gap: 8px; /* Ensure it fits within viewport width */ margin: 0; From 3584188af617060b3d5bb8e35d5de5110c9164f9 Mon Sep 17 00:00:00 2001 From: Guy Korland Date: Sat, 23 Aug 2025 14:10:11 +0300 Subject: [PATCH 3/3] move to logger --- Pipfile | 1 + Pipfile.lock | 17 +- api/app_factory.py | 12 +- api/auth/oauth_handlers.py | 12 +- api/auth/user_management.py | 107 +++++---- api/graph.py | 24 +- api/loaders/mysql_loader.py | 33 +-- api/loaders/postgres_loader.py | 19 +- api/routes/auth.py | 43 ++-- api/routes/database.py | 13 +- api/routes/graphs.py | 401 ++++++++++++++++++--------------- 11 files changed, 373 insertions(+), 309 deletions(-) diff --git a/Pipfile b/Pipfile index 7c2c0c2e..7b7fc4e9 100644 --- a/Pipfile +++ b/Pipfile @@ -16,6 +16,7 @@ jsonschema = "~=4.25.0" tqdm = "~=4.67.1" python-multipart = "~=0.0.10" jinja2 = "~=3.1.4" +structlog = "~=25.4.0" [dev-packages] pytest = "~=8.4.1" diff --git a/Pipfile.lock b/Pipfile.lock index 09375fad..390da0f4 100644 --- a/Pipfile.lock +++ b/Pipfile.lock @@ -1,7 +1,7 @@ { "_meta": { "hash": { - "sha256": "4573e7a507f4e8c5c10787062953a7b4219cdf0ee3f0a3130c36c4abdd067f51" + "sha256": "da423ef9014a6e9d808fa16bf2198164862017f245951e30453d5f96feda6d33" }, "pipfile-spec": 6, "requires": { @@ -900,11 +900,11 @@ }, "openai": { "hashes": [ - "sha256:54d3457b2c8d7303a1bc002a058de46bdd8f37a8117751c7cf4ed4438051f151", - "sha256:787b4c3c8a65895182c58c424f790c25c790cc9a0330e34f73d55b6ee5a00e32" + "sha256:29f56df2236069686e64aca0e13c24a4ec310545afb25ef7da2ab1a18523f22d", + "sha256:6539a446cce154f8d9fb42757acdfd3ed9357ab0d34fcac11096c461da87133b" ], "markers": "python_version >= '3.8'", - "version": "==1.100.2" + "version": "==1.101.0" }, "packaging": { "hashes": [ @@ -1617,6 +1617,15 @@ "markers": "python_version >= '3.9'", "version": "==0.47.2" }, + "structlog": { + "hashes": [ + "sha256:186cd1b0a8ae762e29417095664adf1d6a31702160a46dacb7796ea82f7409e4", + "sha256:fe809ff5c27e557d14e613f45ca441aabda051d119ee5a0102aaba6ce40eed2c" + ], + "index": "pypi", + "markers": "python_version >= '3.8'", + "version": "==25.4.0" + }, "tiktoken": { "hashes": [ "sha256:10331d08b5ecf7a780b4fe4d0281328b23ab22cdb4ff65e68d56caeda9940ecc", diff --git a/api/app_factory.py b/api/app_factory.py index cb53445d..6454de08 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,6 +1,5 @@ """Application factory for the text2sql FastAPI app.""" -import logging import os import secrets @@ -11,13 +10,14 @@ from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.base import BaseHTTPMiddleware +from api.logging_config import get_logger from api.routes.auth import auth_router, init_auth from api.routes.graphs import graphs_router from api.routes.database import database_router # Load environment variables from .env file load_dotenv() -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = get_logger(__name__) class SecurityMiddleware(BaseHTTPMiddleware): @@ -55,7 +55,7 @@ def create_app(): secret_key = os.getenv("FASTAPI_SECRET_KEY") if not secret_key: secret_key = secrets.token_hex(32) - logging.warning("FASTAPI_SECRET_KEY not set, using generated key. Set this in production!") + logger.warning("FASTAPI_SECRET_KEY not set, using generated key. Set this in production!") # Add session middleware with explicit settings to ensure OAuth state persists app.add_middleware( @@ -87,8 +87,10 @@ def create_app(): async def handle_oauth_error(request: Request, exc: Exception): """Handle OAuth-related errors gracefully""" # Check if it's an OAuth-related error - if "token" in str(exc).lower() or "oauth" in str(exc).lower(): - logging.warning("OAuth error occurred: %s", exc) + exc_text = str(exc) + exc_text_lower = exc_text.lower() + if "token" in exc_text_lower or "oauth" in exc_text_lower: + logger.warning("OAuth error occurred", error=exc_text) request.session.clear() return RedirectResponse(url="/", status_code=302) diff --git a/api/auth/oauth_handlers.py b/api/auth/oauth_handlers.py index 4e58c5a9..d6b9e94b 100644 --- a/api/auth/oauth_handlers.py +++ b/api/auth/oauth_handlers.py @@ -4,14 +4,16 @@ callbacks can invoke them when processing OAuth responses. """ -import logging from typing import Dict, Any from fastapi import FastAPI, Request from authlib.integrations.starlette_client import OAuth +from api.logging_config import get_logger from .user_management import ensure_user_in_organizations +logger = get_logger(__name__) + def setup_oauth_handlers(app: FastAPI, oauth: OAuth): """Set up OAuth handlers for both Google and GitHub.""" @@ -30,7 +32,7 @@ async def handle_google_callback(_request: Request, # Validate required fields if not user_id or not email: - logging.error("Missing required fields from Google OAuth response") + logger.error("Missing required fields from Google OAuth response") return False # Check if identity exists in Organizations graph, create if new @@ -44,7 +46,7 @@ async def handle_google_callback(_request: Request, return True except Exception as exc: # capture exception for logging - logging.error("Error handling Google OAuth callback: %s", exc) + logger.error("Error handling Google OAuth callback", error=str(exc)) return False async def handle_github_callback(_request: Request, @@ -58,7 +60,7 @@ async def handle_github_callback(_request: Request, # Validate required fields if not user_id or not email: - logging.error("Missing required fields from GitHub OAuth response") + logger.error("Missing required fields from GitHub OAuth response") return False # Check if identity exists in Organizations graph, create if new @@ -72,7 +74,7 @@ async def handle_github_callback(_request: Request, return True except Exception as exc: # capture exception for logging - logging.error("Error handling GitHub OAuth callback: %s", exc) + logger.error("Error handling GitHub OAuth callback", error=str(exc)) return False # Store handlers in app state for use in route callbacks diff --git a/api/auth/user_management.py b/api/auth/user_management.py index 3e2ffc9a..6bd2c934 100644 --- a/api/auth/user_management.py +++ b/api/auth/user_management.py @@ -1,16 +1,16 @@ """User management and authentication functions for text2sql API.""" -import logging import time from functools import wraps from typing import Tuple, Optional, Dict, Any -import requests from fastapi import Request, HTTPException, status -from fastapi.responses import JSONResponse from authlib.integrations.starlette_client import OAuth from api.extensions import db +from api.logging_config import get_logger + +logger = get_logger(__name__) def ensure_user_in_organizations(provider_user_id, email, name, provider, picture=None): @@ -22,19 +22,24 @@ def ensure_user_in_organizations(provider_user_id, email, name, provider, pictur """ # Input validation if not provider_user_id or not email or not provider: - logging.error("Missing required parameters: provider_user_id=%s, email=%s, provider=%s", - provider_user_id, email, provider) + logger.error( + "Missing required parameters", + provider_user_id=provider_user_id, + email=email, + provider=provider, + ) + return False, None return False, None # Validate email format (basic check) if "@" not in email or "." not in email: - logging.error("Invalid email format: %s", email) + logger.error("Invalid email format", email=email) return False, None # Validate provider is in allowed list allowed_providers = ["google", "github"] if provider not in allowed_providers: - logging.error("Invalid provider: %s", provider) + logger.error("Invalid provider", provider=provider) return False, None try: @@ -99,28 +104,37 @@ def ensure_user_in_organizations(provider_user_id, email, name, provider, pictur # Determine the type of operation for logging if is_new_identity and not had_other_identities: # Brand new user (first identity) - logging.info("NEW USER CREATED: provider=%s, provider_user_id=%s, " - "email=%s, name=%s", provider, provider_user_id, email, name) + logger.info( + "NEW USER CREATED", + provider=provider, + provider_user_id=provider_user_id, + email=email, + name=name, + ) return True, {"identity": identity, "user": user} elif is_new_identity and had_other_identities: # New identity for existing user (cross-provider linking) - logging.info("NEW IDENTITY LINKED TO EXISTING USER: provider=%s, " - "provider_user_id=%s, email=%s, name=%s", - provider, provider_user_id, email, name) + logger.info( + "NEW IDENTITY LINKED TO EXISTING_USER", + provider=provider, + provider_user_id=provider_user_id, + email=email, + name=name, + ) return True, {"identity": identity, "user": user} else: # Existing identity login - logging.info("Existing identity found: provider=%s, email=%s", provider, email) + logger.info("Existing identity found", provider=provider, email=email) return False, {"identity": identity, "user": user} else: - logging.error("Failed to create/update identity and user: email=%s", email) + logger.error("Failed to create/update identity and user", email=email) return False, None except (AttributeError, ValueError, KeyError) as e: - logging.error("Error managing user in Organizations graph: %s", e) + logger.error("Error managing user in Organizations graph", error=str(e)) return False, None except Exception as e: - logging.error("Unexpected error managing user in Organizations graph: %s", e) + logger.error("Unexpected error managing user in Organizations graph", error=str(e)) return False, None @@ -128,14 +142,17 @@ def update_identity_last_login(provider, provider_user_id): """Update the last login timestamp for an existing identity""" # Input validation if not provider or not provider_user_id: - logging.error("Missing required parameters: provider=%s, provider_user_id=%s", - provider, provider_user_id) + logger.error( + "Missing required parameters", + provider=provider, + provider_user_id=provider_user_id, + ) return # Validate provider is in allowed list allowed_providers = ["google", "github"] if provider not in allowed_providers: - logging.error("Invalid provider: %s", provider) + logger.error("Invalid provider", provider=provider) return try: @@ -145,18 +162,32 @@ def update_identity_last_login(provider, provider_user_id): SET identity.last_login = timestamp() RETURN identity """ - organizations_graph.query(update_query, { - "provider": provider, - "provider_user_id": provider_user_id - }) - logging.info("Updated last login for identity: provider=%s, provider_user_id=%s", - provider, provider_user_id) + organizations_graph.query( + update_query, + { + "provider": provider, + "provider_user_id": provider_user_id, + }, + ) + logger.info( + "Updated last login for identity", + provider=provider, + provider_user_id=provider_user_id, + ) except (AttributeError, ValueError, KeyError) as e: - logging.error("Error updating last login for identity %s/%s: %s", - provider, provider_user_id, e) + logger.error( + "Error updating last login for identity", + provider=provider, + provider_user_id=provider_user_id, + error=str(e), + ) except Exception as e: - logging.error("Unexpected error updating last login for identity %s/%s: %s", - provider, provider_user_id, e) + logger.error( + "Unexpected error updating last login for identity", + provider=provider, + provider_user_id=provider_user_id, + error=str(e), + ) async def validate_and_cache_user(request: Request) -> Tuple[Optional[Dict[str, Any]], bool]: @@ -192,9 +223,9 @@ async def validate_and_cache_user(request: Request) -> Tuple[Optional[Dict[str, ) request.session["google_token"] = new_token resp = await oauth.google.get("/oauth2/v2/userinfo", token=new_token) - logging.info("Google access token refreshed successfully") + logger.info("Google access token refreshed successfully") except Exception as e: - logging.error("Google token refresh failed: %s", e) + logger.error("Google token refresh failed", error=str(e)) request.session.pop("google_token", None) request.session.pop("user_info", None) return None, False @@ -202,7 +233,7 @@ async def validate_and_cache_user(request: Request) -> Tuple[Optional[Dict[str, if resp.status_code == 200: google_user = resp.json() if not google_user.get("id") or not google_user.get("email"): - logging.warning("Invalid Google user data received") + logger.warning("Invalid Google user data received") request.session.pop("google_token", None) request.session.pop("user_info", None) return None, False @@ -219,7 +250,7 @@ async def validate_and_cache_user(request: Request) -> Tuple[Optional[Dict[str, request.session["token_validated_at"] = current_time return user_info, True except Exception as e: - logging.warning("Google OAuth validation error: %s", e) + logger.warning("Google OAuth validation error", error=str(e)) request.session.pop("google_token", None) request.session.pop("user_info", None) @@ -231,7 +262,7 @@ async def validate_and_cache_user(request: Request) -> Tuple[Optional[Dict[str, if resp.status_code == 200: github_user = resp.json() if not github_user.get("id"): - logging.warning("Invalid GitHub user data received") + logger.warning("Invalid GitHub user data received") request.session.pop("github_token", None) request.session.pop("user_info", None) return None, False @@ -248,7 +279,7 @@ async def validate_and_cache_user(request: Request) -> Tuple[Optional[Dict[str, email = email_resp.json()[0].get("email") if not email: - logging.warning("No email found for GitHub user") + logger.warning("No email found for GitHub user") request.session.pop("github_token", None) request.session.pop("user_info", None) return None, False @@ -264,7 +295,7 @@ async def validate_and_cache_user(request: Request) -> Tuple[Optional[Dict[str, request.session["token_validated_at"] = current_time return user_info, True except Exception as e: - logging.warning("GitHub OAuth validation error: %s", e) + logger.warning("GitHub OAuth validation error", error=str(e)) request.session.pop("github_token", None) request.session.pop("user_info", None) @@ -273,7 +304,7 @@ async def validate_and_cache_user(request: Request) -> Tuple[Optional[Dict[str, return None, False except Exception as e: - logging.error("Unexpected error in validate_and_cache_user: %s", e) + logger.error("Unexpected error in validate_and_cache_user", error=str(e)) request.session.pop("user_info", None) return None, False @@ -312,7 +343,7 @@ async def wrapper(request: Request, *args, **kwargs): except HTTPException: raise except Exception as e: - logging.error("Unexpected error in token_required: %s", e) + logger.error("Unexpected error in token_required", error=str(e)) request.session.pop("user_info", None) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, diff --git a/api/graph.py b/api/graph.py index bbbf810c..f26ffb7b 100644 --- a/api/graph.py +++ b/api/graph.py @@ -1,17 +1,16 @@ """Module to handle the graph data loading into the database.""" import json -import logging from itertools import combinations from typing import List, Tuple from litellm import completion from pydantic import BaseModel +from api.logging_config import get_logger from api.config import Config from api.extensions import db - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = get_logger(__name__) class TableDescription(BaseModel): @@ -61,10 +60,7 @@ def find(graph_id: str, queries_history: List[str], user_query = queries_history[-1] previous_queries = queries_history[:-1] - logging.info( - "Calling to an LLM to find relevant tables and columns for the query: %s", - user_query - ) + logger.info("LLM: find relevant tables/columns", query=user_query) # Call the completion model to get the relevant Cypher queries to retrieve # from the Graph that represent the Database schema. # The completion model will generate a set of Cypher query to retrieve the relevant nodes. @@ -94,16 +90,22 @@ def find(graph_id: str, queries_history: List[str], # Parse JSON string and convert to Pydantic model json_data = json.loads(json_str) descriptions = Descriptions(**json_data) - logging.info("Find tables based on: %s", descriptions.tables_descriptions) + logger.info( + "Find tables based on tables descriptions", + tables=descriptions.tables_descriptions, + ) tables_des = _find_tables(graph, descriptions.tables_descriptions) - logging.info("Find tables based on columns: %s", descriptions.columns_descriptions) + logger.info( + "Find tables based on columns descriptions", + columns=descriptions.columns_descriptions, + ) tables_by_columns_des = _find_tables_by_columns(graph, descriptions.columns_descriptions) # table names for sphere and route extraction base_tables_names = [table[0] for table in tables_des] - logging.info("Extracting tables by sphere") + logger.info("Extracting tables by sphere") tables_by_sphere = _find_tables_sphere(graph, base_tables_names) - logging.info("Extracting tables by connecting routes %s", base_tables_names) + logger.info("Extracting tables by connecting routes", tables=base_tables_names) tables_by_route, _ = find_connecting_tables(graph, base_tables_names) combined_tables = _get_unique_tables( tables_des + tables_by_columns_des + tables_by_route + tables_by_sphere diff --git a/api/loaders/mysql_loader.py b/api/loaders/mysql_loader.py index 24577d18..0c1531cb 100644 --- a/api/loaders/mysql_loader.py +++ b/api/loaders/mysql_loader.py @@ -2,7 +2,6 @@ import datetime import decimal -import logging import re from typing import Tuple, Dict, Any, List @@ -10,11 +9,10 @@ import pymysql from pymysql.cursors import DictCursor - +from api.logging_config import get_logger from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = get_logger(__name__) class MySQLLoader(BaseLoader): @@ -416,7 +414,10 @@ def refresh_graph_schema(graph_id: str, db_url: str) -> Tuple[bool, str]: Tuple of (success, message) """ try: - logging.info("Schema modification detected. Refreshing graph schema for: %s", graph_id) + logger.info( + "Schema modification detected. Refreshing graph schema for", + graph_id=graph_id, + ) # Import here to avoid circular imports from api.extensions import db @@ -439,17 +440,16 @@ def refresh_graph_schema(graph_id: str, db_url: str) -> Tuple[bool, str]: success, message = MySQLLoader.load(prefix, db_url) if success: - logging.info("Graph schema refreshed successfully.") + logger.info("Graph schema refreshed successfully.") return True, message - - logging.error("Schema refresh failed for graph %s: %s", graph_id, message) + logger.error("Schema refresh failed for graph", graph_id=graph_id, error=message) return False, "Failed to reload schema" except Exception as e: # Log the error and return failure - logging.error("Error refreshing graph schema: %s", str(e)) + logger.error("Error refreshing graph schema", error=str(e)) error_msg = "Error refreshing graph schema" - logging.error(error_msg) + logger.error(error_msg) return False, error_msg @staticmethod @@ -516,19 +516,6 @@ def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: return result_list - except pymysql.MySQLError as e: - # Rollback in case of error - if 'conn' in locals(): - conn.rollback() - cursor.close() - conn.close() - except pymysql.MySQLError as e: - # Rollback in case of error - if 'conn' in locals(): - conn.rollback() - cursor.close() - conn.close() - raise Exception(f"MySQL query execution error: {str(e)}") from e except Exception as e: # Rollback in case of error if 'conn' in locals(): diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index 2b41262a..cf5b136f 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -2,17 +2,16 @@ import datetime import decimal -import logging import re from typing import Tuple, Dict, Any, List import psycopg2 import tqdm +from api.logging_config import get_logger from api.loaders.base_loader import BaseLoader from api.loaders.graph_loader import load_to_graph - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = get_logger(__name__) class PostgresLoader(BaseLoader): @@ -377,7 +376,10 @@ def refresh_graph_schema(graph_id: str, db_url: str) -> Tuple[bool, str]: Tuple of (success, message) """ try: - logging.info("Schema modification detected. Refreshing graph schema for: %s", graph_id) + logger.info( + "Schema modification detected. Refreshing graph schema for", + graph_id=graph_id, + ) # Import here to avoid circular imports from api.extensions import db @@ -400,17 +402,16 @@ def refresh_graph_schema(graph_id: str, db_url: str) -> Tuple[bool, str]: success, message = PostgresLoader.load(prefix, db_url) if success: - logging.info("Graph schema refreshed successfully.") + logger.info("Graph schema refreshed successfully.") return True, message - - logging.error("Schema refresh failed for graph %s: %s", graph_id, message) + logger.error("Schema refresh failed for graph", graph_id=graph_id, error=message) return False, "Failed to reload schema" except Exception as e: # Log the error and return failure - logging.error("Error refreshing graph schema: %s", str(e)) + logger.error("Error refreshing graph schema", error=str(e)) error_msg = "Error refreshing graph schema" - logging.error(error_msg) + logger.error(error_msg) return False, error_msg @staticmethod diff --git a/api/routes/auth.py b/api/routes/auth.py index 5ada36c3..6efa3d38 100644 --- a/api/routes/auth.py +++ b/api/routes/auth.py @@ -1,6 +1,5 @@ """Authentication routes for the text2sql API.""" -import logging import os import time from pathlib import Path @@ -13,12 +12,14 @@ from authlib.common.errors import AuthlibBaseError from starlette.config import Config +from api.logging_config import get_logger from api.auth.user_management import validate_and_cache_user # Router auth_router = APIRouter() TEMPLATES_DIR = str((Path(__file__).resolve().parents[1] / "../app/templates").resolve()) templates = Jinja2Templates(directory=TEMPLATES_DIR) +logger = get_logger(__name__) # ---- Helpers ---- def _get_provider_client(request: Request, provider: str): @@ -104,10 +105,9 @@ async def login_google(request: Request) -> RedirectResponse: # Helpful hint if localhost vs 127.0.0.1 mismatch is likely if not os.getenv("OAUTH_BASE_URL") and "127.0.0.1" in str(request.base_url): - logging.warning( - "OAUTH_BASE_URL not set and base URL is 127.0.0.1; " - "if your Google OAuth app uses 'http://localhost:5000', " - "set OAUTH_BASE_URL=http://localhost:5000 to avoid redirect_uri mismatch." + logger.warning( + "OAUTH_BASE_URL not set and base URL is 127.0.0.1; set OAUTH_BASE_URL to" + " avoid redirect mismatch." ) return await google.authorize_redirect(request, redirect_uri) @@ -122,13 +122,13 @@ async def google_authorized(request: Request) -> RedirectResponse: # Always fetch userinfo explicitly resp = await google.get("https://www.googleapis.com/oauth2/v2/userinfo", token=token) if resp.status_code != 200: - logging.error("Failed to fetch Google user info: %s", resp.text) + logger.error("Failed to fetch Google user info", body=resp.text) _clear_auth_session(request.session) return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) user_info = resp.json() if not user_info.get("email"): - logging.error("Invalid Google user data received") + logger.error("Invalid Google user data received") _clear_auth_session(request.session) return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) @@ -146,7 +146,7 @@ async def google_authorized(request: Request) -> RedirectResponse: return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) except AuthlibBaseError as e: - logging.error("Google OAuth error: %s", e) + logger.error("Google OAuth error", error=str(e)) _clear_auth_session(request.session) return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) @@ -165,7 +165,7 @@ async def login_github(request: Request) -> RedirectResponse: # Helpful hint if localhost vs 127.0.0.1 mismatch is likely if not os.getenv("OAUTH_BASE_URL") and "127.0.0.1" in str(request.base_url): - logging.warning( + logger.warning( "OAUTH_BASE_URL not set and base URL is 127.0.0.1; " "if your GitHub OAuth app uses 'http://localhost:5000', " "set OAUTH_BASE_URL=http://localhost:5000 to avoid redirect_uri mismatch." @@ -183,12 +183,11 @@ async def github_authorized(request: Request) -> RedirectResponse: # Fetch GitHub user info resp = await github.get("https://api.github.com/user", token=token) if resp.status_code != 200: - logging.error("Failed to fetch GitHub user info: %s", resp.text) + logger.error("Failed to fetch GitHub user info", body=resp.text) _clear_auth_session(request.session) return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) user_info = resp.json() - # Get user email if not public email = user_info.get("email") if not email: @@ -202,7 +201,7 @@ async def github_authorized(request: Request) -> RedirectResponse: break if not user_info.get("id") or not email: - logging.error("Invalid GitHub user data received") + logger.error("Invalid GitHub user data received") _clear_auth_session(request.session) return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) @@ -220,7 +219,7 @@ async def github_authorized(request: Request) -> RedirectResponse: return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) except AuthlibBaseError as e: - logging.error("GitHub OAuth error: %s", e) + logger.error("GitHub OAuth error", error=str(e)) _clear_auth_session(request.session) return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) @@ -256,19 +255,19 @@ async def logout(request: Request) -> RedirectResponse: headers={"content-type": "application/x-www-form-urlencoded"}, ) if resp.status_code != 200: - logging.warning( - "Google token revoke failed (%s): %s", - resp.status_code, - resp.text, + logger.warning( + "Google token revoke failed", + status=resp.status_code, + body=resp.text, ) else: - logging.info("Successfully revoked Google token") + logger.info("Successfully revoked Google token") except Exception as e: - logging.error("Error revoking Google tokens: %s", e) + logger.error("Error revoking Google tokens", error=str(e)) # ---- Handle GitHub tokens ---- if github_token: - logging.info("GitHub token found, clearing from session (no remote revoke available).") + logger.info("GitHub token found, clearing from session (no remote revoke available).") # GitHub logout is local only unless we call the App management API # ---- Clear session auth keys ---- @@ -287,7 +286,7 @@ def init_auth(app): google_client_id = os.getenv("GOOGLE_CLIENT_ID") google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET") if not google_client_id or not google_client_secret: - logging.warning("Google OAuth env vars not set; login will fail until configured.") + logger.warning("Google OAuth env vars not set; login will fail until configured.") oauth.register( name="google", @@ -300,7 +299,7 @@ def init_auth(app): github_client_id = os.getenv("GITHUB_CLIENT_ID") github_client_secret = os.getenv("GITHUB_CLIENT_SECRET") if not github_client_id or not github_client_secret: - logging.warning("GitHub OAuth env vars not set; login will fail until configured.") + logger.warning("GitHub OAuth env vars not set; login will fail until configured.") oauth.register( name="github", diff --git a/api/routes/database.py b/api/routes/database.py index 4254d45f..a9e0b821 100644 --- a/api/routes/database.py +++ b/api/routes/database.py @@ -1,15 +1,16 @@ """Database connection routes for the text2sql API.""" -import logging from fastapi import APIRouter, Request, HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel -from api.auth.user_management import token_required from api.loaders.postgres_loader import PostgresLoader from api.loaders.mysql_loader import MySQLLoader +from api.auth.user_management import token_required +from api.logging_config import get_logger database_router = APIRouter() +logger = get_logger(__name__) class DatabaseConnectionRequest(BaseModel): @@ -47,7 +48,7 @@ async def connect_database(request: Request, db_request: DatabaseConnectionReque # Attempt to connect/load using the PostgreSQL loader success, result = PostgresLoader.load(request.state.user_id, url) except (ValueError, ConnectionError) as e: - logging.error("PostgreSQL connection error: %s", str(e)) + logger.error("PostgreSQL connection error", error=str(e)) raise HTTPException( status_code=500, detail="Failed to connect to PostgreSQL database", @@ -59,7 +60,7 @@ async def connect_database(request: Request, db_request: DatabaseConnectionReque # Attempt to connect/load using the MySQL loader success, result = MySQLLoader.load(request.state.user_id, url) except (ValueError, ConnectionError) as e: - logging.error("MySQL connection error: %s", str(e)) + logger.error("MySQL connection error", error=str(e)) raise HTTPException( status_code=500, detail="Failed to connect to MySQL database" ) @@ -80,9 +81,9 @@ async def connect_database(request: Request, db_request: DatabaseConnectionReque }) # Don't return detailed error messages to prevent information exposure - logging.error("Database loader failed: %s", result) + logger.error("Database loader failed", error=str(result)) raise HTTPException(status_code=400, detail="Failed to load database schema") except (ValueError, TypeError) as e: - logging.error("Unexpected error in database connection: %s", str(e)) + logger.error("Unexpected error in database connection", error=str(e)) raise HTTPException(status_code=500, detail="Internal server error") diff --git a/api/routes/graphs.py b/api/routes/graphs.py index 6fbab7fb..67ccc875 100644 --- a/api/routes/graphs.py +++ b/api/routes/graphs.py @@ -1,7 +1,6 @@ """Graph-related routes for the text2sql API.""" import json -import logging import time from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FuturesTimeoutError @@ -20,10 +19,13 @@ from api.loaders.mysql_loader import MySQLLoader from api.loaders.odata_loader import ODataLoader +from api.logging_config import get_logger + # Use the same delimiter as in the JavaScript MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||" graphs_router = APIRouter() +logger = get_logger(__name__) class GraphData(BaseModel): @@ -122,7 +124,11 @@ async def get_graph_data(request: Request, graph_id: str): try: graph = db.select_graph(namespaced) except Exception as e: - logging.error("Failed to select graph %s: %s", sanitize_log_input(namespaced), e) + logger.error( + "Failed to select graph", + graph=sanitize_log_input(namespaced), + error=str(e), + ) return JSONResponse(content={"error": "Graph not found or database error"}, status_code=404) # Build table nodes with columns and table-to-table links (foreign keys) @@ -143,7 +149,11 @@ async def get_graph_data(request: Request, graph_id: str): tables_res = graph.query(tables_query).result_set links_res = graph.query(links_query).result_set except Exception as e: - logging.error("Error querying graph data for %s: %s", sanitize_log_input(namespaced), e) + logger.error( + "Error querying graph data", + graph=sanitize_log_input(namespaced), + error=str(e), + ) return JSONResponse(content={"error": "Failed to read graph data"}, status_code=500) nodes = [] @@ -260,7 +270,7 @@ async def load_graph(request: Request, data: GraphData = None, file: UploadFile return JSONResponse(content={"message": "Graph loaded successfully", "graph_id": graph_id}) # Log detailed error but return generic message to user - logging.error("Graph loading failed: %s", str(result)[:100]) + logger.error("Graph loading failed", error=str(result)[:100]) raise HTTPException(status_code=400, detail="Failed to load graph data") @@ -291,7 +301,7 @@ async def query_graph(request: Request, graph_id: str, chat_data: ChatRequest): if len(queries_history) == 0: raise HTTPException(status_code=400, detail="Empty chat history") - logging.info("User Query: %s", sanitize_query(queries_history[-1])) + logger.info("User Query", query=sanitize_query(queries_history[-1])) # Create a generator function for streaming async def generate(): @@ -299,9 +309,12 @@ async def generate(): agent_an = AnalysisAgent(queries_history, result_history) step1_start = time.perf_counter() - step = {"type": "reasoning_step", - "message": "Step 1: Analyzing user query and generating SQL..."} + step = { + "type": "reasoning_step", + "message": "Step 1: Analyzing user query and generating SQL...", + } yield json.dumps(step) + MESSAGE_DELIMITER + # Ensure the database description is loaded db_description, db_url = get_db_description(graph_id) @@ -309,93 +322,104 @@ async def generate(): db_type, loader_class = get_database_type_and_loader(db_url) if not loader_class: - yield json.dumps({ - "type": "error", - "message": "Unable to determine database type" - }) + MESSAGE_DELIMITER + yield json.dumps( + {"type": "error", "message": "Unable to determine database type"} + ) + MESSAGE_DELIMITER return - logging.info("Calling to relevancy agent with query: %s", - sanitize_query(queries_history[-1])) + logger.info( + "Calling to relevancy agent", + query=sanitize_log_input(sanitize_query(queries_history[-1])), + ) rel_start = time.perf_counter() answer_rel = agent_rel.get_answer(queries_history[-1], db_description) rel_elapsed = time.perf_counter() - rel_start - logging.info("Relevancy check took %.2f seconds", rel_elapsed) + logger.info("Relevancy check", seconds=rel_elapsed) + if answer_rel["status"] != "On-topic": step = { "type": "followup_questions", "message": "Off topic question: " + answer_rel["reason"], } - logging.info("SQL Fail reason: %s", answer_rel["reason"]) + logger.info("SQL Fail reason", reason=answer_rel.get("reason")) yield json.dumps(step) + MESSAGE_DELIMITER # Total time for the pre-analysis phase step1_elapsed = time.perf_counter() - step1_start - logging.info("Step 1 (relevancy + prep) took %.2f seconds", step1_elapsed) - else: - # Use a thread pool to enforce timeout - with ThreadPoolExecutor(max_workers=1) as executor: - find_start = time.perf_counter() - future = executor.submit(find, graph_id, queries_history, db_description) - try: - _, result, _ = future.result(timeout=120) - find_elapsed = time.perf_counter() - find_start - logging.info("Finding relevant tables took %.2f seconds", find_elapsed) - # Total time for the pre-analysis phase - step1_elapsed = time.perf_counter() - step1_start - logging.info( - "Step 1 (relevancy + table finding) took %.2f seconds", - step1_elapsed, - ) - except FuturesTimeoutError: - yield json.dumps( - { - "type": "error", - "message": ("Timeout error while finding tables relevant to " - "your request."), - } - ) + MESSAGE_DELIMITER - return - except Exception as e: - logging.info("Error in find function: %s", e) - yield json.dumps( - {"type": "error", "message": "Error in find function"} - ) + MESSAGE_DELIMITER - return - - logging.info("Calling to analysis agent with query: %s", - sanitize_query(queries_history[-1])) - - analysis_start = time.perf_counter() - answer_an = agent_an.get_analysis( - queries_history[-1], result, db_description, instructions - ) - analysis_elapsed = time.perf_counter() - analysis_start - logging.info("SQL generation took %.2f seconds", analysis_elapsed) - - logging.info("SQL Result: %s", answer_an['sql_query']) - yield json.dumps( - { - "type": "final_result", - "data": answer_an["sql_query"], - "conf": answer_an["confidence"], - "miss": answer_an["missing_information"], - "amb": answer_an["ambiguities"], - "exp": answer_an["explanation"], - "is_valid": answer_an["is_sql_translatable"], - } - ) + MESSAGE_DELIMITER - - # If the SQL query is valid, execute it using the postgress database db_url - if answer_an["is_sql_translatable"]: - # Check if this is a destructive operation that requires confirmation - sql_query = answer_an["sql_query"] - sql_type = sql_query.strip().split()[0].upper() if sql_query else "" + logger.info("Step 1 (relevancy + prep) took", seconds=step1_elapsed) + return - destructive_ops = ['INSERT', 'UPDATE', 'DELETE', 'DROP', - 'CREATE', 'ALTER', 'TRUNCATE'] - if sql_type in destructive_ops: - # This is a destructive operation - ask for user confirmation - confirmation_message = f"""⚠️ DESTRUCTIVE OPERATION DETECTED ⚠️ + # Use a thread pool to enforce timeout + with ThreadPoolExecutor(max_workers=1) as executor: + find_start = time.perf_counter() + future = executor.submit(find, graph_id, queries_history, db_description) + try: + _, result, _ = future.result(timeout=120) + find_elapsed = time.perf_counter() - find_start + logger.info("Finding relevant tables took", seconds=find_elapsed) + # Total time for the pre-analysis phase + step1_elapsed = time.perf_counter() - step1_start + logger.info("Step 1 (relevancy + table finding) took", seconds=step1_elapsed) + except FuturesTimeoutError: + yield json.dumps( + { + "type": "error", + "message": "Timeout finding relevant tables", + } + ) + MESSAGE_DELIMITER + return + except Exception as e: + logger.info("Error in find function", error=str(e)) + yield json.dumps( + {"type": "error", "message": "Error in find function"} + ) + MESSAGE_DELIMITER + return + + logger.info( + "Calling to analysis agent", + query=sanitize_query(queries_history[-1]), + ) + analysis_start = time.perf_counter() + answer_an = agent_an.get_analysis( + queries_history[-1], result, db_description, instructions + ) + analysis_elapsed = time.perf_counter() - analysis_start + logger.info("SQL generation took", seconds=analysis_elapsed) + + logger.info( + "SQL Result", + sql=answer_an.get('sql_query'), + ) + yield json.dumps( + { + "type": "final_result", + "data": answer_an["sql_query"], + "conf": answer_an["confidence"], + "miss": answer_an["missing_information"], + "amb": answer_an["ambiguities"], + "exp": answer_an["explanation"], + "is_valid": answer_an["is_sql_translatable"], + } + ) + MESSAGE_DELIMITER + + # If the SQL query is valid, execute it using the postgress database db_url + if answer_an["is_sql_translatable"]: + # Check if this is a destructive operation that requires confirmation + sql_query = answer_an["sql_query"] + sql_type = sql_query.strip().split()[0].upper() if sql_query else "" + + destructive_ops = [ + 'INSERT', + 'UPDATE', + 'DELETE', + 'DROP', + 'CREATE', + 'ALTER', + 'TRUNCATE', + ] + if sql_type in destructive_ops: + # This is a destructive operation - ask for user confirmation + confirmation_message = ( + f"""⚠️ DESTRUCTIVE OPERATION DETECTED ⚠️ The generated SQL query will perform a **{sql_type}** operation: @@ -404,119 +428,124 @@ async def generate(): What this will do: """ - if sql_type == 'INSERT': - confirmation_message += "• Add new data to the database" - elif sql_type == 'UPDATE': - confirmation_message += ("• Modify existing data in the " - "database") - elif sql_type == 'DELETE': - confirmation_message += ("• **PERMANENTLY DELETE** data " - "from the database") - elif sql_type == 'DROP': - confirmation_message += ("• **PERMANENTLY DELETE** entire " - "tables or database objects") - elif sql_type == 'CREATE': - confirmation_message += ("• Create new tables or database " - "objects") - elif sql_type == 'ALTER': - confirmation_message += ("• Modify the structure of existing " - "tables") - elif sql_type == 'TRUNCATE': - confirmation_message += ("• **PERMANENTLY DELETE ALL DATA** " - "from specified tables") - confirmation_message += """ + ) + if sql_type == 'INSERT': + confirmation_message += "• Add new data to the database" + elif sql_type == 'UPDATE': + confirmation_message += ("• Modify existing data in the " "database") + # update handled above + elif sql_type == 'DELETE': + confirmation_message += "• **PERMANENTLY DELETE** data from the database" + elif sql_type == 'DROP': + confirmation_message += ( + "\u2022 **PERMANENTLY DELETE** entire " "tables or database objects" + ) + elif sql_type == 'CREATE': + confirmation_message += "• Create new tables or database objects" + elif sql_type == 'ALTER': + confirmation_message += "• Modify the structure of existing tables" + elif sql_type == 'TRUNCATE': + confirmation_message += ( + "\u2022 **PERMANENTLY DELETE ALL DATA** " "from specified tables" + ) + confirmation_message += """ ⚠️ WARNING: This operation will make changes to your database and may be irreversible. """ - yield json.dumps( - { - "type": "destructive_confirmation", - "message": confirmation_message, - "sql_query": sql_query, - "operation_type": sql_type - } - ) + MESSAGE_DELIMITER - return # Stop here and wait for user confirmation - - try: - step = {"type": "reasoning_step", "message": "Step 2: Executing SQL query"} - yield json.dumps(step) + MESSAGE_DELIMITER + yield json.dumps( + { + "type": "destructive_confirmation", + "message": confirmation_message, + "sql_query": sql_query, + "operation_type": sql_type, + } + ) + MESSAGE_DELIMITER + return # Stop here and wait for user confirmation - # Check if this query modifies the database schema using the appropriate loader - is_schema_modifying, operation_type = ( - loader_class.is_schema_modifying_query(sql_query) - ) + try: + step = { + "type": "reasoning_step", + "message": "Step 2: Executing SQL query", + } + yield json.dumps(step) + MESSAGE_DELIMITER - query_results = loader_class.execute_sql_query(answer_an["sql_query"], db_url) - yield json.dumps( - { - "type": "query_result", - "data": query_results, - } - ) + MESSAGE_DELIMITER - - # If schema was modified, refresh the graph using the appropriate loader - if is_schema_modifying: - step = {"type": "reasoning_step", - "message": ("Step 3: Schema change detected - " - "refreshing graph...")} - yield json.dumps(step) + MESSAGE_DELIMITER - - refresh_result = loader_class.refresh_graph_schema( - graph_id, db_url) - refresh_success, refresh_message = refresh_result - - if refresh_success: - refresh_msg = (f"✅ Schema change detected " - f"({operation_type} operation)\n\n" - f"🔄 Graph schema has been automatically " - f"refreshed with the latest database " - f"structure.") - yield json.dumps( - { - "type": "schema_refresh", - "message": refresh_msg, - "refresh_status": "success" - } - ) + MESSAGE_DELIMITER - else: - failure_msg = (f"⚠️ Schema was modified but graph " - f"refresh failed: {refresh_message}") - yield json.dumps( - { - "type": "schema_refresh", - "message": failure_msg, - "refresh_status": "failed" - } - ) + MESSAGE_DELIMITER - - # Generate user-readable response using AI - step_num = "4" if is_schema_modifying else "3" - step = {"type": "reasoning_step", - "message": f"Step {step_num}: Generating user-friendly response"} + # Check if this query modifies the database schema using the appropriate loader + is_schema_modifying, operation_type = loader_class.is_schema_modifying_query( + sql_query + ) + + query_results = loader_class.execute_sql_query(answer_an["sql_query"], db_url) + yield json.dumps( + { + "type": "query_result", + "data": query_results, + } + ) + MESSAGE_DELIMITER + # If schema was modified, refresh the graph using the appropriate loader + if is_schema_modifying: + step = { + "type": "reasoning_step", + "message": "Step 3: Schema change detected - refreshing graph...", + } yield json.dumps(step) + MESSAGE_DELIMITER - response_agent = ResponseFormatterAgent() - user_readable_response = response_agent.format_response( - user_query=queries_history[-1], - sql_query=answer_an["sql_query"], - query_results=query_results, - db_description=db_description - ) + refresh_result = loader_class.refresh_graph_schema(graph_id, db_url) + refresh_success, refresh_message = refresh_result - yield json.dumps( - { - "type": "ai_response", - "message": user_readable_response, - } - ) + MESSAGE_DELIMITER + if refresh_success: + refresh_msg = ( + f"✅ Schema change detected ({operation_type} operation)\n\n" + "🔄 Graph schema has been automatically refreshed " + "with the latest database structure." + ) + yield json.dumps( + { + "type": "schema_refresh", + "message": refresh_msg, + "refresh_status": "success", + } + ) + MESSAGE_DELIMITER + else: + failure_msg = ( + f"⚠️ Schema was modified but graph refresh failed: {refresh_message}" + ) + yield json.dumps( + { + "type": "schema_refresh", + "message": failure_msg, + "refresh_status": "failed", + } + ) + MESSAGE_DELIMITER - except Exception as e: - logging.error("Error executing SQL query: %s", str(e)) - yield json.dumps( - {"type": "error", "message": "Error executing SQL query"} - ) + MESSAGE_DELIMITER + # Generate user-readable response using AI + step_num = "4" if is_schema_modifying else "3" + step = { + "type": "reasoning_step", + "message": f"Step {step_num}: Generating user-friendly response", + } + yield json.dumps(step) + MESSAGE_DELIMITER + + response_agent = ResponseFormatterAgent() + user_readable_response = response_agent.format_response( + user_query=queries_history[-1], + sql_query=answer_an["sql_query"], + query_results=query_results, + db_description=db_description, + ) + + yield json.dumps( + { + "type": "ai_response", + "message": user_readable_response, + } + ) + MESSAGE_DELIMITER + + except Exception as e: + logger.error("Error executing SQL query", error=str(e)) + yield json.dumps( + {"type": "error", "message": "Error executing SQL query"} + ) + MESSAGE_DELIMITER return StreamingResponse(generate(), media_type="application/json") @@ -629,7 +658,7 @@ def generate_confirmation(): ) + MESSAGE_DELIMITER except Exception as e: - logging.error("Error executing confirmed SQL query: %s", str(e)) + logger.error("Error executing confirmed SQL query", error=str(e)) yield json.dumps( {"type": "error", "message": "Error executing query"} ) + MESSAGE_DELIMITER @@ -683,14 +712,14 @@ async def refresh_graph_schema(request: Request, graph_id: str): "message": f"Graph schema refreshed successfully using {db_type}" }) - logging.error("Schema refresh failed for graph %s: %s", graph_id, message) + logger.error("Schema refresh failed for graph", graph_id=graph_id, error=message) return JSONResponse({ "success": False, "error": "Failed to refresh schema" }, status_code=500) except Exception as e: - logging.error("Error in manual schema refresh: %s", e) + logger.error("Error in manual schema refresh", error=str(e)) return JSONResponse({ "success": False, "error": "Error refreshing schema"