diff --git a/src/masoniteorm/connections/MSSQLConnection.py b/src/masoniteorm/connections/MSSQLConnection.py index 53b87f95..286ffaa8 100644 --- a/src/masoniteorm/connections/MSSQLConnection.py +++ b/src/masoniteorm/connections/MSSQLConnection.py @@ -6,6 +6,20 @@ CONNECTION_POOL = [] +# Options that are handled explicitly when building the connection string. +# Anything in self.options that is NOT in this set will be appended verbatim +# as additional "Key=Value" pairs in the pyodbc connection string. +_MSSQL_KNOWN_OPTIONS = frozenset( + { + "driver", + "integrated_security", + "connection_timeout", + "authentication", + "instance", + "trusted_connection", + } +) + class MSSQLConnection(BaseConnection): """MSSQL Connection class.""" @@ -25,10 +39,7 @@ def __init__( name=None, ): self.host = host - if port: - self.port = int(port) - else: - self.port = port + self.port = int(port) if port else None self.database = database self.user = user self.password = password @@ -41,6 +52,50 @@ def __init__( if name: self.name = name + def _build_connection_string(self): + """Build the pyodbc connection string from self.options. + + Known options are mapped to their canonical ODBC connection string + keys. Any remaining entries in self.options that are not in + ``_MSSQL_KNOWN_OPTIONS`` are appended verbatim as ``Key=Value`` pairs, + allowing arbitrary ODBC attributes to be passed through. + + Returns: + str: A semicolon-delimited ODBC connection string. + """ + driver = self.options.get("driver", "ODBC Driver 17 for SQL Server") + connection_timeout = str(self.options.get("connection_timeout", "30")) + integrated_security = self.options.get("integrated_security") + trusted_connection = self.options.get("trusted_connection") + authentication = self.options.get("authentication") + instance = self.options.get("instance", "") + + if instance: + instance = "\\" + instance + + parts = [ + f"DRIVER={driver}", + f"SERVER={self.host}{instance},{self.port}", + f"Connection Timeout={connection_timeout}", + f"DATABASE={self.database}", + f"UID={self.user}", + f"PWD={self.password}", + ] + + if integrated_security: + parts.append(f"Integrated Security={integrated_security}") + if trusted_connection: + parts.append(f"Trusted_Connection={trusted_connection}") + if authentication: + parts.append(f"Authentication={authentication}") + + # Append any extra options not handled above. + for key, value in self.options.items(): + if key not in _MSSQL_KNOWN_OPTIONS: + parts.append(f"{key}={value}") + + return ";".join(parts) + def make_connection(self): """This sets the connection on the connection class""" try: @@ -53,18 +108,8 @@ def make_connection(self): if self.has_global_connection(): return self.get_global_connection() - driver = self.options.get("driver", "ODBC Driver 17 for SQL Server") - integrated_security = self.options.get("integrated_security") - connection_timeout = str(self.options.get("connection_timeout", "30")) - authentication = self.options.get("authentication") - instance = self.options.get("instance", "") - trusted_connection = self.options.get("trusted_connection") - - if instance: - instance = "\\" + instance - self._connection = pyodbc.connect( - f"DRIVER={driver};SERVER={self.host}{instance if instance else ''},{self.port};Connection Timeout={connection_timeout};DATABASE={self.database}{f';Integrated Security={integrated_security}' if integrated_security else ''};UID={self.user};PWD={self.password}{f';Trusted_Connection={trusted_connection}' if trusted_connection else ''}{f';Authentication={authentication}' if authentication else ''}", + self._build_connection_string(), autocommit=True, ) @@ -134,7 +179,6 @@ def query(self, query, bindings=(), results="*"): Returns: dict|None -- Returns a dictionary of results or None """ - try: if not self.open: self.make_connection() diff --git a/src/masoniteorm/connections/MySQLConnection.py b/src/masoniteorm/connections/MySQLConnection.py index 01820840..305b793b 100644 --- a/src/masoniteorm/connections/MySQLConnection.py +++ b/src/masoniteorm/connections/MySQLConnection.py @@ -26,9 +26,7 @@ def __init__( name=None, ): self.host = host - self.port = port - if str(port).isdigit(): - self.port = int(self.port) + self.port = int(port) if port else None self.database = database self.user = user diff --git a/src/masoniteorm/connections/PostgresConnection.py b/src/masoniteorm/connections/PostgresConnection.py index 374db744..5f61fcda 100644 --- a/src/masoniteorm/connections/PostgresConnection.py +++ b/src/masoniteorm/connections/PostgresConnection.py @@ -6,6 +6,39 @@ CONNECTION_POOL = [] +# All keyword arguments accepted directly by psycopg2.connect(), excluding +# 'options' which we build ourselves from leftover entries in self.options. +_PSYCOPG2_CONNECT_KWARGS = frozenset( + { + # Core identity + "database", + "dbname", + "user", + "password", + "host", + "port", + # SSL + "sslmode", + "sslcert", + "sslkey", + "sslrootcert", + "sslcrl", + "sslpassword", + # Timeouts & keepalives + "connect_timeout", + "keepalives", + "keepalives_idle", + "keepalives_interval", + "keepalives_count", + # Application identification + "application_name", + "fallback_application_name", + # Misc + "cursor_factory", + "async", + } +) + class PostgresConnection(BaseConnection): """Postgres Connection class.""" @@ -25,10 +58,7 @@ def __init__( name=None, ): self.host = host - if port: - self.port = int(port) - else: - self.port = port + self.port = int(port) if port else None self.database = database self.user = user self.password = password @@ -46,6 +76,47 @@ def __init__( if name: self.name = name + def _build_connect_kwargs(self): + """Return a dict of kwargs ready to be unpacked into psycopg2.connect(). + + self.options is partitioned into two groups: + - Keys that are valid psycopg2.connect() keyword arguments are passed + through directly. + - All remaining keys are treated as PostgreSQL GUC parameters and + appended to the ``options`` string as ``-c key=value`` entries + alongside the schema search_path (when set). + """ + # --- Base connection parameters --------------------------------- + kwargs = { + "database": self.database, + "user": self.user, + "password": self.password, + "host": self.host, + "port": self.port, + } + + # --- Split self.options into direct kwargs vs GUC params -------- + guc_params = {} + for key, value in self.options.items(): + if key in _PSYCOPG2_CONNECT_KWARGS: + kwargs[key] = value + else: + guc_params[key] = value + + # --- Build the options / GUC string ----------------------------- + # search_path comes from the instance schema or full_details, but can + # also be overridden via self.options by passing it as a leftover key. + schema = self.schema or self.full_details.get("schema") + if schema and "search_path" not in guc_params: + guc_params["search_path"] = schema + + if guc_params: + kwargs["options"] = " ".join( + f"-c {k}={v}" for k, v in guc_params.items() + ) + + return kwargs + def make_connection(self): """This sets the connection on the connection class""" if self.has_global_connection(): @@ -66,9 +137,13 @@ def create_connection(self): import psycopg2 except ModuleNotFoundError: raise DriverNotFound( - "You must have the 'psycopg2' package installed to make a connection to Postgres. Please install it using 'pip install psycopg2-binary'" + "You must have the 'psycopg2' package installed to make a " + "connection to Postgres. Please install it using " + "'pip install psycopg2-binary'" ) + connect_kwargs = self._build_connect_kwargs() + # Initialize the connection pool if the option is set initialize_size = self.full_details.get("connection_pooling_min_size") if ( @@ -77,47 +152,15 @@ def create_connection(self): and len(CONNECTION_POOL) < initialize_size ): for _ in range(initialize_size - len(CONNECTION_POOL)): - connection = psycopg2.connect( - database=self.database, - user=self.user, - password=self.password, - host=self.host, - port=self.port, - sslmode=self.options.get("sslmode"), - sslcert=self.options.get("sslcert"), - sslkey=self.options.get("sslkey"), - sslrootcert=self.options.get("sslrootcert"), - options=( - f"-c search_path={self.schema or self.full_details.get('schema')}" - if self.schema or self.full_details.get("schema") - else "" - ), - ) - CONNECTION_POOL.append(connection) + CONNECTION_POOL.append(psycopg2.connect(**connect_kwargs)) if ( self.full_details.get("connection_pooling_enabled") - and CONNECTION_POOL and len(CONNECTION_POOL) > 0 ): connection = CONNECTION_POOL.pop() else: - connection = psycopg2.connect( - database=self.database, - user=self.user, - password=self.password, - host=self.host, - port=self.port, - sslmode=self.options.get("sslmode"), - sslcert=self.options.get("sslcert"), - sslkey=self.options.get("sslkey"), - sslrootcert=self.options.get("sslrootcert"), - options=( - f"-c search_path={self.schema or self.full_details.get('schema')}" - if self.schema or self.full_details.get("schema") - else "" - ), - ) + connection = psycopg2.connect(**connect_kwargs) return connection @@ -222,4 +265,3 @@ def query(self, query, bindings=(), results="*"): if self.get_transaction_level() <= 0: self.open = 0 self.close_connection() - # self._connection.close() diff --git a/src/masoniteorm/connections/SQLiteConnection.py b/src/masoniteorm/connections/SQLiteConnection.py index 28aa007c..bd8bf97a 100644 --- a/src/masoniteorm/connections/SQLiteConnection.py +++ b/src/masoniteorm/connections/SQLiteConnection.py @@ -12,6 +12,20 @@ def regexp(expr, item): return reg.search(item) is not None +# Valid keyword arguments for sqlite3.connect() beyond the database path. +_SQLITE3_CONNECT_KWARGS = frozenset( + { + "timeout", + "detect_types", + "isolation_level", + "check_same_thread", + "factory", + "cached_statements", + "uri", + } +) + + class SQLiteConnection(BaseConnection): """SQLite Connection class.""" @@ -32,10 +46,7 @@ def __init__( name=None, ): self.host = host - if port: - self.port = int(port) - else: - self.port = port + self.port = int(port) if port else None self.database = database self.user = user self.password = password @@ -48,6 +59,29 @@ def __init__( if name: self.name = name + def _build_connect_kwargs(self): + """Return kwargs for sqlite3.connect() and any leftover PRAGMA settings. + + self.options is partitioned into two groups: + - Keys that are valid sqlite3.connect() keyword arguments are returned + in ``connect_kwargs`` and passed directly to the connect call. + - All remaining keys are returned in ``pragma_settings`` and applied + as ``PRAGMA key = value`` statements after the connection is opened. + + Returns: + tuple[dict, dict]: (connect_kwargs, pragma_settings) + """ + connect_kwargs = {} + pragma_settings = {} + + for key, value in self.options.items(): + if key in _SQLITE3_CONNECT_KWARGS: + connect_kwargs[key] = value + else: + pragma_settings[key] = value + + return connect_kwargs, pragma_settings + def make_connection(self): """This sets the connection on the connection class""" try: @@ -60,11 +94,20 @@ def make_connection(self): if self.has_global_connection(): return self.get_global_connection() - self._connection = sqlite3.connect(self.database, isolation_level=None) - self._connection.create_function("REGEXP", 2, regexp) + connect_kwargs, pragma_settings = self._build_connect_kwargs() + # isolation_level=None enables autocommit; only set it as a default if + # the caller hasn't provided their own isolation_level via options. + connect_kwargs.setdefault("isolation_level", None) + + self._connection = sqlite3.connect(self.database, **connect_kwargs) + self._connection.create_function("REGEXP", 2, regexp) self._connection.row_factory = sqlite3.Row + # Apply any leftover options as PRAGMA statements. + for key, value in pragma_settings.items(): + self._connection.execute(f"PRAGMA {key} = {value}") + self.enable_disable_foreign_keys() self.open = 1