diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 9c5efdc4..24e10d5f 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2724,6 +2724,10 @@ def bulkcopy( f"for auth_type '{self.connection._auth_type}': {e}" ) from e pycore_context["access_token"] = raw_token + # Token replaces credential fields — py-core's validator rejects + # access_token combined with authentication/user_name/password. + for key in ("authentication", "user_name", "password"): + pycore_context.pop(key, None) logger.debug( "Bulk copy: acquired fresh Azure AD token for auth_type=%s", self.connection._auth_type, diff --git a/tests/test_020_bulkcopy_auth_cleanup.py b/tests/test_020_bulkcopy_auth_cleanup.py new file mode 100644 index 00000000..9eb86d69 --- /dev/null +++ b/tests/test_020_bulkcopy_auth_cleanup.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for bulkcopy auth field cleanup in cursor.py. + +When cursor.bulkcopy() acquires an Azure AD token, it must strip stale +authentication/user_name/password keys from the pycore_context dict before +passing it to mssql_py_core. The Rust validator rejects access_token +combined with those fields (ODBC parity). +""" + +import secrets +from unittest.mock import MagicMock, patch + +SAMPLE_TOKEN = secrets.token_hex(44) + + +def _make_cursor(connection_str, auth_type): + """Build a mock Cursor with just enough wiring for bulkcopy's auth path.""" + from mssql_python.cursor import Cursor + + mock_conn = MagicMock() + mock_conn.connection_str = connection_str + mock_conn._auth_type = auth_type + mock_conn._is_connected = True + + cursor = Cursor.__new__(Cursor) + cursor._connection = mock_conn + cursor.closed = False + cursor.hstmt = None + return cursor + + +class TestBulkcopyAuthCleanup: + """Verify cursor.bulkcopy strips stale auth fields after token acquisition.""" + + @patch("mssql_python.cursor.logger") + def test_token_replaces_auth_fields(self, mock_logger): + """access_token present ⇒ authentication, user_name, password removed.""" + mock_logger.is_debug_enabled = False + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" + "Authentication=ActiveDirectoryDefault;UID=user@test.com;PWD=secret", + "activedirectorydefault", + ) + + captured_context = {} + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + + def capture_context(ctx, **kwargs): + captured_context.update(ctx) + return mock_pycore_conn + + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = capture_context + + with ( + patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}), + patch("mssql_python.auth.AADAuth.get_raw_token", return_value=SAMPLE_TOKEN), + ): + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + assert captured_context.get("access_token") == SAMPLE_TOKEN + assert "authentication" not in captured_context + assert "user_name" not in captured_context + assert "password" not in captured_context + + @patch("mssql_python.cursor.logger") + def test_no_auth_type_leaves_fields_intact(self, mock_logger): + """No _auth_type ⇒ credentials pass through unchanged (SQL auth path).""" + mock_logger.is_debug_enabled = False + + cursor = _make_cursor( + "Server=tcp:test.database.windows.net;Database=testdb;" "UID=sa;PWD=password123", + None, # no AD auth + ) + + captured_context = {} + + mock_pycore_cursor = MagicMock() + mock_pycore_cursor.bulkcopy.return_value = { + "rows_copied": 1, + "batch_count": 1, + "elapsed_time": 0.1, + } + mock_pycore_conn = MagicMock() + mock_pycore_conn.cursor.return_value = mock_pycore_cursor + + def capture_context(ctx, **kwargs): + captured_context.update(ctx) + return mock_pycore_conn + + mock_pycore_module = MagicMock() + mock_pycore_module.PyCoreConnection = capture_context + + with patch.dict("sys.modules", {"mssql_py_core": mock_pycore_module}): + cursor.bulkcopy("dbo.test_table", [(1, "row")], timeout=10) + + assert "access_token" not in captured_context + assert captured_context.get("user_name") == "sa" + assert captured_context.get("password") == "password123"