diff --git a/api/loaders/base_loader.py b/api/loaders/base_loader.py index 735a9a57..1dab4b1c 100644 --- a/api/loaders/base_loader.py +++ b/api/loaders/base_loader.py @@ -1,16 +1,16 @@ """Base loader module providing abstract base class for data loaders.""" from abc import ABC -from typing import Tuple +from typing import AsyncGenerator class BaseLoader(ABC): """Abstract base class for data loaders.""" @staticmethod - async def load(_graph_id: str, _data) -> Tuple[bool, str]: + async def load(_graph_id: str, _data) -> AsyncGenerator[tuple[bool, str], None]: """ Load the graph data into the database. This method must be implemented by any subclass. """ - return False, "Not implemented" + yield False, "Not implemented" diff --git a/api/loaders/mysql_loader.py b/api/loaders/mysql_loader.py index 5c22c03a..d2ef3091 100644 --- a/api/loaders/mysql_loader.py +++ b/api/loaders/mysql_loader.py @@ -4,7 +4,7 @@ import decimal import logging import re -from typing import Tuple, Dict, Any, List +from typing import AsyncGenerator, Tuple, Dict, Any, List import tqdm import pymysql @@ -125,7 +125,7 @@ def _parse_mysql_url(connection_url: str) -> Dict[str, str]: } @staticmethod - async def load(prefix: str, connection_url: str) -> Tuple[bool, str]: + async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, str], None]: """ Load the graph data from a MySQL database into the graph database. @@ -148,9 +148,11 @@ async def load(prefix: str, connection_url: str) -> Tuple[bool, str]: db_name = conn_params['database'] # Get all table information + yield True, "Extracting table information..." entities = MySQLLoader.extract_tables_info(cursor, db_name) # Get all relationship information + yield True, "Extracting relationship information..." relationships = MySQLLoader.extract_relationships(cursor, db_name) # Close database connection @@ -158,16 +160,17 @@ async def load(prefix: str, connection_url: str) -> Tuple[bool, str]: conn.close() # Load data into graph + yield True, "Loading data into graph..." await load_to_graph(f"{prefix}_{db_name}", entities, relationships, db_name=db_name, db_url=connection_url) - return True, (f"MySQL schema loaded successfully. " + yield True, (f"MySQL schema loaded successfully. " f"Found {len(entities)} tables.") except pymysql.MySQLError as e: - return False, f"MySQL connection error: {str(e)}" + yield False, f"MySQL connection error: {str(e)}" except Exception as e: - return False, f"Error loading MySQL schema: {str(e)}" + yield False, f"Error loading MySQL schema: {str(e)}" @staticmethod def extract_tables_info(cursor, db_name: str) -> Dict[str, Any]: diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py index 0ac50e49..fa2da1b6 100644 --- a/api/loaders/postgres_loader.py +++ b/api/loaders/postgres_loader.py @@ -4,7 +4,7 @@ import decimal import logging import re -from typing import Tuple, Dict, Any, List +from typing import AsyncGenerator, Tuple, Dict, Any, List import psycopg2 import tqdm @@ -64,7 +64,7 @@ def _serialize_value(value): return value @staticmethod - async def load(prefix: str, connection_url: str) -> Tuple[bool, str]: + async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, str], None]: """ Load the graph data from a PostgreSQL database into the graph database. @@ -86,8 +86,10 @@ async def load(prefix: str, connection_url: str) -> Tuple[bool, str]: db_name = db_name.split('?')[0] # Get all table information + yield True, "Extracting table information..." entities = PostgresLoader.extract_tables_info(cursor) + yield True, "Extracting relationship information..." # Get all relationship information relationships = PostgresLoader.extract_relationships(cursor) @@ -95,17 +97,18 @@ async def load(prefix: str, connection_url: str) -> Tuple[bool, str]: cursor.close() conn.close() + yield True, "Loading data into graph..." # Load data into graph await load_to_graph(f"{prefix}_{db_name}", entities, relationships, db_name=db_name, db_url=connection_url) - return True, (f"PostgreSQL schema loaded successfully. " + yield True, (f"PostgreSQL schema loaded successfully. " f"Found {len(entities)} tables.") except psycopg2.Error as e: - return False, f"PostgreSQL connection error: {str(e)}" + yield False, f"PostgreSQL connection error: {str(e)}" except Exception as e: - return False, f"Error loading PostgreSQL schema: {str(e)}" + yield False, f"Error loading PostgreSQL schema: {str(e)}" @staticmethod def extract_tables_info(cursor) -> Dict[str, Any]: diff --git a/api/routes/database.py b/api/routes/database.py index d6a17547..3d783178 100644 --- a/api/routes/database.py +++ b/api/routes/database.py @@ -1,8 +1,11 @@ """Database connection routes for the text2sql API.""" + import logging +import json +import time from fastapi import APIRouter, Request, HTTPException -from fastapi.responses import JSONResponse +from fastapi.responses import StreamingResponse from pydantic import BaseModel from api.auth.user_management import token_required @@ -11,6 +14,8 @@ database_router = APIRouter() +# Use the same delimiter as in the JavaScript frontend for streaming chunks +MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||" class DatabaseConnectionRequest(BaseModel): """Database connection request model. @@ -18,8 +23,8 @@ class DatabaseConnectionRequest(BaseModel): Args: BaseModel (_type_): _description_ """ - url: str + url: str @database_router.post("/database", operation_id="connect_database") @token_required @@ -27,7 +32,7 @@ async def connect_database(request: Request, db_request: DatabaseConnectionReque """ Accepts a JSON payload with a database URL and attempts to connect. Supports both PostgreSQL and MySQL databases. - Returns success or error message. + Streams progress steps as a sequence of JSON messages separated by MESSAGE_DELIMITER. """ url = db_request.url if not url: @@ -37,52 +42,94 @@ async def connect_database(request: Request, db_request: DatabaseConnectionReque if not isinstance(url, str) or len(url.strip()) == 0: raise HTTPException(status_code=400, detail="Invalid URL format") - try: - success = False - result = "" - - # Check for PostgreSQL URL - if url.startswith("postgres://") or url.startswith("postgresql://"): + async def generate(): + overall_start = time.perf_counter() + steps_counter = 0 + try: + # Step 1: Start + steps_counter += 1 + yield json.dumps( + { + "type": "reasoning_step", + "message": f"Step {steps_counter}: Starting database connection", + } + ) + MESSAGE_DELIMITER + + # Step 2: Determine type + db_type = None + if url.startswith("postgres://") or url.startswith("postgresql://"): + db_type = "postgresql" + loader = PostgresLoader + elif url.startswith("mysql://"): + db_type = "mysql" + loader = MySQLLoader + else: + yield json.dumps( + {"type": "error", "message": "Invalid database URL format"} + ) + MESSAGE_DELIMITER + return + + steps_counter += 1 + yield json.dumps( + { + "type": "reasoning_step", + "message": f"Step {steps_counter}: Detected database type: {db_type}. " + "Attempting to load schema...", + } + ) + MESSAGE_DELIMITER + + # Step 3: Attempt to load schema using the loader + success, result = [False, ""] try: - # Attempt to connect/load using the PostgreSQL loader - success, result = await PostgresLoader.load(request.state.user_id, url) - except (ValueError, ConnectionError) as e: - logging.error("PostgreSQL connection error: %s", str(e)) - raise HTTPException( - status_code=500, - detail="Failed to connect to PostgreSQL database", + load_start = time.perf_counter() + async for progress in loader.load(request.state.user_id, url): + success, result = progress + if success: + steps_counter += 1 + yield json.dumps( + { + "type": "reasoning_step", + "message": f"Step {steps_counter}: {result}", + } + ) + MESSAGE_DELIMITER + else: + break + + load_elapsed = time.perf_counter() - load_start + logging.info( + "Database load attempt finished in %.2f seconds", load_elapsed ) - # Check for MySQL URL - elif url.startswith("mysql://"): - try: - # Attempt to connect/load using the MySQL loader - success, result = await MySQLLoader.load(request.state.user_id, url) - except (ValueError, ConnectionError) as e: - logging.error("MySQL connection error: %s", str(e)) - raise HTTPException( - status_code=500, detail="Failed to connect to MySQL database" - ) - - else: - raise HTTPException( - status_code=400, - detail=( - "Invalid database URL. Supported formats: postgresql:// " - "or mysql://" - ), + if success: + yield json.dumps( + { + "type": "final_result", + "success": True, + "message": "Database connected and schema loaded successfully", + } + ) + MESSAGE_DELIMITER + else: + # Don't stream the full internal result; give higher-level error + logging.error("Database loader failed: %s", str(result)) + yield json.dumps( + {"type": "error", "message": "Failed to load database schema"} + ) + MESSAGE_DELIMITER + except Exception as e: + logging.exception("Error while loading database schema: %s", str(e)) + yield json.dumps( + {"type": "error", "message": "Error connecting to database"} + ) + MESSAGE_DELIMITER + + except Exception as e: + logging.exception("Unexpected error in connect_database stream: %s", str(e)) + yield json.dumps( + {"type": "error", "message": "Internal server error"} + ) + MESSAGE_DELIMITER + finally: + overall_elapsed = time.perf_counter() - overall_start + logging.info( + "connect_database processing completed - Total time: %.2f seconds", + overall_elapsed, ) - if success: - return JSONResponse(content={ - "success": True, - "message": "Database connected successfully" - }) - - # Don't return detailed error messages to prevent information exposure - logging.error("Database loader failed: %s", result) - raise HTTPException(status_code=400, detail="Failed to load database schema") - - except (ValueError, TypeError) as e: - logging.error("Unexpected error in database connection: %s", str(e)) - raise HTTPException(status_code=500, detail="Internal server error") + return StreamingResponse(generate(), media_type="application/json") diff --git a/app/public/css/modals.css b/app/public/css/modals.css index ad438fb1..7abcd2e4 100644 --- a/app/public/css/modals.css +++ b/app/public/css/modals.css @@ -399,6 +399,43 @@ overflow-x: auto; } +/* Styles for incremental database connection steps shown in the connect modal */ +#db-connection-steps { + margin: 16px 24px; + font-size: 14px; + color: var(--text-primary); +} + +#db-connection-steps-list { + list-style: none; + padding: 0; + margin: 0; + max-height: 220px; + overflow: auto; +} + +.db-connection-step { + display: flex; + align-items: center; + margin: 8px 0; +} + +.db-connection-step .step-icon { + display: inline-flex; + width: 20px; + height: 20px; + margin-right: 8px; + border-radius: 50%; + align-items: center; + justify-content: center; + font-size: 12px; + line-height: 20px; +} + +.db-connection-step .step-icon.pending { color: #1f6feb; } +.db-connection-step .step-icon.success { color: #16a34a; } +.db-connection-step .step-icon.error { color: #dc2626; } + .alert { padding: 1.5em; border-radius: 6px; diff --git a/app/templates/components/database_modal.j2 b/app/templates/components/database_modal.j2 index 7795401a..c2580112 100644 --- a/app/templates/components/database_modal.j2 +++ b/app/templates/components/database_modal.j2 @@ -3,6 +3,12 @@

Connect to Database

+ + +
+ +
+