From 8ac1f7db09dc207600af723d94ad417805e72d95 Mon Sep 17 00:00:00 2001 From: Davide Mauri Date: Wed, 22 Oct 2025 06:02:00 +0000 Subject: [PATCH] Checkpoint from VS Code for coding agent session --- vectordb_bench/backend/clients/mssql/mssql.py | 79 ++++++++++++------- 1 file changed, 52 insertions(+), 27 deletions(-) diff --git a/vectordb_bench/backend/clients/mssql/mssql.py b/vectordb_bench/backend/clients/mssql/mssql.py index 829872a63..4be85014a 100644 --- a/vectordb_bench/backend/clients/mssql/mssql.py +++ b/vectordb_bench/backend/clients/mssql/mssql.py @@ -15,12 +15,20 @@ import struct import azure.identity +import multiprocessing as mp + log = logging.getLogger(__name__) # --- Constants for Token Authentication --- SQL_COPT_SS_ACCESS_TOKEN = 1256 SQL_SERVER_TOKEN_SCOPE = "https://database.windows.net/.default" +# --- Shared variables for token refresh across processes --- +# Initialize manager and shared resources +_manager = mp.Manager() +token_refresh_lock = _manager.Lock() +shared_token_data = _manager.dict({'token': None, 'expires_on': None}) + class MSSQL(VectorDB): supported_filter_types: list[FilterOp] = [ @@ -44,7 +52,6 @@ def __init__( self.dim = dim self.schema_name = "benchmark" self.drop_old = drop_old - self.access_token = None log.info("db_case_config: " + str(db_case_config)) @@ -69,7 +76,7 @@ def __init__( cnxn.commit() - log.info(f"Creating vector table '[{self.schema_name}].[{self.table_name}]'...") + log.info(f"Creating vector table '[{self.schema_name}].[{self.table_name}]' if not existing...") cursor.execute(f""" if object_id('[{self.schema_name}].[{self.table_name}]') is null begin create table [{self.schema_name}].[{self.table_name}] ( @@ -155,7 +162,6 @@ def ready_to_search(self): log.info(f"MSSQL ready to search") pass - def prepare_filter(self, filters: Filter): log.info(f"Preparing filters: {filters}") if filters.type == FilterOp.NonFilter: @@ -237,36 +243,55 @@ def connect(self): # --- Case 2: Entra ID Managed Identity (Manual Token Auth) --- - # check if token exists and if it expires within the next hour (or is already expired) - if self.access_token is not None: - remaining_seconds = self.access_token.expires_on - time.time() - if remaining_seconds < 300: # expires within 5 minutes or already expired - expiration_datetime = datetime.fromtimestamp(self.access_token.expires_on) - if remaining_seconds <= 0: - log.info(f"Token expired on {expiration_datetime.strftime('%Y-%m-%d %H:%M:%S')} ({abs(remaining_seconds):.0f} seconds ago), acquiring a new one.") - else: - log.info(f"Token expires on {expiration_datetime.strftime('%Y-%m-%d %H:%M:%S')} (in {remaining_seconds:.0f} seconds), acquiring a new one.") - self.access_token = None - - if self.access_token is None: - log.info(f"Acquiring token for authentication...") + # Protect token refresh with lock to ensure only one process accesses it at a time + log.info(f"Process {mp.current_process().name} attempting to acquire token refresh lock...") + with token_refresh_lock: + log.info(f"Process {mp.current_process().name} acquired token refresh lock") + token_str = shared_token_data['token'] + expires_on = shared_token_data['expires_on'] + + # if expires within 5 minutes or already expired then force a refresh + if token_str is not None and expires_on is not None: + remaining_seconds = expires_on - time.time() + if remaining_seconds < 300: + expiration_datetime = datetime.fromtimestamp(expires_on) + if remaining_seconds <= 0: + log.info(f"Token expired on {expiration_datetime.strftime('%Y-%m-%d %H:%M:%S')} ({abs(remaining_seconds):.0f} seconds ago), acquiring a new one.") + else: + log.info(f"Token expires on {expiration_datetime.strftime('%Y-%m-%d %H:%M:%S')} (in {remaining_seconds:.0f} seconds), acquiring a new one.") + shared_token_data['token'] = None + shared_token_data['expires_on'] = None + token_str = None - if authentication == "AzureCLICredential": - log.info("Using Azure CLI Credentials for authentication.") - credential = azure.identity.AzureCliCredential() + # refresh token if not available or expired + if token_str is None: + log.info(f"Acquiring token for authentication...") - if authentication == "ManagedIdentityCredential": - log.info(f"Using Managed Identity Credential with client_id: {self.db_config.get('principal')}") - credential = azure.identity.ManagedIdentityCredential(client_id=self.db_config.get("principal")) - - access_token = credential.get_token(SQL_SERVER_TOKEN_SCOPE) - log.info("Token acquired successfully.") + if authentication == "AzureCLICredential": + log.info("Using Azure CLI Credentials for authentication.") + credential = azure.identity.AzureCliCredential() + + if authentication == "ManagedIdentityCredential": + log.info(f"Using Managed Identity Credential with client_id: {self.db_config.get('principal')}") + credential = azure.identity.ManagedIdentityCredential(client_id=self.db_config.get("principal")) + + access_token = credential.get_token(SQL_SERVER_TOKEN_SCOPE) + log.info("Token acquired successfully.") - self.access_token = access_token + # Store token as string in shared variable + shared_token_data['token'] = access_token.token + shared_token_data['expires_on'] = access_token.expires_on + token_str = access_token.token + log.info(f"Process {mp.current_process().name} stored new token in shared data") + else: + log.info(f"Process {mp.current_process().name} using existing token from shared data") - token_bytes = self.access_token.token.encode("UTF-16-LE") + log.info(f"Process {mp.current_process().name} released token refresh lock") + # Prepare token structure for pyodbc + token_bytes = token_str.encode("UTF-16-LE") token_struct = struct.pack(f'