Skip to content

Commit 7fdd83e

Browse files
Gabe Pescoclaude
authored andcommitted
fix(databricks): use shared connection pool to prevent OAuth CSRF race
When concurrent_tasks > 1, DatabricksConnectionConfig previously used ThreadLocalConnectionPool, which creates a separate databricks.sql.connect() per thread. For U2M OAuth (databricks-oauth / azure-oauth), each thread triggers an independent browser-based OAuth flow; these race on the CSRF state parameter and cause MismatchingStateError. Setting shared_connection = True causes ThreadLocalSharedConnectionPool to be used instead: a single connection is created (behind a lock) and each thread receives its own cursor, so only one OAuth flow is ever initiated. This mirrors the existing pattern used by DuckDBConnectionConfig. Fixes #5646 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> Signed-off-by: Gabe Pesco <PescoG@medinsight.milliman.com>
1 parent 8f092ac commit 7fdd83e

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

sqlmesh/core/config/connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,8 @@ class DatabricksConnectionConfig(ConnectionConfig):
813813
DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks"
814814
DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3
815815

816+
shared_connection: t.ClassVar[bool] = True
817+
816818
_concurrent_tasks_validator = concurrent_tasks_validator
817819
_http_headers_validator = http_headers_validator
818820

tests/core/test_connection_config.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,35 @@ def test_databricks(make_config):
14221422
)
14231423

14241424

1425+
def test_databricks_shared_connection(make_config):
1426+
"""Databricks should use a shared connection pool to prevent OAuth CSRF races.
1427+
1428+
When concurrent_tasks > 1, ThreadLocalConnectionPool creates one connection per
1429+
thread. For U2M OAuth, each thread triggers its own browser-based OAuth flow;
1430+
these race on the CSRF state parameter and cause MismatchingStateError.
1431+
1432+
Setting shared_connection = True causes ThreadLocalSharedConnectionPool to be
1433+
used instead: a single connection is created (behind a lock) and each thread
1434+
gets its own cursor, so only one OAuth flow is ever initiated.
1435+
1436+
See: https://github.com/tobymao/sqlmesh/issues/5646
1437+
"""
1438+
from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool
1439+
1440+
config = make_config(
1441+
type="databricks",
1442+
server_hostname="dbc-test.cloud.databricks.com",
1443+
http_path="sql/test/foo",
1444+
access_token="test-token",
1445+
concurrent_tasks=4,
1446+
)
1447+
assert isinstance(config, DatabricksConnectionConfig)
1448+
assert config.shared_connection is True
1449+
1450+
adapter = config.create_engine_adapter()
1451+
assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool)
1452+
1453+
14251454
def test_engine_import_validator():
14261455
with pytest.raises(
14271456
ConfigError,

0 commit comments

Comments
 (0)