diff --git a/.github/wordlist.txt b/.github/wordlist.txt index 5eacb0df..af8192d5 100644 --- a/.github/wordlist.txt +++ b/.github/wordlist.txt @@ -121,3 +121,6 @@ pytest Radix Zod Dependabot +pymssql +sqlserver +SQLServerLoader diff --git a/api/core/schema_loader.py b/api/core/schema_loader.py index b1568514..0f59c4d6 100644 --- a/api/core/schema_loader.py +++ b/api/core/schema_loader.py @@ -14,6 +14,7 @@ from api.loaders.postgres_loader import PostgresLoader from api.loaders.mysql_loader import MySQLLoader from api.loaders.snowflake_loader import SnowflakeLoader +from api.loaders.sqlserver_loader import SQLServerLoader # Use the same delimiter as in the JavaScript frontend for streaming chunks MESSAGE_DELIMITER = "|||FALKORDB_MESSAGE_BOUNDARY|||" @@ -48,6 +49,9 @@ def _step_detect_db_type(steps_counter: int, url: str) -> tuple[type[BaseLoader] elif url.startswith("snowflake://"): db_type = "snowflake" loader = SnowflakeLoader + elif url.startswith("sqlserver://"): + db_type = "sqlserver" + loader = SQLServerLoader else: raise InvalidArgumentError("Invalid database URL format") diff --git a/api/core/text2sql.py b/api/core/text2sql.py index 65bc0b4f..76f380af 100644 --- a/api/core/text2sql.py +++ b/api/core/text2sql.py @@ -21,6 +21,7 @@ from api.loaders.postgres_loader import PostgresLoader from api.loaders.mysql_loader import MySQLLoader from api.loaders.snowflake_loader import SnowflakeLoader +from api.loaders.sqlserver_loader import SQLServerLoader from api.memory.graphiti_tool import MemoryTool from api.sql_utils import SQLIdentifierQuoter, DatabaseSpecificQuoter @@ -87,6 +88,8 @@ def get_database_type_and_loader(db_url: str): return 'mysql', MySQLLoader if db_url_lower.startswith('snowflake://'): return 'snowflake', SnowflakeLoader + if db_url_lower.startswith('sqlserver://'): + return 'sqlserver', SQLServerLoader # Default to PostgresLoader for backward compatibility return 'postgresql', PostgresLoader diff --git a/api/loaders/sqlserver_loader.py b/api/loaders/sqlserver_loader.py new file mode 100644 index 00000000..e5d9518b --- /dev/null +++ b/api/loaders/sqlserver_loader.py @@ -0,0 +1,617 @@ +"""SQL Server loader for loading database schemas into FalkorDB graphs.""" + +import datetime +import decimal +import logging +import re +from typing import AsyncGenerator, Dict, Any, List, Tuple + +import tqdm +import pymssql + +from api.loaders.base_loader import BaseLoader +from api.loaders.graph_loader import load_to_graph + + +class SQLServerQueryError(Exception): + """Exception raised for SQL Server query execution errors.""" + + +class SQLServerConnectionError(Exception): + """Exception raised for SQL Server connection errors.""" + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") + + +class SQLServerLoader(BaseLoader): + """ + Loader for SQL Server databases that connects and extracts schema information. + """ + + # DDL operations that modify database schema # pylint: disable=duplicate-code + SCHEMA_MODIFYING_OPERATIONS = { + 'CREATE', 'ALTER', 'DROP', 'RENAME', 'TRUNCATE' + } + + # More specific patterns for schema-affecting operations + SCHEMA_PATTERNS = [ # pylint: disable=duplicate-code + r'^\s*CREATE\s+TABLE', + r'^\s*CREATE\s+INDEX', + r'^\s*CREATE\s+UNIQUE\s+INDEX', + r'^\s*ALTER\s+TABLE', + r'^\s*DROP\s+TABLE', + r'^\s*DROP\s+INDEX', + r'^\s*RENAME\s+TABLE', + r'^\s*TRUNCATE\s+TABLE', + r'^\s*CREATE\s+VIEW', + r'^\s*DROP\s+VIEW', + r'^\s*CREATE\s+SCHEMA', + r'^\s*DROP\s+SCHEMA', + ] + + @staticmethod + def _execute_sample_query( + cursor, table_name: str, col_name: str, sample_size: int = 3 + ) -> List[Any]: + """ + Execute query to get random sample values for a column. + SQL Server implementation using TABLESAMPLE or TOP with NEWID() for random sampling. + """ + query = ( + f"SELECT DISTINCT TOP {int(sample_size)} [{col_name}]" + f" FROM [{table_name}]" + f" WHERE [{col_name}] IS NOT NULL" + f" ORDER BY NEWID()" + ) + cursor.execute(query) + + sample_results = cursor.fetchall() + return [row[0] for row in sample_results if row[0] is not None] + + @staticmethod + def _serialize_value(value): + """ + Convert non-JSON serializable values to JSON serializable format. + + Args: + value: The value to serialize + + Returns: + JSON serializable version of the value + """ + if isinstance(value, (datetime.date, datetime.datetime)): + return value.isoformat() + if isinstance(value, datetime.time): + return value.isoformat() + if isinstance(value, decimal.Decimal): + return float(value) + if isinstance(value, bytes): + return value.hex() + if value is None: + return None + return value + + @staticmethod + def _parse_sqlserver_url(connection_url: str) -> Dict[str, Any]: + """ + Parse SQL Server connection URL into components. + + Args: + connection_url: SQL Server connection URL in format: + sqlserver://username:password@host:port/database + + Returns: + Dict with connection parameters + """ + # Remove sqlserver:// prefix + if connection_url.startswith('sqlserver://'): + url = connection_url[len('sqlserver://'):] + else: + raise ValueError( + "Invalid SQL Server URL format. Expected " + "sqlserver://username:password@host:port/database" + ) + + # Parse components + if '@' not in url: + raise ValueError("SQL Server URL must include username and host") + + credentials, host_db = url.split('@', 1) + + if ':' in credentials: + username, password = credentials.split(':', 1) + else: + username = credentials + password = "" + + if '/' not in host_db: + raise ValueError("SQL Server URL must include database name") + + host_port, database = host_db.split('/', 1) + + # Handle query parameters + if '?' in database: + database = database.split('?')[0] + + if ':' in host_port: + host, port_str = host_port.split(':', 1) + port = int(port_str) + else: + host = host_port + port = 1433 + + return { + 'server': host, + 'port': port, + 'user': username, + 'password': password, + 'database': database + } + + @staticmethod + async def load(prefix: str, connection_url: str) -> AsyncGenerator[tuple[bool, str], None]: + """ + Load the graph data from a SQL Server database into the graph database. + + Args: + connection_url: SQL Server connection URL in format: + sqlserver://username:password@host:port/database + + Returns: + Tuple[bool, str]: Success status and message + """ + try: + # Parse connection URL + conn_params = SQLServerLoader._parse_sqlserver_url(connection_url) + + # Connect to SQL Server database + conn = pymssql.connect(**conn_params) # pylint: disable=no-member + cursor = conn.cursor(as_dict=True) + + # Get database name + db_name = conn_params['database'] + + # Get all table information + yield True, "Extracting table information..." + entities = SQLServerLoader.extract_tables_info(cursor, db_name) + + # Get all relationship information + yield True, "Extracting relationship information..." + relationships = SQLServerLoader.extract_relationships(cursor, db_name) + + # Close database connection + cursor.close() + conn.close() + + # Load data into graph + yield True, "Loading data into graph..." + await load_to_graph(f"{prefix}_{db_name}", entities, relationships, + db_name=db_name, db_url=connection_url) + + yield True, (f"SQL Server schema loaded successfully. " + f"Found {len(entities)} tables.") + + except pymssql.Error as e: + logging.error("SQL Server connection error: %s", e) + yield False, "Failed to connect to SQL Server database" + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("Error loading SQL Server schema: %s", e) + yield False, "Failed to load SQL Server database schema" + + @staticmethod + def extract_tables_info(cursor, db_name: str) -> Dict[str, Any]: # pylint: disable=unused-argument + """ + Extract table and column information from SQL Server database. + + Args: + cursor: Database cursor + db_name: Database name + + Returns: + Dict containing table information + """ + entities = {} + + # Get all tables in the database + cursor.execute(""" + SELECT + t.name AS table_name, + ISNULL(ep.value, '') AS table_comment + FROM sys.tables t + LEFT JOIN sys.extended_properties ep + ON ep.major_id = t.object_id + AND ep.minor_id = 0 + AND ep.name = 'MS_Description' + WHERE t.is_ms_shipped = 0 + ORDER BY t.name; + """) + + tables = cursor.fetchall() + + for table_info in tqdm.tqdm(tables, desc="Extracting table information"): + table_name = table_info['table_name'] + table_comment = table_info['table_comment'] + + # Get column information for this table + columns_info = SQLServerLoader.extract_columns_info(cursor, db_name, table_name) + + # Get foreign keys for this table + foreign_keys = SQLServerLoader.extract_foreign_keys(cursor, db_name, table_name) + + # Generate table description + table_description = table_comment if table_comment else f"Table: {table_name}" + + # Get column descriptions for batch embedding + col_descriptions = [col_info['description'] for col_info in columns_info.values()] + + entities[table_name] = { + 'description': table_description, + 'columns': columns_info, + 'foreign_keys': foreign_keys, + 'col_descriptions': col_descriptions + } + + return entities + + @staticmethod + def extract_columns_info(cursor, db_name: str, table_name: str) -> Dict[str, Any]: # pylint: disable=unused-argument + """ + Extract column information for a specific table. + + Args: + cursor: Database cursor + db_name: Database name + table_name: Name of the table + + Returns: + Dict containing column information + """ + cursor.execute(""" + SELECT + c.name AS column_name, + tp.name AS data_type, + c.is_nullable, + dc.definition AS column_default, + CASE + WHEN pk.column_id IS NOT NULL THEN 'PRI' + WHEN fk.parent_column_id IS NOT NULL THEN 'MUL' + WHEN uc.column_id IS NOT NULL THEN 'UNI' + ELSE '' + END AS column_key, + ISNULL(ep.value, '') AS column_comment + FROM sys.columns c + JOIN sys.types tp ON c.user_type_id = tp.user_type_id + JOIN sys.tables t ON c.object_id = t.object_id + LEFT JOIN sys.default_constraints dc ON c.default_object_id = dc.object_id + LEFT JOIN ( + SELECT ic.object_id, ic.column_id + FROM sys.index_columns ic + JOIN sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id + WHERE i.is_primary_key = 1 + ) pk ON c.object_id = pk.object_id AND c.column_id = pk.column_id + LEFT JOIN sys.foreign_key_columns fk + ON fk.parent_object_id = c.object_id AND fk.parent_column_id = c.column_id + LEFT JOIN ( + SELECT ic.object_id, ic.column_id + FROM sys.index_columns ic + JOIN sys.indexes i ON ic.object_id = i.object_id AND ic.index_id = i.index_id + WHERE i.is_unique = 1 AND i.is_primary_key = 0 + ) uc ON c.object_id = uc.object_id AND c.column_id = uc.column_id + LEFT JOIN sys.extended_properties ep + ON ep.major_id = c.object_id + AND ep.minor_id = c.column_id + AND ep.name = 'MS_Description' + WHERE t.name = %s + ORDER BY c.column_id; + """, (table_name,)) + + columns = cursor.fetchall() + columns_info = {} + + for col_info in columns: + col_name = col_info['column_name'] + data_type = col_info['data_type'] + is_nullable = 'YES' if col_info['is_nullable'] else 'NO' + column_default = col_info['column_default'] + column_key = col_info['column_key'] + column_comment = col_info['column_comment'] + + # Determine key type + if column_key == 'PRI': + key_type = 'PRIMARY KEY' + elif column_key == 'MUL': + key_type = 'FOREIGN KEY' + elif column_key == 'UNI': + key_type = 'UNIQUE KEY' + else: + key_type = 'NONE' + + # Generate column description + description_parts = [] + if column_comment: + description_parts.append(str(column_comment)) + else: + description_parts.append(f"Column {col_name} of type {data_type}") + + if key_type != 'NONE': + description_parts.append(f"({key_type})") + + if is_nullable == 'NO': + description_parts.append("(NOT NULL)") + + if column_default is not None: + description_parts.append(f"(Default: {column_default})") + + # Extract sample values for the column (stored separately, not in description) + sample_values = SQLServerLoader.extract_sample_values_for_column( + cursor, table_name, col_name + ) + + columns_info[col_name] = { + 'type': data_type, + 'null': is_nullable, + 'key': key_type, + 'description': ' '.join(description_parts), + 'default': column_default, + 'sample_values': sample_values + } + + return columns_info + + @staticmethod + def extract_foreign_keys(cursor, db_name: str, table_name: str) -> List[Dict[str, str]]: # pylint: disable=unused-argument + """ + Extract foreign key information for a specific table. + + Args: + cursor: Database cursor + db_name: Database name + table_name: Name of the table + + Returns: + List of foreign key dictionaries + """ + cursor.execute(""" + SELECT + fk.name AS constraint_name, + cp.name AS column_name, + rt.name AS referenced_table_name, + cr.name AS referenced_column_name + FROM sys.foreign_keys fk + JOIN sys.foreign_key_columns fkc + ON fk.object_id = fkc.constraint_object_id + JOIN sys.columns cp + ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id + JOIN sys.tables rt + ON fkc.referenced_object_id = rt.object_id + JOIN sys.columns cr + ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id + JOIN sys.tables pt + ON fkc.parent_object_id = pt.object_id + WHERE pt.name = %s + ORDER BY fk.name; + """, (table_name,)) + + foreign_keys = [] + for fk_info in cursor.fetchall(): + foreign_keys.append({ + 'constraint_name': fk_info['constraint_name'], + 'column': fk_info['column_name'], + 'referenced_table': fk_info['referenced_table_name'], + 'referenced_column': fk_info['referenced_column_name'] + }) + + return foreign_keys + + @staticmethod + def extract_relationships(cursor, db_name: str) -> Dict[str, List[Dict[str, str]]]: # pylint: disable=unused-argument + """ + Extract all relationship information from the database. + + Args: + cursor: Database cursor + db_name: Database name + + Returns: + Dict containing relationship information + """ + cursor.execute(""" + SELECT + pt.name AS table_name, + fk.name AS constraint_name, + cp.name AS column_name, + rt.name AS referenced_table_name, + cr.name AS referenced_column_name + FROM sys.foreign_keys fk + JOIN sys.foreign_key_columns fkc + ON fk.object_id = fkc.constraint_object_id + JOIN sys.columns cp + ON fkc.parent_object_id = cp.object_id AND fkc.parent_column_id = cp.column_id + JOIN sys.tables pt + ON fkc.parent_object_id = pt.object_id + JOIN sys.tables rt + ON fkc.referenced_object_id = rt.object_id + JOIN sys.columns cr + ON fkc.referenced_object_id = cr.object_id AND fkc.referenced_column_id = cr.column_id + ORDER BY pt.name, fk.name; + """) + + relationships = {} + for rel_info in cursor.fetchall(): + constraint_name = rel_info['constraint_name'] + + if constraint_name not in relationships: + relationships[constraint_name] = [] + + relationships[constraint_name].append({ + 'from': rel_info['table_name'], + 'to': rel_info['referenced_table_name'], + 'source_column': rel_info['column_name'], + 'target_column': rel_info['referenced_column_name'], + 'note': f'Foreign key constraint: {constraint_name}' + }) + + return relationships + + @staticmethod + def is_schema_modifying_query(sql_query: str) -> Tuple[bool, str]: + """ + Check if a SQL query modifies the database schema. + + Args: + sql_query: The SQL query to check + + Returns: + Tuple of (is_schema_modifying, operation_type) + """ + if not sql_query or not sql_query.strip(): + return False, "" + + # Clean and normalize the query + normalized_query = sql_query.strip().upper() + + # Check for basic DDL operations + first_word = normalized_query.split()[0] if normalized_query.split() else "" + if first_word in SQLServerLoader.SCHEMA_MODIFYING_OPERATIONS: + # Additional pattern matching for more precise detection + for pattern in SQLServerLoader.SCHEMA_PATTERNS: + if re.match(pattern, normalized_query, re.IGNORECASE): + return True, first_word + + # If it's a known DDL operation but doesn't match specific patterns, + # still consider it schema-modifying (better safe than sorry) + return True, first_word + + return False, "" + + @staticmethod + async def refresh_graph_schema(graph_id: str, db_url: str) -> Tuple[bool, str]: + """ + Refresh the graph schema by clearing existing data and reloading from the database. + + Args: + graph_id: The graph ID to refresh + db_url: Database connection URL + + Returns: + Tuple of (success, message) + """ + try: + logging.info("Schema modification detected. Refreshing graph schema.") + + # Import here to avoid circular imports + from api.extensions import db # pylint: disable=import-error,import-outside-toplevel + + # Clear existing graph data + # Drop current graph before reloading + graph = db.select_graph(graph_id) + await graph.delete() + + # Extract prefix from graph_id (remove database name part) + # graph_id format is typically "prefix_database_name" + parts = graph_id.split('_') + if len(parts) >= 2: + # Reconstruct prefix by joining all parts except the last one + prefix = '_'.join(parts[:-1]) + else: + prefix = graph_id + + # Reuse the existing load method to reload the schema + success, message = False, "" + async for progress in SQLServerLoader.load(prefix, db_url): + success, message = progress + + if success: + logging.info("Graph schema refreshed successfully.") + return True, message + + logging.error("Schema refresh failed") + return False, "Failed to reload schema" + + except Exception as e: # pylint: disable=broad-exception-caught + # Log the error and return failure + logging.error("Error refreshing graph schema: %s", str(e)) + error_msg = "Error refreshing graph schema" + logging.error(error_msg) + return False, error_msg + + @staticmethod + def execute_sql_query(sql_query: str, db_url: str) -> List[Dict[str, Any]]: + """ + Execute a SQL query on the SQL Server database and return the results. + + Args: + sql_query: The SQL query to execute + db_url: SQL Server connection URL in format: + sqlserver://username:password@host:port/database + + Returns: + List of dictionaries containing the query results + """ + try: + # Parse connection URL + conn_params = SQLServerLoader._parse_sqlserver_url(db_url) + + # Connect to SQL Server database + conn = pymssql.connect(**conn_params) # pylint: disable=no-member + cursor = conn.cursor(as_dict=True) + + # Execute the SQL query + cursor.execute(sql_query) + + # Check if the query returns results (SELECT queries) + if cursor.description is not None: + # This is a SELECT query or similar that returns rows + results = cursor.fetchall() + result_list = [] + for row in results: + # Serialize each value to ensure JSON compatibility + serialized_row = { + key: SQLServerLoader._serialize_value(value) + for key, value in row.items() + } + result_list.append(serialized_row) + else: + # This is an INSERT, UPDATE, DELETE, or other non-SELECT query + # Return information about the operation + affected_rows = cursor.rowcount + sql_type = sql_query.strip().split()[0].upper() + + if sql_type in ['INSERT', 'UPDATE', 'DELETE']: + result_list = [{ + "operation": sql_type, + "affected_rows": affected_rows, + "status": "success" + }] + else: + # For other types of queries (CREATE, DROP, etc.) + result_list = [{ + "operation": sql_type, + "status": "success" + }] + + # Commit the transaction for write operations + conn.commit() + + # Close database connection + cursor.close() + conn.close() + + return result_list + + except pymssql.Error as e: + # Rollback in case of error + if 'conn' in locals(): + conn.rollback() + cursor.close() + conn.close() + logging.error("SQL Server query execution error: %s", e) + raise SQLServerQueryError(f"SQL Server query execution error: {str(e)}") from e + except Exception as e: + # Rollback in case of error + if 'conn' in locals(): + conn.rollback() + cursor.close() + conn.close() + logging.error("Error executing SQL query: %s", e) + raise SQLServerQueryError(f"Error executing SQL query: {str(e)}") from e diff --git a/api/sql_utils/sql_sanitizer.py b/api/sql_utils/sql_sanitizer.py index 6f7d127e..c1a304fa 100644 --- a/api/sql_utils/sql_sanitizer.py +++ b/api/sql_utils/sql_sanitizer.py @@ -24,6 +24,17 @@ class SQLIdentifierQuoter: 'EXCEPT', 'CASE', 'WHEN', 'THEN', 'ELSE', 'END', 'CAST', 'ASC', 'DESC' } + # Quote pairs recognized as already-quoted identifiers + _QUOTE_PAIRS = [('"', '"'), ('`', '`'), ('[', ']')] + + @classmethod + def _is_already_quoted(cls, identifier: str) -> bool: + """Check if an identifier is already quoted.""" + return any( + identifier.startswith(start) and identifier.endswith(end) + for start, end in cls._QUOTE_PAIRS + ) + @classmethod def needs_quoting(cls, identifier: str) -> bool: """ @@ -36,8 +47,7 @@ def needs_quoting(cls, identifier: str) -> bool: True if the identifier needs quoting, False otherwise """ # Already quoted - if (identifier.startswith('"') and identifier.endswith('"')) or \ - (identifier.startswith('`') and identifier.endswith('`')): + if cls._is_already_quoted(identifier): return False # Check if it's a SQL keyword @@ -54,7 +64,8 @@ def quote_identifier(identifier: str, quote_char: str = '"') -> str: Args: identifier: The identifier to quote - quote_char: The quote character to use (default: " for PostgreSQL/standard SQL) + quote_char: The quote character to use (default: " for PostgreSQL/standard SQL, + use ` for MySQL, [ for SQL Server) Returns: Quoted identifier @@ -62,10 +73,13 @@ def quote_identifier(identifier: str, quote_char: str = '"') -> str: identifier = identifier.strip() # Don't double-quote - if (identifier.startswith('"') and identifier.endswith('"')) or \ - (identifier.startswith('`') and identifier.endswith('`')): + if SQLIdentifierQuoter._is_already_quoted(identifier): return identifier + # SQL Server uses bracket pairs: [identifier] + if quote_char == '[': + return f'[{identifier}]' + return f'{quote_char}{identifier}{quote_char}' @classmethod @@ -167,5 +181,7 @@ def get_quote_char(db_type: str) -> str: """ if db_type.lower() in ['mysql', 'mariadb']: return '`' - # PostgreSQL, SQLite, SQL Server (standard SQL) use double quotes + if db_type.lower() in ['sqlserver', 'mssql']: + return '[' + # PostgreSQL, SQLite use double quotes (standard SQL) return '"' diff --git a/app/src/components/modals/DatabaseModal.tsx b/app/src/components/modals/DatabaseModal.tsx index e3a7f47a..26d7cb6b 100644 --- a/app/src/components/modals/DatabaseModal.tsx +++ b/app/src/components/modals/DatabaseModal.tsx @@ -134,7 +134,7 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { builtUrl.searchParams.set('warehouse', warehouse); dbUrl = builtUrl.toString(); } else { - const protocol = selectedDatabase === 'mysql' ? 'mysql' : 'postgresql'; + const protocol = selectedDatabase === 'mysql' ? 'mysql' : selectedDatabase === 'sqlserver' ? 'sqlserver' : 'postgresql'; const builtUrl = new URL(`${protocol}://${host}:${port}/${database}`); builtUrl.username = username; builtUrl.password = password; @@ -301,7 +301,7 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { Connect to Database - Connect to PostgreSQL, MySQL, or Snowflake database using a connection URL or manual entry.{" "} + Connect to PostgreSQL, MySQL, Snowflake, or SQL Server database using a connection URL or manual entry.{" "} { Snowflake + +
+
+ SQL Server +
+
@@ -385,6 +391,8 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { ? 'postgresql://username:password@host:5432/database' : selectedDatabase === 'mysql' ? 'mysql://username:password@host:3306/database' + : selectedDatabase === 'sqlserver' + ? 'sqlserver://username:password@host:1433/database' : 'snowflake://username:password@account/database/schema?warehouse=warehouse_name' } value={connectionUrl} @@ -551,7 +559,7 @@ const DatabaseModal = ({ open, onOpenChange }: DatabaseModalProps) => { setPort(e.target.value)} className="bg-muted border-border focus-visible:ring-purple-500" diff --git a/pyproject.toml b/pyproject.toml index e5ace043..0f82416e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "falkordb~=1.6.0", "psycopg2-binary~=2.9.11", "pymysql~=1.1.0", + "pymssql~=2.3.13", "authlib~=1.6.4", "itsdangerous~=2.2.0", "jsonschema~=4.26.0", diff --git a/tests/test_sqlserver_loader.py b/tests/test_sqlserver_loader.py new file mode 100644 index 00000000..22e9ed58 --- /dev/null +++ b/tests/test_sqlserver_loader.py @@ -0,0 +1,157 @@ +"""Tests for SQL Server loader functionality.""" +# pylint: disable=protected-access + +import asyncio +import datetime +import decimal +from unittest.mock import patch, MagicMock + +import pytest + +from api.loaders.sqlserver_loader import SQLServerLoader + + +async def _consume_loader(loader_gen): + """Consume an async generator loader and return the final result.""" + last_success, last_message = False, "" + async for success, message in loader_gen: + last_success, last_message = success, message + return last_success, last_message + + +class TestSQLServerLoader: + """Test cases for SQLServerLoader class.""" + + def test_parse_sqlserver_url_valid(self): + """Test parsing a valid SQL Server URL.""" + url = "sqlserver://testuser:testpass@localhost:1433/testdb" + result = SQLServerLoader._parse_sqlserver_url(url) + expected = { + 'server': 'localhost', + 'port': 1433, + 'user': 'testuser', + 'password': 'testpass', + 'database': 'testdb' + } + assert result == expected + + def test_parse_sqlserver_url_default_port(self): + """Test parsing SQL Server URL without port (should default to 1433).""" + url = "sqlserver://testuser:testpass@localhost/testdb" + result = SQLServerLoader._parse_sqlserver_url(url) + assert result['port'] == 1433 + assert result['server'] == 'localhost' + + def test_parse_sqlserver_url_no_password(self): + """Test parsing SQL Server URL without password.""" + url = "sqlserver://testuser@localhost:1433/testdb" + result = SQLServerLoader._parse_sqlserver_url(url) + assert result['password'] == "" + assert result['user'] == 'testuser' + + def test_parse_sqlserver_url_with_query_params(self): + """Test parsing SQL Server URL with query parameters.""" + url = "sqlserver://testuser:testpass@localhost:1433/testdb?charset=utf8" + result = SQLServerLoader._parse_sqlserver_url(url) + assert result['database'] == 'testdb' # Should strip query params + + def test_parse_sqlserver_url_invalid_format(self): + """Test parsing invalid SQL Server URL format.""" + with pytest.raises(ValueError, match="Invalid SQL Server URL format"): + SQLServerLoader._parse_sqlserver_url("postgresql://user@host/db") + + def test_parse_sqlserver_url_missing_host(self): + """Test parsing SQL Server URL without host.""" + with pytest.raises(ValueError, match="SQL Server URL must include username and host"): + SQLServerLoader._parse_sqlserver_url("sqlserver://") + + def test_parse_sqlserver_url_missing_database(self): + """Test parsing SQL Server URL without database.""" + with pytest.raises(ValueError, match="SQL Server URL must include database name"): + SQLServerLoader._parse_sqlserver_url("sqlserver://user@host") + + def test_serialize_value(self): + """Test value serialization for JSON compatibility.""" + # Test datetime + dt = datetime.datetime(2023, 1, 1, 12, 0, 0) + assert SQLServerLoader._serialize_value(dt) == "2023-01-01T12:00:00" + + # Test date + d = datetime.date(2023, 1, 1) + assert SQLServerLoader._serialize_value(d) == "2023-01-01" + + # Test time + t = datetime.time(12, 0, 0) + assert SQLServerLoader._serialize_value(t) == "12:00:00" + + # Test decimal + dec = decimal.Decimal("123.45") + assert SQLServerLoader._serialize_value(dec) == 123.45 + + # Test bytes + b = b'\x00\x01\x02' + assert SQLServerLoader._serialize_value(b) == "000102" + + # Test None + assert SQLServerLoader._serialize_value(None) is None + + # Test regular value + assert SQLServerLoader._serialize_value("test") == "test" + + def test_is_schema_modifying_query(self): + """Test detection of schema-modifying queries.""" + # Schema-modifying queries + assert SQLServerLoader.is_schema_modifying_query("CREATE TABLE test (id INT)")[0] is True + assert SQLServerLoader.is_schema_modifying_query("DROP TABLE test")[0] is True + assert SQLServerLoader.is_schema_modifying_query( + "ALTER TABLE test ADD COLUMN name VARCHAR(50)")[0] is True + assert SQLServerLoader.is_schema_modifying_query( + " CREATE INDEX idx_name ON test(name)")[0] is True + + # Non-schema-modifying queries + assert SQLServerLoader.is_schema_modifying_query("SELECT * FROM test")[0] is False + assert SQLServerLoader.is_schema_modifying_query("INSERT INTO test VALUES (1)")[0] is False + assert SQLServerLoader.is_schema_modifying_query("UPDATE test SET name = 'test'")[0] is False + assert SQLServerLoader.is_schema_modifying_query("DELETE FROM test WHERE id = 1")[0] is False + + # Edge cases + assert SQLServerLoader.is_schema_modifying_query("")[0] is False + assert SQLServerLoader.is_schema_modifying_query(" ")[0] is False + + @patch('pymssql.connect') + def test_connection_error(self, mock_connect): + """Test handling of SQL Server connection errors.""" + # Mock connection failure + mock_connect.side_effect = Exception("Connection failed") + + success, message = asyncio.run( + _consume_loader(SQLServerLoader.load("test_prefix", "sqlserver://user:pass@host:1433/db")) + ) + + assert success is False + assert "Failed to load SQL Server database schema" in message + + @patch('pymssql.connect') + @patch('api.loaders.sqlserver_loader.load_to_graph') + def test_successful_load(self, mock_load_to_graph, mock_connect): + """Test successful SQL Server schema loading.""" + # Mock database connection and cursor + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.fetchall.return_value = [ + {'table_name': 'users', 'table_comment': 'User table'} + ] + mock_conn.cursor.return_value = mock_cursor + mock_connect.return_value = mock_conn + + # Mock the extract methods to return minimal data + with patch.object(SQLServerLoader, 'extract_tables_info', + return_value={'users': {'description': 'User table'}}): + with patch.object(SQLServerLoader, 'extract_relationships', return_value={}): + success, message = asyncio.run(_consume_loader(SQLServerLoader.load( + "test_prefix", "sqlserver://user:pass@localhost:1433/testdb" + ))) + + assert success is True + assert "SQL Server schema loaded successfully" in message + mock_load_to_graph.assert_called_once() diff --git a/uv.lock b/uv.lock index d4247255..15d36d8c 100644 --- a/uv.lock +++ b/uv.lock @@ -1920,6 +1920,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/6f/9ac2548e290764781f9e7e2aaf0685b086379dabfb29ca38536985471eaf/pylint-4.0.5-py3-none-any.whl", hash = "sha256:00f51c9b14a3b3ae08cff6b2cdd43f28165c78b165b628692e428fb1f8dc2cf2", size = 536694, upload-time = "2026-02-20T09:07:31.028Z" }, ] +[[package]] +name = "pymssql" +version = "2.3.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/cc/843c044b7f71ee329436b7327c578383e2f2499313899f88ad267cdf1f33/pymssql-2.3.13.tar.gz", hash = "sha256:2137e904b1a65546be4ccb96730a391fcd5a85aab8a0632721feb5d7e39cfbce", size = 203153, upload-time = "2026-02-14T05:00:36.865Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ba/60/a2e8a8a38f7be21d54402e2b3365cd56f1761ce9f2706c97f864e8aa8300/pymssql-2.3.13-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:cf4f32b4a05b66f02cb7d55a0f3bcb0574a6f8cf0bee4bea6f7b104038364733", size = 3158689, upload-time = "2026-02-14T04:59:46.982Z" }, + { url = "https://files.pythonhosted.org/packages/43/9e/0cf0ffb9e2f73238baf766d8e31d7237b5bee3cc1bb29a376b404610994a/pymssql-2.3.13-cp312-cp312-macosx_15_0_x86_64.whl", hash = "sha256:2b056eb175955f7fb715b60dc1c0c624969f4d24dbdcf804b41ab1e640a2b131", size = 2960018, upload-time = "2026-02-14T04:59:48.668Z" }, + { url = "https://files.pythonhosted.org/packages/93/ea/bc27354feaca717faa4626911f6b19bb62985c87dda28957c63de4de5895/pymssql-2.3.13-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:319810b89aa64b99d9c5c01518752c813938df230496fa2c4c6dda0603f04c4c", size = 3065719, upload-time = "2026-02-14T04:59:50.369Z" }, + { url = "https://files.pythonhosted.org/packages/1e/7a/8028681c96241fb5fc850b87c8959402c353e4b83c6e049a99ffa67ded54/pymssql-2.3.13-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c0ea72641cb0f8bce7ad8565dbdbda4a7437aa58bce045f2a3a788d71af2e4be", size = 3190567, upload-time = "2026-02-14T04:59:52.202Z" }, + { url = "https://files.pythonhosted.org/packages/aa/f1/ab5b76adbbd6db9ce746d448db34b044683522e7e7b95053f9dd0165297b/pymssql-2.3.13-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1493f63d213607f708a5722aa230776ada726ccdb94097fab090a1717a2534e0", size = 3710481, upload-time = "2026-02-14T04:59:54.01Z" }, + { url = "https://files.pythonhosted.org/packages/59/aa/2fa0951475cd0a1829e0b8bfbe334d04ece4bce11546a556b005c4100689/pymssql-2.3.13-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eb3275985c23479e952d6462ae6c8b2b6993ab6b99a92805a9c17942cf3d5b3d", size = 3453789, upload-time = "2026-02-14T04:59:56.841Z" }, + { url = "https://files.pythonhosted.org/packages/78/08/8cd2af9003f9fc03912b658a64f5a4919dcd68f0dd3bbc822b49a3d14fd9/pymssql-2.3.13-cp312-cp312-win_amd64.whl", hash = "sha256:a930adda87bdd8351a5637cf73d6491936f34e525a5e513068a6eac742f69cdb", size = 1994709, upload-time = "2026-02-14T04:59:58.972Z" }, + { url = "https://files.pythonhosted.org/packages/d4/4f/ee15b1f6b11e7c3accdc7da7840a019b63f12ba09eaa008acc601182f516/pymssql-2.3.13-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:30918bb044242865c01838909777ef5e0f1b9ecd7f5882346aefa57f4414b29c", size = 3156333, upload-time = "2026-02-14T05:00:01.21Z" }, + { url = "https://files.pythonhosted.org/packages/79/03/aea5c77bad4a52649a1d9f786a1d9ce1c83d50f1a75df288e292737b6d80/pymssql-2.3.13-cp313-cp313-macosx_15_0_x86_64.whl", hash = "sha256:1c6d0b2d7961f159a07e4f0d8cc81f70ceab83f5e7fd1e832a2d069e1d67ee4e", size = 2957990, upload-time = "2026-02-14T05:00:03.11Z" }, + { url = "https://files.pythonhosted.org/packages/5f/f8/30ac16fba32ff066b05f12c392d7b812fe11f06cb62d1d86ca5177c50a8b/pymssql-2.3.13-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:16c5957a3c9e51a03276bfd76a22431e2bc4c565e2e95f2cbb3559312edda230", size = 3065264, upload-time = "2026-02-14T05:00:05.377Z" }, + { url = "https://files.pythonhosted.org/packages/a9/98/7568447bf85921d21453fd56e19b6c9591d595fde0546c5a569f3ae937a8/pymssql-2.3.13-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0fddd24efe9d18bbf174fab7c6745b0927773718387f5517cf8082241f721a68", size = 3190039, upload-time = "2026-02-14T05:00:06.925Z" }, + { url = "https://files.pythonhosted.org/packages/35/f1/4d9d275ebaac42cdd49d40d504ccb648f27710660c8b60cc427752438c09/pymssql-2.3.13-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:123c55ee41bc7a82c76db12e2eb189b50d0d7a11222b4f8789206d1cda3b33b9", size = 3710151, upload-time = "2026-02-14T05:00:08.424Z" }, + { url = "https://files.pythonhosted.org/packages/6f/bd/a5cc6244fd27d3ea0cc82f12a7d38a24d7fd90b0022afd250014e8bfba15/pymssql-2.3.13-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e053b443e842f9e1698fcb2b23a4bff1ff3d410894d880064e754ad823d541e5", size = 3453156, upload-time = "2026-02-14T05:00:09.978Z" }, + { url = "https://files.pythonhosted.org/packages/26/d0/c20ff0bbffd18db528bcc7b0c68b25c12ad563ed67c56ceca87c58f7399e/pymssql-2.3.13-cp313-cp313-win_amd64.whl", hash = "sha256:5c045c0f1977a679cc30d5acd9da3f8aeb2dc6e744895b26444b4a2f20dad9a0", size = 1995236, upload-time = "2026-02-14T05:00:11.495Z" }, + { url = "https://files.pythonhosted.org/packages/ec/5f/6b64f78181d680f655ab40ba7b34cb68c045a2f4e04a10a70d768cd383b7/pymssql-2.3.13-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:fc5482969c813b0a45ce51c41844ae5bfa8044ad5ef8b4820ef6de7d4545b7f2", size = 3158377, upload-time = "2026-02-14T05:00:13.581Z" }, + { url = "https://files.pythonhosted.org/packages/ff/24/155dbb0992c431496d440f47fb9d587cd0059ee20baf65e3d891794d862a/pymssql-2.3.13-cp314-cp314-macosx_15_0_x86_64.whl", hash = "sha256:ff5be7ab1d643dbce2ee3424d2ef9ae8e4146cf75bd20946bc7a6108e3ad1e47", size = 2959039, upload-time = "2026-02-14T05:00:15.883Z" }, + { url = "https://files.pythonhosted.org/packages/c9/89/b453dd1b1188779621fb974ac715ab2e738f4a0b69f7291ab014298bd80d/pymssql-2.3.13-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8d66ce0a249d2e3b57369048d71e1f00d08dfb90a758d134da0250ae7bc739c1", size = 3063862, upload-time = "2026-02-14T05:00:17.537Z" }, + { url = "https://files.pythonhosted.org/packages/02/e5/96f57c78162013678ecc3f3f7e5fb52c83ee07beef26906d0870770c3ef6/pymssql-2.3.13-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d663c908414a6a032f04d17628138b1782af916afc0df9fefac4751fa394c3ac", size = 3188155, upload-time = "2026-02-14T05:00:19.011Z" }, + { url = "https://files.pythonhosted.org/packages/cd/a2/4bee9484734ae0c55d10a2f6ff82dd4e416f52420755161b8760c817ad64/pymssql-2.3.13-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:aa5e07eff7e6e8bd4ba22c30e4cb8dd073e138cd272090603609a15cc5dbc75b", size = 3709344, upload-time = "2026-02-14T05:00:21.139Z" }, + { url = "https://files.pythonhosted.org/packages/37/cf/3520d96afa213c88db4f4a1988199db476d869a62afdd5d9c4635c184631/pymssql-2.3.13-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:db77da1a3fc9b5b5c5400639d79d7658ba7ad620957100c5b025be608b562193", size = 3451799, upload-time = "2026-02-14T05:00:22.504Z" }, + { url = "https://files.pythonhosted.org/packages/25/50/4be9bd9cf4b43208a7175117a533ece200cfe4131a39f9909bdc7560ddeb/pymssql-2.3.13-cp314-cp314-win_amd64.whl", hash = "sha256:7d7037d2b5b907acc7906d0479924db2935a70c720450c41339146a4ada2b93d", size = 2049139, upload-time = "2026-02-14T05:00:23.951Z" }, +] + [[package]] name = "pymysql" version = "1.1.2" @@ -2146,6 +2175,7 @@ dependencies = [ { name = "jsonschema" }, { name = "litellm" }, { name = "psycopg2-binary" }, + { name = "pymssql" }, { name = "pymysql" }, { name = "python-dotenv" }, { name = "python-multipart" }, @@ -2176,6 +2206,7 @@ requires-dist = [ { name = "jsonschema", specifier = "~=4.26.0" }, { name = "litellm", specifier = ">=1.83.0" }, { name = "psycopg2-binary", specifier = "~=2.9.11" }, + { name = "pymssql", specifier = "~=2.3.13" }, { name = "pymysql", specifier = "~=1.1.0" }, { name = "python-dotenv", specifier = "~=1.1.0" }, { name = "python-multipart", specifier = "~=0.0.10" },