diff --git a/src/ansys/hps/client/client.py b/src/ansys/hps/client/client.py index 4e9351ed1..8314df22c 100644 --- a/src/ansys/hps/client/client.py +++ b/src/ansys/hps/client/client.py @@ -27,7 +27,9 @@ import os import platform import tempfile +import threading import warnings +from datetime import datetime, timedelta, timezone import jwt import requests @@ -94,6 +96,18 @@ class Client: disable_security_warnings : bool, optional Whether to disable urllib3 warnings about insecure HTTPS requests. The default is ``True``. For more information, see urllib3 documentation about TLS warnings. + auto_refresh_token : bool, optional + Whether to automatically refresh access token before it expires. The default is ``True``. + token_refresh_factor : float, optional + Fraction of the token lifetime at which the first preemptive refresh is + scheduled. Must be in the open interval ``(0, 1)``. The default is ``0.70``. + token_refresh_retry_factors : sequence of float, optional + Strictly increasing fractions in ``(token_refresh_factor, 1)`` used to + reschedule the refresh after a failed attempt. The default is + ``(0.80, 0.90, 0.95, 0.98)``. + token_refresh_loop_interval : float, optional + Maximum interval, in seconds, between checks of the background token refresh + loop. The default is ``300``. Examples -------- @@ -132,6 +146,10 @@ def __init__( all_fields=True, verify: bool | str = None, disable_security_warnings: bool = True, + auto_refresh_token: bool = True, + token_refresh_factor: float = 0.70, + token_refresh_retry_factors: tuple[float, ...] = (0.80, 0.90, 0.95, 0.98), + token_refresh_loop_interval: float = 300, **kwargs, ): """Initialize the Client object.""" @@ -162,6 +180,28 @@ def __init__( self.client_secret = client_secret self.verify = verify self.data_transfer_url = url + "/dt/api/v1" + self._token_refresh_thread = None + self._stop_event = threading.Event() + if not 0 < token_refresh_factor < 1: + raise ValueError("token_refresh_factor must be in the open interval (0, 1).") + if token_refresh_loop_interval <= 0: + raise ValueError("token_refresh_loop_interval must be positive.") + retry_factors = tuple(token_refresh_retry_factors) + prev = token_refresh_factor + for f in retry_factors: + if not prev < f < 1: + raise ValueError( + "token_refresh_retry_factors must be strictly increasing and " + "in (token_refresh_factor, 1)." + ) + prev = f + self.token_refresh_factor = token_refresh_factor + self.token_refresh_retry_factors = retry_factors + self.loop_interval = token_refresh_loop_interval + self._refresh_attempt = 0 + self.token_expires_in = None + self.token_acquired_date = None + self.token_refresh_date = None self._dt_client: DataTransferClient | None = None self._dt_api: DataTransferApi | None = None @@ -214,6 +254,8 @@ def __init__( # client credentials flow does not return a refresh token self.refresh_token = tokens.get("refresh_token", None) + self._update_token_expiry(tokens) + parsed_username = None token = {} try: @@ -244,8 +286,13 @@ def __init__( self.session.hooks["response"] = [self._auto_refresh_token, raise_for_status] self._unauthorized_num_retry = 0 self._unauthorized_max_retry = 1 + if auto_refresh_token and self.token_refresh_date is not None: + self._start_token_refresh_thread() def exit_handler(): + self._stop_event.set() + if self._token_refresh_thread is not None: + self._token_refresh_thread.join(timeout=5) if self._dt_client is not None: log.info("Stopping the data transfer client gracefully.") self._dt_client.stop() @@ -344,6 +391,103 @@ def auth_api_url(self) -> str: log.error("auth_api not valid for non-keycloak implementation") return None + def _start_token_refresh_thread(self): + """Start a background thread to refresh the access token.""" + if self._token_refresh_thread is not None and self._token_refresh_thread.is_alive(): + return + + self._token_refresh_thread = threading.Thread( + target=self._periodically_refresh_token, + name="periodic_token_refresh", + ) + self._token_refresh_thread.daemon = True + self._token_refresh_thread.start() + + def _update_token_expiry(self, tokens): + """Update expiry-related fields from a token response.""" + expires_in = [] + access_expires_in = tokens.get("expires_in", None) + if access_expires_in is not None: + log.debug(f"Access token expires in {timedelta(seconds=int(access_expires_in))}") + expires_in.append(access_expires_in) + refresh_expires_in = tokens.get("refresh_expires_in", None) + if refresh_expires_in is not None: + info = ( + "offline" + if refresh_expires_in == 0 and "offline_access" in self.scope + else f"expires in {timedelta(seconds=int(refresh_expires_in))}" + ) + log.debug(f"Refresh token {info}") + if refresh_expires_in > 0: + expires_in.append(refresh_expires_in) + self.token_expires_in = min(expires_in) if expires_in else None + if self.token_expires_in is not None: + log.debug(f"Setting token expiry to {timedelta(seconds=int(self.token_expires_in))}") + self.token_acquired_date = ( + datetime.now(timezone.utc) if self.token_expires_in is not None else None + ) + + self._refresh_attempt = 0 + if self.token_expires_in is not None: + offset = max(1, int(self.token_expires_in * self.token_refresh_factor)) + self.token_refresh_date = self.token_acquired_date + timedelta(seconds=offset) + log.debug( + "Refresh token set, auto refresh in " + f"{timedelta(seconds=offset)} ({self.token_refresh_date})" + ) + else: + self.token_refresh_date = None + + def _periodically_refresh_token(self): + """Periodically check if the token needs to be refreshed and refresh it.""" + while not self._stop_event.is_set(): + if self.token_refresh_date is None: + if self._stop_event.wait(self.loop_interval): + break + continue + + now = datetime.now(timezone.utc) + if now > self.token_refresh_date: + log.debug("Attempting preemptive authentication token refresh") + try: + self.refresh_access_token() + except Exception as ex: + self._reschedule_after_failed_refresh(ex) + continue + + diff = self.token_refresh_date - now + sleep_time = max(0.1, min(self.loop_interval, diff.total_seconds())) + if self._stop_event.wait(sleep_time): + break + log.debug("Token refresh thread stopped") + + def _reschedule_after_failed_refresh(self, ex): + """Schedule the next refresh attempt after a failure, if any retries remain.""" + self._refresh_attempt += 1 + if ( + self._refresh_attempt > len(self.token_refresh_retry_factors) + or self.token_acquired_date is None + or self.token_expires_in is None + ): + log.error( + "Preemptive token refresh failed and no retries remain: %s. " + "Falling back to on-demand refresh on the next 401 response.", + ex, + ) + self.token_refresh_date = None + return + + factor = self.token_refresh_retry_factors[self._refresh_attempt - 1] + offset = max(1, int(self.token_expires_in * factor)) + self.token_refresh_date = self.token_acquired_date + timedelta(seconds=offset) + log.warning( + "Preemptive token refresh failed (%s); next attempt scheduled at %.0f%% " + "of token lifetime (%s).", + ex, + factor * 100, + self.token_refresh_date, + ) + def _auto_refresh_token(self, response, *args, **kwargs): """Provide a callback for refreshing an expired token. @@ -396,6 +540,7 @@ def refresh_access_token(self): self.access_token = tokens["access_token"] self.refresh_token = tokens.get("refresh_token", None) self.session.headers.update({"Authorization": f"Bearer {tokens['access_token']}"}) + self._update_token_expiry(tokens) @property def data_transfer_client(self) -> DataTransferClient: diff --git a/tests/test_client.py b/tests/test_client.py index a888b6106..26229d683 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -21,7 +21,9 @@ # SOFTWARE.get_projects import logging +import threading import time +from datetime import datetime, timedelta, timezone import pytest import requests @@ -125,3 +127,80 @@ def test_dt_client(url, username, password): assert client.data_transfer_client == client._dt_client assert client.data_transfer_api == client._dt_api + + +def test_update_token_expiry_sets_refresh_date(url, username, password): + """After authentication, expiry-related fields must be populated.""" + client = Client(url, username, password) + + assert client.token_expires_in is not None + assert client.token_acquired_date is not None + assert client.token_refresh_date is not None + # refresh date should be in the future and within token lifetime + assert client.token_refresh_date > client.token_acquired_date + diff = (client.token_refresh_date - client.token_acquired_date).total_seconds() + assert 0 < diff <= client.token_expires_in + + +def test_update_token_expiry_updates_after_refresh(url, username, password): + """Calling refresh_access_token must move token_refresh_date forward.""" + client = Client(url, username, password) + first_refresh_date = client.token_refresh_date + + time.sleep(0.5) + client.refresh_access_token() + + assert client.token_refresh_date > first_refresh_date + + +def test_reschedule_after_failed_refresh(url, username, password): + """Failed refreshes must escalate through retry factors, then give up.""" + client = Client(url, username, password) + + # Stop the background thread so it doesn't race with our manipulations. + client._stop_event.set() + client._token_refresh_thread.join(timeout=5) + + acquired = client.token_acquired_date + expires_in = client.token_expires_in + retry_factors = client.token_refresh_retry_factors + assert len(retry_factors) > 0 + + err = RuntimeError("simulated refresh failure") + + # Each failure should reschedule at the next retry factor. + for i, factor in enumerate(retry_factors, start=1): + client._reschedule_after_failed_refresh(err) + assert client._refresh_attempt == i + expected_offset = max(1, int(expires_in * factor)) + expected_date = acquired + timedelta(seconds=expected_offset) + assert client.token_refresh_date == expected_date + + # One more failure exhausts the retries and disables preemptive refresh. + client._reschedule_after_failed_refresh(err) + assert client.token_refresh_date is None + + +def test_periodically_refresh_token_refreshes_preemptively(url, username, password): + """The background thread must refresh the access token before it expires.""" + client = Client(url, username, password) + initial_access_token = client.access_token + + # The background thread is already sleeping with the default loop_interval + # against a fresh token_refresh_date. Stop it, retune the schedule so a + # refresh is due immediately, then restart so the new values take effect. + client._stop_event.set() + client._token_refresh_thread.join(timeout=5) + client._stop_event = threading.Event() + client._token_refresh_thread = None + client.loop_interval = 0.1 + client.token_refresh_date = datetime.now(timezone.utc) - timedelta(seconds=1) + client._start_token_refresh_thread() + + # Wait for the background thread to perform the refresh + deadline = time.time() + 5 + while time.time() < deadline and client.access_token == initial_access_token: + time.sleep(0.1) + + assert client.access_token != initial_access_token + assert client.token_refresh_date > datetime.now(timezone.utc)