Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 52 additions & 27 deletions vectordb_bench/backend/clients/mssql/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@
import struct
import azure.identity

import multiprocessing as mp

log = logging.getLogger(__name__)

# --- Constants for Token Authentication ---
SQL_COPT_SS_ACCESS_TOKEN = 1256
SQL_SERVER_TOKEN_SCOPE = "https://database.windows.net/.default"

# --- Shared variables for token refresh across processes ---
# Initialize manager and shared resources
_manager = mp.Manager()
token_refresh_lock = _manager.Lock()
shared_token_data = _manager.dict({'token': None, 'expires_on': None})

class MSSQL(VectorDB):

supported_filter_types: list[FilterOp] = [
Expand All @@ -44,7 +52,6 @@ def __init__(
self.dim = dim
self.schema_name = "benchmark"
self.drop_old = drop_old
self.access_token = None

log.info("db_case_config: " + str(db_case_config))

Expand All @@ -69,7 +76,7 @@ def __init__(
cnxn.commit()


log.info(f"Creating vector table '[{self.schema_name}].[{self.table_name}]'...")
log.info(f"Creating vector table '[{self.schema_name}].[{self.table_name}]' if not existing...")
cursor.execute(f"""
if object_id('[{self.schema_name}].[{self.table_name}]') is null begin
create table [{self.schema_name}].[{self.table_name}] (
Expand Down Expand Up @@ -155,7 +162,6 @@ def ready_to_search(self):
log.info(f"MSSQL ready to search")
pass


def prepare_filter(self, filters: Filter):
log.info(f"Preparing filters: {filters}")
if filters.type == FilterOp.NonFilter:
Expand Down Expand Up @@ -237,36 +243,55 @@ def connect(self):

# --- Case 2: Entra ID Managed Identity (Manual Token Auth) ---

# check if token exists and if it expires within the next hour (or is already expired)
if self.access_token is not None:
remaining_seconds = self.access_token.expires_on - time.time()
if remaining_seconds < 300: # expires within 5 minutes or already expired
expiration_datetime = datetime.fromtimestamp(self.access_token.expires_on)
if remaining_seconds <= 0:
log.info(f"Token expired on {expiration_datetime.strftime('%Y-%m-%d %H:%M:%S')} ({abs(remaining_seconds):.0f} seconds ago), acquiring a new one.")
else:
log.info(f"Token expires on {expiration_datetime.strftime('%Y-%m-%d %H:%M:%S')} (in {remaining_seconds:.0f} seconds), acquiring a new one.")
self.access_token = None

if self.access_token is None:
log.info(f"Acquiring token for authentication...")
# Protect token refresh with lock to ensure only one process accesses it at a time
log.info(f"Process {mp.current_process().name} attempting to acquire token refresh lock...")
with token_refresh_lock:
log.info(f"Process {mp.current_process().name} acquired token refresh lock")
token_str = shared_token_data['token']
expires_on = shared_token_data['expires_on']

# if expires within 5 minutes or already expired then force a refresh
if token_str is not None and expires_on is not None:
remaining_seconds = expires_on - time.time()
if remaining_seconds < 300:
expiration_datetime = datetime.fromtimestamp(expires_on)
if remaining_seconds <= 0:
log.info(f"Token expired on {expiration_datetime.strftime('%Y-%m-%d %H:%M:%S')} ({abs(remaining_seconds):.0f} seconds ago), acquiring a new one.")
else:
log.info(f"Token expires on {expiration_datetime.strftime('%Y-%m-%d %H:%M:%S')} (in {remaining_seconds:.0f} seconds), acquiring a new one.")
shared_token_data['token'] = None
shared_token_data['expires_on'] = None
token_str = None

if authentication == "AzureCLICredential":
log.info("Using Azure CLI Credentials for authentication.")
credential = azure.identity.AzureCliCredential()
# refresh token if not available or expired
if token_str is None:
log.info(f"Acquiring token for authentication...")

if authentication == "ManagedIdentityCredential":
log.info(f"Using Managed Identity Credential with client_id: {self.db_config.get('principal')}")
credential = azure.identity.ManagedIdentityCredential(client_id=self.db_config.get("principal"))

access_token = credential.get_token(SQL_SERVER_TOKEN_SCOPE)
log.info("Token acquired successfully.")
if authentication == "AzureCLICredential":
log.info("Using Azure CLI Credentials for authentication.")
credential = azure.identity.AzureCliCredential()

if authentication == "ManagedIdentityCredential":
log.info(f"Using Managed Identity Credential with client_id: {self.db_config.get('principal')}")
credential = azure.identity.ManagedIdentityCredential(client_id=self.db_config.get("principal"))

access_token = credential.get_token(SQL_SERVER_TOKEN_SCOPE)
log.info("Token acquired successfully.")

self.access_token = access_token
# Store token as string in shared variable
shared_token_data['token'] = access_token.token
shared_token_data['expires_on'] = access_token.expires_on
token_str = access_token.token
log.info(f"Process {mp.current_process().name} stored new token in shared data")
else:
log.info(f"Process {mp.current_process().name} using existing token from shared data")

token_bytes = self.access_token.token.encode("UTF-16-LE")
log.info(f"Process {mp.current_process().name} released token refresh lock")
# Prepare token structure for pyodbc
token_bytes = token_str.encode("UTF-16-LE")
token_struct = struct.pack(f'<I{len(token_bytes)}s', len(token_bytes), token_bytes)

# Create a new connection with the access token
cnxn = pyodbc.connect(
self.db_config.get("connection_string"),
attrs_before={SQL_COPT_SS_ACCESS_TOKEN: token_struct}
Expand Down
Loading