Skip to content
Merged
145 changes: 145 additions & 0 deletions src/ansys/hps/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import os
import platform
import tempfile
import threading
import warnings
from datetime import datetime, timedelta, timezone

import jwt
import requests
Expand Down Expand Up @@ -94,6 +96,18 @@ class Client:
disable_security_warnings : bool, optional
Whether to disable urllib3 warnings about insecure HTTPS requests. The default is ``True``.
For more information, see urllib3 documentation about TLS warnings.
auto_refresh_token : bool, optional
Whether to automatically refresh access token before it expires. The default is ``True``.
token_refresh_factor : float, optional
Fraction of the token lifetime at which the first preemptive refresh is
scheduled. Must be in the open interval ``(0, 1)``. The default is ``0.70``.
token_refresh_retry_factors : sequence of float, optional
Strictly increasing fractions in ``(token_refresh_factor, 1)`` used to
reschedule the refresh after a failed attempt. The default is
``(0.80, 0.90, 0.95, 0.98)``.
token_refresh_loop_interval : float, optional
Maximum interval, in seconds, between checks of the background token refresh
loop. The default is ``300``.

Examples
--------
Expand Down Expand Up @@ -132,6 +146,10 @@ def __init__(
all_fields=True,
verify: bool | str = None,
disable_security_warnings: bool = True,
auto_refresh_token: bool = True,
token_refresh_factor: float = 0.70,
token_refresh_retry_factors: tuple[float, ...] = (0.80, 0.90, 0.95, 0.98),
token_refresh_loop_interval: float = 300,
**kwargs,
):
"""Initialize the Client object."""
Expand Down Expand Up @@ -162,6 +180,28 @@ def __init__(
self.client_secret = client_secret
self.verify = verify
self.data_transfer_url = url + "/dt/api/v1"
self._token_refresh_thread = None
self._stop_event = threading.Event()
if not 0 < token_refresh_factor < 1:
raise ValueError("token_refresh_factor must be in the open interval (0, 1).")
if token_refresh_loop_interval <= 0:
raise ValueError("token_refresh_loop_interval must be positive.")
retry_factors = tuple(token_refresh_retry_factors)
prev = token_refresh_factor
for f in retry_factors:
if not prev < f < 1:
raise ValueError(
"token_refresh_retry_factors must be strictly increasing and "
"in (token_refresh_factor, 1)."
)
prev = f
self.token_refresh_factor = token_refresh_factor
self.token_refresh_retry_factors = retry_factors
self.loop_interval = token_refresh_loop_interval
self._refresh_attempt = 0
self.token_expires_in = None
self.token_acquired_date = None
self.token_refresh_date = None

self._dt_client: DataTransferClient | None = None
self._dt_api: DataTransferApi | None = None
Expand Down Expand Up @@ -214,6 +254,8 @@ def __init__(
# client credentials flow does not return a refresh token
self.refresh_token = tokens.get("refresh_token", None)

self._update_token_expiry(tokens)

parsed_username = None
token = {}
try:
Expand Down Expand Up @@ -244,8 +286,13 @@ def __init__(
self.session.hooks["response"] = [self._auto_refresh_token, raise_for_status]
self._unauthorized_num_retry = 0
self._unauthorized_max_retry = 1
if auto_refresh_token and self.token_refresh_date is not None:
self._start_token_refresh_thread()

def exit_handler():
self._stop_event.set()
if self._token_refresh_thread is not None:
self._token_refresh_thread.join(timeout=5)
if self._dt_client is not None:
log.info("Stopping the data transfer client gracefully.")
self._dt_client.stop()
Expand Down Expand Up @@ -344,6 +391,103 @@ def auth_api_url(self) -> str:
log.error("auth_api not valid for non-keycloak implementation")
return None

def _start_token_refresh_thread(self):
"""Start a background thread to refresh the access token."""
if self._token_refresh_thread is not None and self._token_refresh_thread.is_alive():
return

self._token_refresh_thread = threading.Thread(
target=self._periodically_refresh_token,
name="periodic_token_refresh",
)
self._token_refresh_thread.daemon = True
self._token_refresh_thread.start()

def _update_token_expiry(self, tokens):
"""Update expiry-related fields from a token response."""
expires_in = []
access_expires_in = tokens.get("expires_in", None)
if access_expires_in is not None:
log.debug(f"Access token expires in {timedelta(seconds=int(access_expires_in))}")
expires_in.append(access_expires_in)
refresh_expires_in = tokens.get("refresh_expires_in", None)
if refresh_expires_in is not None:
info = (
"offline"
if refresh_expires_in == 0 and "offline_access" in self.scope
else f"expires in {timedelta(seconds=int(refresh_expires_in))}"
)
log.debug(f"Refresh token {info}")
if refresh_expires_in > 0:
expires_in.append(refresh_expires_in)
self.token_expires_in = min(expires_in) if expires_in else None
if self.token_expires_in is not None:
log.debug(f"Setting token expiry to {timedelta(seconds=int(self.token_expires_in))}")
self.token_acquired_date = (
datetime.now(timezone.utc) if self.token_expires_in is not None else None
)

self._refresh_attempt = 0
if self.token_expires_in is not None:
offset = max(1, int(self.token_expires_in * self.token_refresh_factor))
self.token_refresh_date = self.token_acquired_date + timedelta(seconds=offset)
log.debug(
"Refresh token set, auto refresh in "
f"{timedelta(seconds=offset)} ({self.token_refresh_date})"
)
else:
self.token_refresh_date = None

def _periodically_refresh_token(self):
"""Periodically check if the token needs to be refreshed and refresh it."""
while not self._stop_event.is_set():
if self.token_refresh_date is None:
if self._stop_event.wait(self.loop_interval):
break
continue

now = datetime.now(timezone.utc)
if now > self.token_refresh_date:
log.debug("Attempting preemptive authentication token refresh")
try:
self.refresh_access_token()
except Exception as ex:
self._reschedule_after_failed_refresh(ex)
continue

diff = self.token_refresh_date - now
sleep_time = max(0.1, min(self.loop_interval, diff.total_seconds()))
if self._stop_event.wait(sleep_time):
break
log.debug("Token refresh thread stopped")

def _reschedule_after_failed_refresh(self, ex):
"""Schedule the next refresh attempt after a failure, if any retries remain."""
self._refresh_attempt += 1
if (
self._refresh_attempt > len(self.token_refresh_retry_factors)
or self.token_acquired_date is None
or self.token_expires_in is None
):
log.error(
"Preemptive token refresh failed and no retries remain: %s. "
"Falling back to on-demand refresh on the next 401 response.",
ex,
)
self.token_refresh_date = None
return

factor = self.token_refresh_retry_factors[self._refresh_attempt - 1]
offset = max(1, int(self.token_expires_in * factor))
self.token_refresh_date = self.token_acquired_date + timedelta(seconds=offset)
log.warning(
"Preemptive token refresh failed (%s); next attempt scheduled at %.0f%% "
"of token lifetime (%s).",
ex,
factor * 100,
self.token_refresh_date,
)

def _auto_refresh_token(self, response, *args, **kwargs):
"""Provide a callback for refreshing an expired token.

Expand Down Expand Up @@ -396,6 +540,7 @@ def refresh_access_token(self):
self.access_token = tokens["access_token"]
self.refresh_token = tokens.get("refresh_token", None)
self.session.headers.update({"Authorization": f"Bearer {tokens['access_token']}"})
self._update_token_expiry(tokens)

@property
def data_transfer_client(self) -> DataTransferClient:
Expand Down
79 changes: 79 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
# SOFTWARE.get_projects

import logging
import threading
import time
from datetime import datetime, timedelta, timezone

import pytest
import requests
Expand Down Expand Up @@ -125,3 +127,80 @@ def test_dt_client(url, username, password):

assert client.data_transfer_client == client._dt_client
assert client.data_transfer_api == client._dt_api


def test_update_token_expiry_sets_refresh_date(url, username, password):
"""After authentication, expiry-related fields must be populated."""
client = Client(url, username, password)

assert client.token_expires_in is not None
assert client.token_acquired_date is not None
assert client.token_refresh_date is not None
# refresh date should be in the future and within token lifetime
assert client.token_refresh_date > client.token_acquired_date
diff = (client.token_refresh_date - client.token_acquired_date).total_seconds()
assert 0 < diff <= client.token_expires_in


def test_update_token_expiry_updates_after_refresh(url, username, password):
"""Calling refresh_access_token must move token_refresh_date forward."""
client = Client(url, username, password)
first_refresh_date = client.token_refresh_date

time.sleep(0.5)
client.refresh_access_token()

assert client.token_refresh_date > first_refresh_date


def test_reschedule_after_failed_refresh(url, username, password):
"""Failed refreshes must escalate through retry factors, then give up."""
client = Client(url, username, password)

# Stop the background thread so it doesn't race with our manipulations.
client._stop_event.set()
client._token_refresh_thread.join(timeout=5)

acquired = client.token_acquired_date
expires_in = client.token_expires_in
retry_factors = client.token_refresh_retry_factors
assert len(retry_factors) > 0

err = RuntimeError("simulated refresh failure")

# Each failure should reschedule at the next retry factor.
for i, factor in enumerate(retry_factors, start=1):
client._reschedule_after_failed_refresh(err)
assert client._refresh_attempt == i
expected_offset = max(1, int(expires_in * factor))
expected_date = acquired + timedelta(seconds=expected_offset)
assert client.token_refresh_date == expected_date

# One more failure exhausts the retries and disables preemptive refresh.
client._reschedule_after_failed_refresh(err)
assert client.token_refresh_date is None


def test_periodically_refresh_token_refreshes_preemptively(url, username, password):
"""The background thread must refresh the access token before it expires."""
client = Client(url, username, password)
initial_access_token = client.access_token

# The background thread is already sleeping with the default loop_interval
# against a fresh token_refresh_date. Stop it, retune the schedule so a
# refresh is due immediately, then restart so the new values take effect.
client._stop_event.set()
client._token_refresh_thread.join(timeout=5)
client._stop_event = threading.Event()
client._token_refresh_thread = None
client.loop_interval = 0.1
client.token_refresh_date = datetime.now(timezone.utc) - timedelta(seconds=1)
client._start_token_refresh_thread()

# Wait for the background thread to perform the refresh
deadline = time.time() + 5
while time.time() < deadline and client.access_token == initial_access_token:
time.sleep(0.1)

assert client.access_token != initial_access_token
assert client.token_refresh_date > datetime.now(timezone.utc)
Loading