Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions api/loaders/base_loader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Base loader module providing abstract base class for data loaders."""

from abc import ABC
from typing import Tuple
from typing import AsyncGenerator


class BaseLoader(ABC):
"""Abstract base class for data loaders."""

@staticmethod
async def load(_graph_id: str, _data) -> Tuple[bool, str]:
async def load(_graph_id: str, _data) -> AsyncGenerator[tuple[bool, str], None]:
"""
Load the graph data into the database.
This method must be implemented by any subclass.
"""
return False, "Not implemented"
yield False, "Not implemented"
13 changes: 8 additions & 5 deletions api/loaders/mysql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import decimal
import logging
import re
from typing import Tuple, Dict, Any, List
from typing import AsyncGenerator, Tuple, Dict, Any, List

import tqdm
import pymysql
Expand Down Expand Up @@ -125,7 +125,7 @@ def _parse_mysql_url(connection_url: str) -> Dict[str, str]:
}

@staticmethod
async def load(prefix: str, connection_url: str) -> Tuple[bool, str]:
async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, str], None]:
"""
Load the graph data from a MySQL database into the graph database.

Expand All @@ -148,26 +148,29 @@ async def load(prefix: str, connection_url: str) -> Tuple[bool, str]:
db_name = conn_params['database']

# Get all table information
yield True, "Extracting table information..."
entities = MySQLLoader.extract_tables_info(cursor, db_name)

# Get all relationship information
yield True, "Extracting relationship information..."
relationships = MySQLLoader.extract_relationships(cursor, db_name)

# Close database connection
cursor.close()
conn.close()

# Load data into graph
yield True, "Loading data into graph..."
await load_to_graph(f"{prefix}_{db_name}", entities, relationships,
db_name=db_name, db_url=connection_url)

return True, (f"MySQL schema loaded successfully. "
yield True, (f"MySQL schema loaded successfully. "
f"Found {len(entities)} tables.")

except pymysql.MySQLError as e:
return False, f"MySQL connection error: {str(e)}"
yield False, f"MySQL connection error: {str(e)}"
except Exception as e:
return False, f"Error loading MySQL schema: {str(e)}"
yield False, f"Error loading MySQL schema: {str(e)}"

@staticmethod
def extract_tables_info(cursor, db_name: str) -> Dict[str, Any]:
Expand Down
13 changes: 8 additions & 5 deletions api/loaders/postgres_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import decimal
import logging
import re
from typing import Tuple, Dict, Any, List
from typing import AsyncGenerator, Tuple, Dict, Any, List

import psycopg2
import tqdm
Expand Down Expand Up @@ -64,7 +64,7 @@ def _serialize_value(value):
return value

@staticmethod
async def load(prefix: str, connection_url: str) -> Tuple[bool, str]:
async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, str], None]:
"""
Load the graph data from a PostgreSQL database into the graph database.

Expand All @@ -86,26 +86,29 @@ async def load(prefix: str, connection_url: str) -> Tuple[bool, str]:
db_name = db_name.split('?')[0]

# Get all table information
yield True, "Extracting table information..."
entities = PostgresLoader.extract_tables_info(cursor)

yield True, "Extracting relationship information..."
# Get all relationship information
relationships = PostgresLoader.extract_relationships(cursor)

# Close database connection
cursor.close()
conn.close()

yield True, "Loading data into graph..."
# Load data into graph
await load_to_graph(f"{prefix}_{db_name}", entities, relationships,
db_name=db_name, db_url=connection_url)

return True, (f"PostgreSQL schema loaded successfully. "
yield True, (f"PostgreSQL schema loaded successfully. "
f"Found {len(entities)} tables.")

except psycopg2.Error as e:
return False, f"PostgreSQL connection error: {str(e)}"
yield False, f"PostgreSQL connection error: {str(e)}"
except Exception as e:
return False, f"Error loading PostgreSQL schema: {str(e)}"
yield False, f"Error loading PostgreSQL schema: {str(e)}"

@staticmethod
def extract_tables_info(cursor) -> Dict[str, Any]:
Expand Down
141 changes: 94 additions & 47 deletions api/routes/database.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
"""Database connection routes for the text2sql API."""

import logging
import json
import time

from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import JSONResponse
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from api.auth.user_management import token_required
Expand All @@ -11,23 +14,25 @@

database_router = APIRouter()

# Use the same delimiter as in the JavaScript frontend for streaming chunks
MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||"

class DatabaseConnectionRequest(BaseModel):
"""Database connection request model.

