Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions src/masoniteorm/connections/MSSQLConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@

CONNECTION_POOL = []

# Options that are handled explicitly when building the connection string.
# Anything in self.options that is NOT in this set will be appended verbatim
# as additional "Key=Value" pairs in the pyodbc connection string.
_MSSQL_KNOWN_OPTIONS = frozenset(
{
"driver",
"integrated_security",
"connection_timeout",
"authentication",
"instance",
"trusted_connection",
}
)


class MSSQLConnection(BaseConnection):
"""MSSQL Connection class."""
Expand All @@ -25,10 +39,7 @@ def __init__(
name=None,
):
self.host = host
if port:
self.port = int(port)
else:
self.port = port
self.port = int(port) if port else None
self.database = database
self.user = user
self.password = password
Expand All @@ -41,6 +52,50 @@ def __init__(
if name:
self.name = name

def _build_connection_string(self):
"""Build the pyodbc connection string from self.options.

Known options are mapped to their canonical ODBC connection string
keys. Any remaining entries in self.options that are not in
``_MSSQL_KNOWN_OPTIONS`` are appended verbatim as ``Key=Value`` pairs,
allowing arbitrary ODBC attributes to be passed through.

Returns:
str: A semicolon-delimited ODBC connection string.
"""
driver = self.options.get("driver", "ODBC Driver 17 for SQL Server")
connection_timeout = str(self.options.get("connection_timeout", "30"))
integrated_security = self.options.get("integrated_security")
trusted_connection = self.options.get("trusted_connection")
authentication = self.options.get("authentication")
instance = self.options.get("instance", "")

if instance:
instance = "\\" + instance

parts = [
f"DRIVER={driver}",
f"SERVER={self.host}{instance},{self.port}",
f"Connection Timeout={connection_timeout}",
f"DATABASE={self.database}",
f"UID={self.user}",
f"PWD={self.password}",
]

if integrated_security:
parts.append(f"Integrated Security={integrated_security}")
if trusted_connection:
parts.append(f"Trusted_Connection={trusted_connection}")
if authentication:
parts.append(f"Authentication={authentication}")

# Append any extra options not handled above.
for key, value in self.options.items():
if key not in _MSSQL_KNOWN_OPTIONS:
parts.append(f"{key}={value}")

return ";".join(parts)

def make_connection(self):
"""This sets the connection on the connection class"""
try:
Expand All @@ -53,18 +108,8 @@ def make_connection(self):
if self.has_global_connection():
return self.get_global_connection()

driver = self.options.get("driver", "ODBC Driver 17 for SQL Server")
integrated_security = self.options.get("integrated_security")
connection_timeout = str(self.options.get("connection_timeout", "30"))
authentication = self.options.get("authentication")
instance = self.options.get("instance", "")
trusted_connection = self.options.get("trusted_connection")

if instance:
instance = "\\" + instance

self._connection = pyodbc.connect(
f"DRIVER={driver};SERVER={self.host}{instance if instance else ''},{self.port};Connection Timeout={connection_timeout};DATABASE={self.database}{f';Integrated Security={integrated_security}' if integrated_security else ''};UID={self.user};PWD={self.password}{f';Trusted_Connection={trusted_connection}' if trusted_connection else ''}{f';Authentication={authentication}' if authentication else ''}",
self._build_connection_string(),
autocommit=True,
)

Expand Down Expand Up @@ -134,7 +179,6 @@ def query(self, query, bindings=(), results="*"):
Returns:
dict|None -- Returns a dictionary of results or None
"""

try:
if not self.open:
self.make_connection()
Expand Down
4 changes: 1 addition & 3 deletions src/masoniteorm/connections/MySQLConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def __init__(
name=None,
):
self.host = host
self.port = port
if str(port).isdigit():
self.port = int(self.port)
self.port = int(port) if port else None
self.database = database

self.user = user
Expand Down
122 changes: 82 additions & 40 deletions src/masoniteorm/connections/PostgresConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,39 @@

CONNECTION_POOL = []

# All keyword arguments accepted directly by psycopg2.connect(), excluding
# 'options' which we build ourselves from leftover entries in self.options.
_PSYCOPG2_CONNECT_KWARGS = frozenset(
{
# Core identity
"database",
"dbname",
"user",
"password",
"host",
"port",
# SSL
"sslmode",
"sslcert",
"sslkey",
"sslrootcert",
"sslcrl",
"sslpassword",
# Timeouts & keepalives
"connect_timeout",
"keepalives",
"keepalives_idle",
"keepalives_interval",
"keepalives_count",
# Application identification
"application_name",
"fallback_application_name",
# Misc
"cursor_factory",
"async",
}
)


class PostgresConnection(BaseConnection):
"""Postgres Connection class."""
Expand All @@ -25,10 +58,7 @@ def __init__(
name=None,
):
self.host = host
if port:
self.port = int(port)
else:
self.port = port
self.port = int(port) if port else None
self.database = database
self.user = user
self.password = password
Expand All @@ -46,6 +76,47 @@ def __init__(
if name:
self.name = name

def _build_connect_kwargs(self):
"""Return a dict of kwargs ready to be unpacked into psycopg2.connect().

self.options is partitioned into two groups:
- Keys that are valid psycopg2.connect() keyword arguments are passed
through directly.
- All remaining keys are treated as PostgreSQL GUC parameters and
appended to the ``options`` string as ``-c key=value`` entries
alongside the schema search_path (when set).
"""
# --- Base connection parameters ---------------------------------
kwargs = {
"database": self.database,
"user": self.user,
"password": self.password,
"host": self.host,
"port": self.port,
}

# --- Split self.options into direct kwargs vs GUC params --------
guc_params = {}
for key, value in self.options.items():
if key in _PSYCOPG2_CONNECT_KWARGS:
kwargs[key] = value
else:
guc_params[key] = value

# --- Build the options / GUC string -----------------------------
# search_path comes from the instance schema or full_details, but can
# also be overridden via self.options by passing it as a leftover key.
schema = self.schema or self.full_details.get("schema")
if schema and "search_path" not in guc_params:
guc_params["search_path"] = schema

if guc_params:
kwargs["options"] = " ".join(
f"-c {k}={v}" for k, v in guc_params.items()
)

return kwargs

def make_connection(self):
"""This sets the connection on the connection class"""
if self.has_global_connection():
Expand All @@ -66,9 +137,13 @@ def create_connection(self):
import psycopg2
except ModuleNotFoundError:
raise DriverNotFound(
"You must have the 'psycopg2' package installed to make a connection to Postgres. Please install it using 'pip install psycopg2-binary'"
"You must have the 'psycopg2' package installed to make a "
"connection to Postgres. Please install it using "
"'pip install psycopg2-binary'"
)

connect_kwargs = self._build_connect_kwargs()

# Initialize the connection pool if the option is set
initialize_size = self.full_details.get("connection_pooling_min_size")
if (
Expand All @@ -77,47 +152,15 @@ def create_connection(self):
and len(CONNECTION_POOL) < initialize_size
):
for _ in range(initialize_size - len(CONNECTION_POOL)):
connection = psycopg2.connect(
database=self.database,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
sslmode=self.options.get("sslmode"),
sslcert=self.options.get("sslcert"),
sslkey=self.options.get("sslkey"),
sslrootcert=self.options.get("sslrootcert"),
options=(
f"-c search_path={self.schema or self.full_details.get('schema')}"
if self.schema or self.full_details.get("schema")
else ""
),
)
CONNECTION_POOL.append(connection)
CONNECTION_POOL.append(psycopg2.connect(**connect_kwargs))

if (
self.full_details.get("connection_pooling_enabled")
and CONNECTION_POOL
and len(CONNECTION_POOL) > 0
):
connection = CONNECTION_POOL.pop()
else:
connection = psycopg2.connect(
database=self.database,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
sslmode=self.options.get("sslmode"),
sslcert=self.options.get("sslcert"),
sslkey=self.options.get("sslkey"),
sslrootcert=self.options.get("sslrootcert"),
options=(
f"-c search_path={self.schema or self.full_details.get('schema')}"
if self.schema or self.full_details.get("schema")
else ""
),
)
connection = psycopg2.connect(**connect_kwargs)

return connection

Expand Down Expand Up @@ -222,4 +265,3 @@ def query(self, query, bindings=(), results="*"):
if self.get_transaction_level() <= 0:
self.open = 0
self.close_connection()
# self._connection.close()
55 changes: 49 additions & 6 deletions src/masoniteorm/connections/SQLiteConnection.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ def regexp(expr, item):
return reg.search(item) is not None


# Valid keyword arguments for sqlite3.connect() beyond the database path.
_SQLITE3_CONNECT_KWARGS = frozenset(
{
"timeout",
"detect_types",
"isolation_level",
"check_same_thread",
"factory",
"cached_statements",
"uri",
}
)


class SQLiteConnection(BaseConnection):
"""SQLite Connection class."""

Expand All @@ -32,10 +46,7 @@ def __init__(
name=None,
):
self.host = host
if port:
self.port = int(port)
else:
self.port = port
self.port = int(port) if port else None
self.database = database
self.user = user
self.password = password
Expand All @@ -48,6 +59,29 @@ def __init__(
if name:
self.name = name

def _build_connect_kwargs(self):
"""Return kwargs for sqlite3.connect() and any leftover PRAGMA settings.

self.options is partitioned into two groups:
- Keys that are valid sqlite3.connect() keyword arguments are returned
in ``connect_kwargs`` and passed directly to the connect call.
- All remaining keys are returned in ``pragma_settings`` and applied
as ``PRAGMA key = value`` statements after the connection is opened.

Returns:
tuple[dict, dict]: (connect_kwargs, pragma_settings)
"""
connect_kwargs = {}
pragma_settings = {}

for key, value in self.options.items():
if key in _SQLITE3_CONNECT_KWARGS:
connect_kwargs[key] = value
else:
pragma_settings[key] = value

return connect_kwargs, pragma_settings

def make_connection(self):
"""This sets the connection on the connection class"""
try:
Expand All @@ -60,11 +94,20 @@ def make_connection(self):
if self.has_global_connection():
return self.get_global_connection()

self._connection = sqlite3.connect(self.database, isolation_level=None)
self._connection.create_function("REGEXP", 2, regexp)
connect_kwargs, pragma_settings = self._build_connect_kwargs()

# isolation_level=None enables autocommit; only set it as a default if
# the caller hasn't provided their own isolation_level via options.
connect_kwargs.setdefault("isolation_level", None)

self._connection = sqlite3.connect(self.database, **connect_kwargs)
self._connection.create_function("REGEXP", 2, regexp)
self._connection.row_factory = sqlite3.Row

# Apply any leftover options as PRAGMA statements.
for key, value in pragma_settings.items():
self._connection.execute(f"PRAGMA {key} = {value}")

self.enable_disable_foreign_keys()

self.open = 1
Expand Down
Loading