Args:
BaseModel (_type_): _description_
"""
url: str

url: str

@database_router.post("/database", operation_id="connect_database")
@token_required
async def connect_database(request: Request, db_request: DatabaseConnectionRequest):
"""
Accepts a JSON payload with a database URL and attempts to connect.
Supports both PostgreSQL and MySQL databases.
Returns success or error message.
Streams progress steps as a sequence of JSON messages separated by MESSAGE_DELIMITER.
"""
url = db_request.url
if not url:
Expand All @@ -37,52 +42,94 @@ async def connect_database(request: Request, db_request: DatabaseConnectionReque
if not isinstance(url, str) or len(url.strip()) == 0:
raise HTTPException(status_code=400, detail="Invalid URL format")

try:
success = False
result = ""

# Check for PostgreSQL URL
if url.startswith("postgres://") or url.startswith("postgresql://"):
async def generate():
overall_start = time.perf_counter()
steps_counter = 0
try:
# Step 1: Start
steps_counter += 1
yield json.dumps(
{
"type": "reasoning_step",
"message": f"Step {steps_counter}: Starting database connection",
}
) + MESSAGE_DELIMITER

# Step 2: Determine type
db_type = None
if url.startswith("postgres://") or url.startswith("postgresql://"):
db_type = "postgresql"
loader = PostgresLoader
elif url.startswith("mysql://"):
db_type = "mysql"
loader = MySQLLoader
else:
yield json.dumps(
{"type": "error", "message": "Invalid database URL format"}
) + MESSAGE_DELIMITER
return

steps_counter += 1
yield json.dumps(
{
"type": "reasoning_step",
"message": f"Step {steps_counter}: Detected database type: {db_type}. "
"Attempting to load schema...",
}
) + MESSAGE_DELIMITER

# Step 3: Attempt to load schema using the loader
success, result = [False, ""]
try:
# Attempt to connect/load using the PostgreSQL loader
success, result = await PostgresLoader.load(request.state.user_id, url)
except (ValueError, ConnectionError) as e:
logging.error("PostgreSQL connection error: %s", str(e))
raise HTTPException(
status_code=500,
detail="Failed to connect to PostgreSQL database",
load_start = time.perf_counter()
async for progress in loader.load(request.state.user_id, url):
success, result = progress
if success:
steps_counter += 1
yield json.dumps(
{
"type": "reasoning_step",
"message": f"Step {steps_counter}: {result}",
}
) + MESSAGE_DELIMITER
else:
break

load_elapsed = time.perf_counter() - load_start
logging.info(
"Database load attempt finished in %.2f seconds", load_elapsed
)

# Check for MySQL URL
elif url.startswith("mysql://"):
try:
# Attempt to connect/load using the MySQL loader
success, result = await MySQLLoader.load(request.state.user_id, url)
except (ValueError, ConnectionError) as e:
logging.error("MySQL connection error: %s", str(e))
raise HTTPException(
status_code=500, detail="Failed to connect to MySQL database"
)

else:
raise HTTPException(
status_code=400,
detail=(
"Invalid database URL. Supported formats: postgresql:// "
"or mysql://"
),
if success:
yield json.dumps(
{
"type": "final_result",
"success": True,
"message": "Database connected and schema loaded successfully",
}
) + MESSAGE_DELIMITER
else:
# Don't stream the full internal result; give higher-level error
logging.error("Database loader failed: %s", str(result))
yield json.dumps(
{"type": "error", "message": "Failed to load database schema"}
) + MESSAGE_DELIMITER
except Exception as e:
logging.exception("Error while loading database schema: %s", str(e))
yield json.dumps(
{"type": "error", "message": "Error connecting to database"}
) + MESSAGE_DELIMITER

except Exception as e:
logging.exception("Unexpected error in connect_database stream: %s", str(e))
yield json.dumps(
{"type": "error", "message": "Internal server error"}
) + MESSAGE_DELIMITER
finally:
overall_elapsed = time.perf_counter() - overall_start
logging.info(
"connect_database processing completed - Total time: %.2f seconds",
overall_elapsed,
)

if success:
return JSONResponse(content={
"success": True,
"message": "Database connected successfully"
})

# Don't return detailed error messages to prevent information exposure
logging.error("Database loader failed: %s", result)
raise HTTPException(status_code=400, detail="Failed to load database schema")

except (ValueError, TypeError) as e:
logging.error("Unexpected error in database connection: %s", str(e))
raise HTTPException(status_code=500, detail="Internal server error")
return StreamingResponse(generate(), media_type="application/json")

Check warning

Code scanning / CodeQL

Information exposure through an exception Medium

Stack trace information flows to this location and may be exposed to an external user.
Stack trace information flows to this location and may be exposed to an external user.
Stack trace information flows to this location and may be exposed to an external user.
Stack trace information flows to this location and may be exposed to an external user.

Copilot Autofix

AI 8 months ago

To address the vulnerability of exposing internal error information to users, the fix must prevent potentially sensitive data from being returned in streamed JSON responses. To do this, the PostgresLoader.load method (in api/loaders/postgres_loader.py) must never yield detailed exception messages, instead only yielding generic, safe error messages. Internal details can and should be logged on the server.

  • Update the except blocks in PostgresLoader.load so that, instead of yielding f-strings containing str(e), they log that info and only yield a sanitized message to the caller (route).
  • No changes are required in the route itself if we can guarantee that loader always yields only generic error messages (as the route already returns friendly error text in other places).
  • Only file api/loaders/postgres_loader.py requires editing to update the error yielding logic.

Suggested changeset 1
api/loaders/postgres_loader.py
Outside changed files

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/api/loaders/postgres_loader.py b/api/loaders/postgres_loader.py
--- a/api/loaders/postgres_loader.py
+++ b/api/loaders/postgres_loader.py
@@ -106,9 +106,11 @@
                          f"Found {len(entities)} tables.")
 
         except psycopg2.Error as e:
-            yield False, f"PostgreSQL connection error: {str(e)}"
+            logging.error("PostgreSQL connection error: %s", str(e))
+            yield False, "Could not connect to the PostgreSQL database. Please check the connection details."
         except Exception as e:
-            yield False, f"Error loading PostgreSQL schema: {str(e)}"
+            logging.exception("Unexpected error while loading PostgreSQL schema: %s", str(e))
+            yield False, "An internal error occurred while loading the PostgreSQL schema."
 
     @staticmethod
     def extract_tables_info(cursor) -> Dict[str, Any]:
EOF
@@ -106,9 +106,11 @@
f"Found {len(entities)} tables.")

except psycopg2.Error as e:
yield False, f"PostgreSQL connection error: {str(e)}"
logging.error("PostgreSQL connection error: %s", str(e))
yield False, "Could not connect to the PostgreSQL database. Please check the connection details."
except Exception as e:
yield False, f"Error loading PostgreSQL schema: {str(e)}"
logging.exception("Unexpected error while loading PostgreSQL schema: %s", str(e))
yield False, "An internal error occurred while loading the PostgreSQL schema."

@staticmethod
def extract_tables_info(cursor) -> Dict[str, Any]:
Copilot is powered by AI and may make mistakes. Always verify output.
37 changes: 37 additions & 0 deletions app/public/css/modals.css
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,43 @@
overflow-x: auto;
}

/* Styles for incremental database connection steps shown in the connect modal */
#db-connection-steps {
margin: 16px 24px;
font-size: 14px;
color: var(--text-primary);
}

#db-connection-steps-list {
list-style: none;
padding: 0;
margin: 0;
max-height: 220px;
overflow: auto;
}

.db-connection-step {
display: flex;
align-items: center;
margin: 8px 0;
}

.db-connection-step .step-icon {
display: inline-flex;
width: 20px;
height: 20px;
margin-right: 8px;
border-radius: 50%;
align-items: center;
justify-content: center;
font-size: 12px;
line-height: 20px;
}

.db-connection-step .step-icon.pending { color: #1f6feb; }
.db-connection-step .step-icon.success { color: #16a34a; }
.db-connection-step .step-icon.error { color: #dc2626; }

.alert {
padding: 1.5em;
border-radius: 6px;
Expand Down
6 changes: 6 additions & 0 deletions app/templates/components/database_modal.j2
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
<div class="db-modal-content">
<h2 class="db-modal-title" id="db-modal-title">Connect to Database</h2>
<input type="text" id="db-url-input" class="db-modal-input" placeholder="Select database type first..." disabled />

<!-- Connection progress steps will appear here -->
<div id="db-connection-steps">
<ul id="db-connection-steps-list"></ul>
</div>

<div class="db-modal-actions">
<button id="db-modal-connect" class="db-modal-btn db-modal-connect" disabled>
<span class="db-modal-connect-text">Connect</span>
Expand Down
Loading
Loading