diff --git a/pom.xml b/pom.xml index 4e5977f..1fda826 100644 --- a/pom.xml +++ b/pom.xml @@ -166,6 +166,27 @@ 4.12.0 test + + + io.jsonwebtoken + jjwt-api + 0.12.6 + test + + + + io.jsonwebtoken + jjwt-impl + 0.12.6 + test + + + + io.jsonwebtoken + jjwt-jackson + 0.12.6 + test + diff --git a/src/main/java/de/privateaim/node_message_broker/common/CommonSpringConfig.java b/src/main/java/de/privateaim/node_message_broker/common/CommonSpringConfig.java index 60f372c..4e231d8 100644 --- a/src/main/java/de/privateaim/node_message_broker/common/CommonSpringConfig.java +++ b/src/main/java/de/privateaim/node_message_broker/common/CommonSpringConfig.java @@ -1,13 +1,11 @@ package de.privateaim.node_message_broker.common; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.ObjectMapper; import de.privateaim.node_message_broker.ConfigurationUtil; import de.privateaim.node_message_broker.common.hub.HttpHubClient; -import de.privateaim.node_message_broker.common.hub.HttpHubClientConfig; import de.privateaim.node_message_broker.common.hub.HubClient; -import de.privateaim.node_message_broker.common.hub.auth.HttpHubAuthClient; -import de.privateaim.node_message_broker.common.hub.auth.HttpHubAuthClientConfig; -import de.privateaim.node_message_broker.common.hub.auth.HubAuthClient; -import de.privateaim.node_message_broker.common.hub.auth.RenewAuthTokenFilter; +import de.privateaim.node_message_broker.common.hub.auth.HubOIDCAuthenticator; import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; @@ -24,6 +22,9 @@ @Configuration public class CommonSpringConfig { + private static final int EXCHANGE__MAX_RETRIES = 5; + private static final int EXCHANGE__MAX_RETRY_DELAY_MS = 1000; + @Value("${app.hub.baseUrl}") private String hubCoreBaseUrl; @@ -48,10 +49,16 @@ public String hubAuthRobotId() { return hubAuthRobotId; } + @Qualifier("HUB_EXCHANGE_RETRY_CONFIG") + @Bean + HttpRetryConfig exchangeRetryConfig() { + return new HttpRetryConfig(EXCHANGE__MAX_RETRIES, EXCHANGE__MAX_RETRY_DELAY_MS); + } + @Qualifier("HUB_CORE_WEB_CLIENT") @Bean public WebClient alwaysReAuthenticatedWebClient( - @Qualifier("HUB_AUTH_RENEW_TOKEN") ExchangeFilterFunction renewTokenFilter, + @Qualifier("HUB_AUTHENTICATION_MIDDLEWARE") ExchangeFilterFunction authenticationMiddleware, @Qualifier("BASE_SSL_HTTP_CLIENT_CONNECTOR") ReactorClientHttpConnector baseSslHttpClientConnector) { // We can't use Spring's default security mechanisms out-of-the-box here since HUB uses a non-standard grant // type which is not supported. There's a way by using a custom grant type accompanied by a client manager. @@ -65,19 +72,17 @@ public WebClient alwaysReAuthenticatedWebClient( return WebClient.builder() .uriBuilderFactory(factory) .defaultHeaders(httpHeaders -> httpHeaders.setAccept(List.of(MediaType.APPLICATION_JSON))) - .filter(renewTokenFilter) + .filter(authenticationMiddleware) .clientConnector(baseSslHttpClientConnector) .build(); } @Bean - public HubClient hubClient(@Qualifier("HUB_CORE_WEB_CLIENT") WebClient alwaysReAuthenticatedWebClient) { - var clientConfig = new HttpHubClientConfig.Builder() - .withMaxRetries(5) - .withRetryDelayMs(1000) - .build(); - - return new HttpHubClient(alwaysReAuthenticatedWebClient, clientConfig); + public HubClient hubClient( + @Qualifier("HUB_CORE_WEB_CLIENT") WebClient webClient, + @Qualifier("HUB_EXCHANGE_RETRY_CONFIG") HttpRetryConfig retryConfig + ) { + return new HttpHubClient(webClient, retryConfig); } @Qualifier("HUB_AUTH_WEB_CLIENT") @@ -91,21 +96,37 @@ WebClient hubAuthWebClient( .build(); } - @Qualifier("HUB_AUTH_CLIENT") + @Qualifier("HUB_JSON_MAPPER") + @Bean + ObjectMapper simpleJsonMapper() { + return new ObjectMapper() + .findAndRegisterModules() + .disable(DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES) + .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES); + } + + @Qualifier("HUB_AUTHENTICATOR") @Bean - public HubAuthClient hubAuthClient(@Qualifier("HUB_AUTH_WEB_CLIENT") WebClient webClient) { - var clientConfig = new HttpHubAuthClientConfig.Builder() - .withMaxRetries(5) - .withRetryDelayMs(1000) + OIDCAuthenticator hubAuthenticator( + @Qualifier("HUB_AUTH_WEB_CLIENT") WebClient webClient, + @Qualifier("HUB_EXCHANGE_RETRY_CONFIG") HttpRetryConfig retryConfig, + @Qualifier("HUB_AUTH_ROBOT_ID") String hubAuthRobotId, + @Qualifier("HUB_AUTH_ROBOT_SECRET") String hubAuthRobotSecret, + @Qualifier("HUB_JSON_MAPPER") ObjectMapper jsonMapper + ) { + return HubOIDCAuthenticator.builder() + .usingWebClient(webClient) + .withRetryConfig(retryConfig) + .withAuthCredentials(hubAuthRobotId, hubAuthRobotSecret) + .withJsonDecoder(jsonMapper) .build(); - return new HttpHubAuthClient(webClient, clientConfig); } - @Qualifier("HUB_AUTH_RENEW_TOKEN") + @Qualifier("HUB_AUTHENTICATION_MIDDLEWARE") @Bean - ExchangeFilterFunction renewAuthTokenFilter( - @Qualifier("HUB_AUTH_CLIENT") HubAuthClient hubAuthClient, - @Qualifier("HUB_AUTH_ROBOT_SECRET") String hubAuthRobotSecret) { - return new RenewAuthTokenFilter(hubAuthClient, hubAuthRobotId, hubAuthRobotSecret); + ExchangeFilterFunction hubAuthenticationMiddleware( + @Qualifier("HUB_AUTHENTICATOR") OIDCAuthenticator authenticator + ) { + return new OIDCAuthenticatorMiddleware(authenticator); } } diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClientConfig.java b/src/main/java/de/privateaim/node_message_broker/common/HttpRetryConfig.java similarity index 73% rename from src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClientConfig.java rename to src/main/java/de/privateaim/node_message_broker/common/HttpRetryConfig.java index 1ba97a8..e70a523 100644 --- a/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClientConfig.java +++ b/src/main/java/de/privateaim/node_message_broker/common/HttpRetryConfig.java @@ -1,12 +1,12 @@ -package de.privateaim.node_message_broker.common.hub; +package de.privateaim.node_message_broker.common; /** - * Additional behavioural configuration for the {@link HttpHubClient}. + * Configuration options for retrying failed HTTP requests. * * @param maxRetries number of maximum retries carried out by the client in case of a retryable error * @param retryDelayMs time between retries in ms */ -public record HttpHubClientConfig(int maxRetries, int retryDelayMs) { +public record HttpRetryConfig(int maxRetries, int retryDelayMs) { public static final class Builder { private int maxRetries = 5; private int retryDelayMs = 2000; @@ -21,7 +21,7 @@ public Builder withRetryDelayMs(int retryDelayMs) { return this; } - public HttpHubClientConfig build() { + public HttpRetryConfig build() { if (maxRetries < 0) { throw new IllegalArgumentException("maxRetries must be greater than 0"); } @@ -30,7 +30,7 @@ public HttpHubClientConfig build() { throw new IllegalArgumentException("retryDelayMs must be greater than 0"); } - return new HttpHubClientConfig(maxRetries, retryDelayMs); + return new HttpRetryConfig(maxRetries, retryDelayMs); } } } diff --git a/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticator.java b/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticator.java new file mode 100644 index 0000000..d7f600b --- /dev/null +++ b/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticator.java @@ -0,0 +1,24 @@ +package de.privateaim.node_message_broker.common; + +import org.springframework.security.oauth2.core.OAuth2Token; +import reactor.core.publisher.Mono; + +/** + * Describes an OIDC compliant authenticator. + */ +public interface OIDCAuthenticator { + /** + * Authenticates against an external system. + * + * @return A pair of access and refresh token. + */ + Mono authenticate(); + + /** + * Refreshes the authentication against an external system. + * + * @param refreshToken The refresh token to be used. + * @return A pair of access and refresh token. + */ + Mono refresh(OAuth2Token refreshToken); +} diff --git a/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddleware.java b/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddleware.java new file mode 100644 index 0000000..88ea2f1 --- /dev/null +++ b/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddleware.java @@ -0,0 +1,91 @@ +package de.privateaim.node_message_broker.common; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatus; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFilterFunction; +import org.springframework.web.reactive.function.client.ExchangeFunction; +import reactor.core.publisher.Mono; + +import java.time.Instant; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import static java.util.Objects.requireNonNull; + +/** + * A middleware for authenticating against several different services of a single external provider. + */ +@Slf4j +public final class OIDCAuthenticatorMiddleware implements ExchangeFilterFunction { + + private final Map tokenPairByHost; + private final OIDCAuthenticator oidcAuthenticator; + + public OIDCAuthenticatorMiddleware(OIDCAuthenticator oidcAuthenticator) { + this.tokenPairByHost = new ConcurrentHashMap<>(); + this.oidcAuthenticator = requireNonNull(oidcAuthenticator, "OIDC authenticator must not be null"); + } + + @Override + public @NonNull Mono filter(@NonNull ClientRequest request, ExchangeFunction next) { + var host = request.url().getHost(); + if (host == null) { + return Mono.just(ClientResponse.create(HttpStatus.BAD_REQUEST).build()); + } + + var authToken = computeAccessTokenForHost(host); + var authenticatedRequest = ClientRequest.from(request) + .headers(headers -> headers.setBearerAuth(authToken.getTokenValue())) + .build(); + + return next.exchange(authenticatedRequest).flatMap(response -> { + // Handling for unauthorized events in case of time overlaps regarding token expiration. + // Can happen if time skew is not properly handled by the authentication server. + if (response.statusCode().value() == HttpStatus.UNAUTHORIZED.value()) { + return response.releaseBody() + .then(Mono.just(computeAccessTokenForHost(host))) + .flatMap(token -> { + var newRequest = ClientRequest.from(request) + .headers(headers -> headers.setBearerAuth(token.getTokenValue())) + .build(); + log.warn("retrying request to '{}' with new bearer token after receiving status code 401 " + + "(unauthorized)", request.url()); + return next.exchange(newRequest); + }); + } else { + return Mono.just(response); + } + }); + } + + private OAuth2AccessToken computeAccessTokenForHost(@NonNull String host) { + // TODO: revise - find another way to circumvent using block() + // The following is an atomic operation! + return tokenPairByHost.compute(host, (unused, tokenPair) -> { + if (tokenPair == null) { + log.info("acquiring access token for host '{}' as there is none yet", host); + return oidcAuthenticator.authenticate().block(); + } + + if (tokenPair.accessToken().getExpiresAt().isBefore(Instant.now())) { + return tokenPair.refreshToken() + .map(refreshToken -> { + if (refreshToken.getExpiresAt().isBefore(Instant.now())) { + log.warn("refresh token expired - acquiring new pair of access token and refresh token for " + + "host '{}'", host); + return oidcAuthenticator.authenticate().block(); + } else { + log.info("refreshing access token for host '{}'", host); + return oidcAuthenticator.refresh(refreshToken).block(); + } + }) + .orElseGet(() -> oidcAuthenticator.authenticate().block()); + } + return tokenPair; + }).accessToken(); + } +} diff --git a/src/main/java/de/privateaim/node_message_broker/common/OIDCTokenPair.java b/src/main/java/de/privateaim/node_message_broker/common/OIDCTokenPair.java new file mode 100644 index 0000000..0876394 --- /dev/null +++ b/src/main/java/de/privateaim/node_message_broker/common/OIDCTokenPair.java @@ -0,0 +1,18 @@ +package de.privateaim.node_message_broker.common; + +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; + +import java.util.Optional; + +/** + * An OIDC compliant pair of access token & refresh token. + * + * @param accessToken JWT acting as an access token. + * @param refreshToken JWT acting as a refresh token for acquiring a new access token. + */ +public record OIDCTokenPair( + OAuth2AccessToken accessToken, + Optional refreshToken +) { +} diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClient.java b/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClient.java index 06002b7..26f7766 100644 --- a/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClient.java +++ b/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClient.java @@ -1,5 +1,6 @@ package de.privateaim.node_message_broker.common.hub; +import de.privateaim.node_message_broker.common.HttpRetryConfig; import de.privateaim.node_message_broker.common.hub.api.AnalysisNode; import de.privateaim.node_message_broker.common.hub.api.HubResponseContainer; import de.privateaim.node_message_broker.common.hub.api.Node; @@ -30,11 +31,11 @@ public final class HttpHubClient implements HubClient { private final WebClient authenticatedWebClient; - private final HttpHubClientConfig config; + private final HttpRetryConfig retryConfig; - public HttpHubClient(WebClient authenticatedWebClient, HttpHubClientConfig config) { + public HttpHubClient(WebClient authenticatedWebClient, HttpRetryConfig retryConfig) { this.authenticatedWebClient = requireNonNull(authenticatedWebClient, "authenticated web client must not be null"); - this.config = requireNonNull(config, "config must not be null"); + this.retryConfig = requireNonNull(retryConfig, "retry config must not be null"); } // TODO: this might use a cache to cut corners and improve performance by avoiding unnecessary round-trips @@ -59,12 +60,12 @@ public Mono> fetchAnalysisNodes(String analysisId) { .bodyToMono(new ParameterizedTypeReference>>() { }) .map(resp -> resp.data) - .retryWhen(Retry.backoff(config.maxRetries(), Duration.ofMillis(config.retryDelayMs())) + .retryWhen(Retry.backoff(retryConfig.maxRetries(), Duration.ofMillis(retryConfig.retryDelayMs())) .jitter(0.75) .filter(err -> err instanceof HubCoreServerException) .onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) -> new HubAnalysisNodesNotObtainable("exhausted maximum number of retries of '%d'" - .formatted(config.maxRetries()))))); + .formatted(retryConfig.maxRetries()))))); } // TODO: add cache here! - see spring annotations @@ -110,11 +111,11 @@ public Mono fetchPublicKey(String robotId) { "robot id `%s`".formatted(robotId), e)); } }) - .retryWhen(Retry.backoff(config.maxRetries(), Duration.ofMillis(config.retryDelayMs())) + .retryWhen(Retry.backoff(retryConfig.maxRetries(), Duration.ofMillis(retryConfig.retryDelayMs())) .jitter(0.75) .filter(err -> err instanceof HubCoreServerException) .onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) -> new HubNodePublicKeyNotObtainable("exhausted maximum number of retries of '%d'" - .formatted(config.maxRetries()))))); + .formatted(retryConfig.maxRetries()))))); } } diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClient.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClient.java deleted file mode 100644 index e5d0ebd..0000000 --- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClient.java +++ /dev/null @@ -1,59 +0,0 @@ -package de.privateaim.node_message_broker.common.hub.auth; - -import de.privateaim.node_message_broker.common.hub.auth.api.HubAuthTokenResponse; -import lombok.extern.slf4j.Slf4j; -import org.springframework.http.HttpStatusCode; -import org.springframework.http.MediaType; -import org.springframework.web.reactive.function.BodyInserters; -import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Mono; -import reactor.util.retry.Retry; - -import java.time.Duration; - -import static java.util.Objects.requireNonNull; - -/** - * A client for communicating with the hub's auth service via HTTP/HTTPS. - */ -@Slf4j -public final class HttpHubAuthClient implements HubAuthClient { - - private final WebClient webClient; - private final HttpHubAuthClientConfig config; - - private static final String TOKEN_PATH = "/token"; - private static final String GRANT_TYPE = "robot_credentials"; - - public HttpHubAuthClient(WebClient webClient, HttpHubAuthClientConfig config) { - this.webClient = requireNonNull(webClient, "web client must not be null"); - this.config = requireNonNull(config, "config must not be null"); - } - - @Override - public Mono requestAccessToken(String clientId, String clientSecret) { - return webClient.post() - .uri(TOKEN_PATH) - .contentType(MediaType.APPLICATION_FORM_URLENCODED) - .accept(MediaType.APPLICATION_JSON) - .body(BodyInserters.fromFormData("grant_type", GRANT_TYPE) - .with("id", clientId) - .with("secret", clientSecret)) - .retrieve() - .onStatus(HttpStatusCode::is5xxServerError, - response -> { - var err = new HubAuthException("could not fetch hub access token"); - - log.warn("retrying token request after failed attempt", err); - return Mono.error(err); - }) - .bodyToMono(HubAuthTokenResponse.class) - .map(hubAuthTokenResponse -> hubAuthTokenResponse.accessToken) - .retryWhen(Retry.backoff(config.maxRetries(), Duration.ofMillis(config.retryDelayMs())) - .jitter(0.75) - .filter(err -> err instanceof HubAuthException) - .onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) -> - new HubAccessTokenNotObtainable("exhausted maximum retries of '%d'" - .formatted(config.maxRetries()))))); - } -} diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientConfig.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientConfig.java deleted file mode 100644 index 5820594..0000000 --- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientConfig.java +++ /dev/null @@ -1,37 +0,0 @@ -package de.privateaim.node_message_broker.common.hub.auth; - -/** - * Additional behavioural configuration for the {@link HttpHubAuthClient}. - * - * @param maxRetries number of maximum retries carried out by the client in case of a retryable error - * @param retryDelayMs time between retries in ms - */ -public record HttpHubAuthClientConfig(int maxRetries, int retryDelayMs) { - - public static final class Builder { - private int maxRetries = 5; - private int retryDelayMs = 2000; - - public Builder withMaxRetries(int maxRetries) { - this.maxRetries = maxRetries; - return this; - } - - public Builder withRetryDelayMs(int retryDelayMs) { - this.retryDelayMs = retryDelayMs; - return this; - } - - public HttpHubAuthClientConfig build() { - if (maxRetries < 0) { - throw new IllegalArgumentException("maxRetries must be greater than 0"); - } - - if (retryDelayMs < 0) { - throw new IllegalArgumentException("retryDelayMs must be greater than 0"); - } - - return new HttpHubAuthClientConfig(maxRetries, retryDelayMs); - } - } -} diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthClient.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthClient.java deleted file mode 100644 index b117aeb..0000000 --- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthClient.java +++ /dev/null @@ -1,18 +0,0 @@ -package de.privateaim.node_message_broker.common.hub.auth; - -import reactor.core.publisher.Mono; - -/** - * Describes a client able to carry out authentication related operations. - */ -public interface HubAuthClient { - - /** - * Asynchronously requests an access token from the hub using the "robot_credentials" grant type. - * - * @param clientId the client's id - * @param clientSecret the client's secret - * @return The access token. - */ - Mono requestAccessToken(String clientId, String clientSecret); -} diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthMalformedResponseException.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthMalformedResponseException.java new file mode 100644 index 0000000..88cf846 --- /dev/null +++ b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthMalformedResponseException.java @@ -0,0 +1,10 @@ +package de.privateaim.node_message_broker.common.hub.auth; + +/** + * Signals that the response from the Hub's auth service is malformed. + */ +public class HubAuthMalformedResponseException extends Exception { + public HubAuthMalformedResponseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticator.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticator.java new file mode 100644 index 0000000..5b97603 --- /dev/null +++ b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticator.java @@ -0,0 +1,256 @@ +package de.privateaim.node_message_broker.common.hub.auth; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import de.privateaim.node_message_broker.common.HttpRetryConfig; +import de.privateaim.node_message_broker.common.OIDCAuthenticator; +import de.privateaim.node_message_broker.common.OIDCTokenPair; +import de.privateaim.node_message_broker.common.hub.auth.api.HubAuthTokenResponse; +import lombok.Getter; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.springframework.http.HttpStatusCode; +import org.springframework.http.MediaType; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.web.reactive.function.BodyInserters; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +import java.time.Duration; +import java.time.Instant; +import java.util.Base64; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +/** + * An OIDC compliant authenticator for authenticating against the Hub. + */ +@Slf4j +public final class HubOIDCAuthenticator implements OIDCAuthenticator { + + private final WebClient webClient; + private final HttpRetryConfig retryConfig; + private final ObjectMapper jsonDecoder; + private final String clientId; + private final String clientSecret; + + private static final String TOKEN_PATH = "/token"; + private static final String GRANT_TYPE_AUTHENTICATE = "robot_credentials"; + private static final String GRANT_TYPE_REFRESH_TOKEN = "refresh_token"; + + private static final String JWT_CLAIM__ISSUED_AT = "iat"; + private static final String JWT_CLAIM__EXPIRES_AT = "exp"; + + private HubOIDCAuthenticator(Builder builder) { + this.webClient = requireNonNull(builder.getWebClient(), "web client must not be null"); + this.retryConfig = requireNonNull(builder.getRetryConfig(), "retry configuration must not be null"); + this.jsonDecoder = requireNonNull(builder.getJsonDecoder(), "json decoder must not be null"); + this.clientId = requireNonNull(builder.getClientId(), "client ID must not be null"); + this.clientSecret = requireNonNull(builder.getClientSecret(), "client secret must not be null"); + } + + @Override + public Mono authenticate() { + return webClient.post() + .uri(TOKEN_PATH) + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .accept(MediaType.APPLICATION_JSON) + .body(BodyInserters.fromFormData("grant_type", GRANT_TYPE_AUTHENTICATE) + .with("id", clientId) + .with("secret", clientSecret)) + .retrieve() + .onStatus(HttpStatusCode::is5xxServerError, + response -> { + var err = new HubAuthException("could not fetch hub access token"); + + log.warn("retrying token request after failed attempt", err); + return Mono.error(err); + }) + .bodyToMono(HubAuthTokenResponse.class) + .flatMap(hubAuthTokenResponse -> { + try { + return Mono.just(new OIDCTokenPair( + parseAccessToken(hubAuthTokenResponse.accessToken), + Optional.ofNullable(parseRefreshToken(hubAuthTokenResponse.refreshToken)))); + } catch (IllegalArgumentException e) { + return Mono.error( + new HubAuthMalformedResponseException("received malformed response from HUB when " + + "trying to fetch access token", e)); + } + }) + .doOnError(err -> log.error(err.getMessage(), err)) + .retryWhen(Retry.backoff(retryConfig.maxRetries(), Duration.ofMillis(retryConfig.retryDelayMs())) + .jitter(0.75) + .filter(err -> err instanceof HubAuthException) + .onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) -> + new HubAccessTokenNotObtainable("exhausted maximum retries of '%d'" + .formatted(retryConfig.maxRetries()))))); + } + + @Override + public Mono refresh(OAuth2Token refreshToken) { + return webClient.post() + .uri(TOKEN_PATH) + .contentType(MediaType.APPLICATION_FORM_URLENCODED) + .accept(MediaType.APPLICATION_JSON) + .body(BodyInserters.fromFormData("grant_type", GRANT_TYPE_REFRESH_TOKEN) + .with("refresh_token", refreshToken.getTokenValue())) + .retrieve() + .onStatus(HttpStatusCode::is5xxServerError, + response -> { + var err = new HubAuthException("could not fetch new access token using refresh token"); + + log.warn("retrying token refresh request after failed attempt", err); + return Mono.error(err); + }) + .bodyToMono(HubAuthTokenResponse.class) + .flatMap(hubAuthTokenResponse -> { + try { + return Mono.just(new OIDCTokenPair( + parseAccessToken(hubAuthTokenResponse.accessToken), + Optional.ofNullable(parseRefreshToken(hubAuthTokenResponse.refreshToken)))); + } catch (IllegalArgumentException e) { + return Mono.error( + new HubAuthMalformedResponseException("received malformed response from HUB when " + + "trying to refresh access token", e)); + } + }) + .doOnError(err -> log.error(err.getMessage(), err)) + .retryWhen(Retry.backoff(retryConfig.maxRetries(), Duration.ofMillis(retryConfig.retryDelayMs())) + .jitter(0.75) + .filter(err -> err instanceof HubAuthException) + .onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) -> + new HubAccessTokenNotObtainable("exhausted maximum retries of '%d'" + .formatted(retryConfig.maxRetries()))))); + } + + private OAuth2AccessToken parseAccessToken(@NonNull String tokenValue) { + try { + var jwt = decodeJwt(tokenValue); + if (jwt.getIssuedAt() == null) { + throw new IllegalArgumentException("JWT is missing issuing claim"); + } + if (jwt.getExpiresAt() == null) { + throw new IllegalArgumentException("JWT is missing expiration claim"); + } + + return new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, + jwt.getTokenValue(), + jwt.getIssuedAt(), + jwt.getExpiresAt()); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("failed to decode access token from Hub"); + } + } + + private OAuth2RefreshToken parseRefreshToken(String tokenValue) { + if (tokenValue == null) { + return null; + } + + try { + var jwt = decodeJwt(tokenValue); + if (jwt.getIssuedAt() == null) { + throw new IllegalArgumentException("JWT is missing issuing claim"); + } + if (jwt.getExpiresAt() == null) { + throw new IllegalArgumentException("JWT is missing expiration claim"); + } + + return new OAuth2RefreshToken( + jwt.getTokenValue(), + jwt.getIssuedAt(), + jwt.getExpiresAt()); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException("failed to decode refresh token from Hub"); + } + } + + private OAuth2Token decodeJwt(@NonNull String tokenValue) throws JsonProcessingException { + var chunks = tokenValue.split("\\."); + + if (chunks.length < 2) { + throw new IllegalArgumentException("could not decode JWT due to missing payload section"); + } + var payload = new String(Base64.getUrlDecoder().decode(chunks[1])); + + var payloadJson = jsonDecoder.readTree(payload); + + var issuedAt = Optional.ofNullable(payloadJson.get(JWT_CLAIM__ISSUED_AT)) + .map(JsonNode::asLong) + .map(Instant::ofEpochSecond) + .orElse(null); + var expiresAt = Optional.ofNullable(payloadJson.get(JWT_CLAIM__EXPIRES_AT)) + .map(JsonNode::asLong) + .map(Instant::ofEpochSecond) + .orElse(null); + + return new HubJWT(tokenValue, issuedAt, expiresAt); + } + + private record HubJWT(String tokenValue, Instant issuedAt, Instant expiresAt) implements OAuth2Token { + @Override + public String getTokenValue() { + return tokenValue; + } + + @Override + public Instant getIssuedAt() { + return issuedAt; + } + + @Override + public Instant getExpiresAt() { + return expiresAt; + } + } + + + public static Builder builder() { + return new Builder(); + } + + @Getter + public static class Builder { + private WebClient webClient; + private HttpRetryConfig retryConfig; + private ObjectMapper jsonDecoder; + private String clientId; + private String clientSecret; + + Builder() { + } + + public Builder usingWebClient(WebClient webClient) { + this.webClient = webClient; + return this; + } + + public Builder withRetryConfig(HttpRetryConfig retryConfig) { + this.retryConfig = retryConfig; + return this; + } + + public Builder withJsonDecoder(ObjectMapper jsonDecoder) { + this.jsonDecoder = jsonDecoder; + return this; + } + + public Builder withAuthCredentials(String clientId, String clientSecret) { + this.clientId = clientId; + this.clientSecret = clientSecret; + return this; + } + + public OIDCAuthenticator build() { + // TODO: add validation step + return new HubOIDCAuthenticator(this); + } + } +} diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilter.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilter.java deleted file mode 100644 index 7749960..0000000 --- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilter.java +++ /dev/null @@ -1,63 +0,0 @@ -package de.privateaim.node_message_broker.common.hub.auth; - -import lombok.NonNull; -import lombok.extern.slf4j.Slf4j; -import org.springframework.http.HttpStatus; -import org.springframework.web.reactive.function.client.ClientRequest; -import org.springframework.web.reactive.function.client.ClientResponse; -import org.springframework.web.reactive.function.client.ExchangeFilterFunction; -import org.springframework.web.reactive.function.client.ExchangeFunction; -import reactor.core.publisher.Mono; - -import static java.util.Objects.requireNonNull; - -/** - * Filter for a {@link org.springframework.web.reactive.function.client.WebClient} which handles re-authentication of - * the request if the request target responds with a 401 (Unauthorized) status code. - */ -@Slf4j -public class RenewAuthTokenFilter implements ExchangeFilterFunction { - - private final HubAuthClient authClient; - - private final String robotId; - - private final String robotSecret; - - /** - * Creates a new instance of the filter with information required for re-authenticating any in-flight requests. - * - * @param authClient client able to communicate to the Hub's auth service - * @param robotId client's robot id used for authentication - * @param robotSecret client's robot secret used for authentication - */ - public RenewAuthTokenFilter(HubAuthClient authClient, String robotId, String robotSecret) { - this.authClient = requireNonNull(authClient, "auth client must not be null"); - this.robotId = robotId; - this.robotSecret = robotSecret; - } - - @Override - public @NonNull Mono filter(@NonNull ClientRequest request, ExchangeFunction next) { - return next.exchange(request).flatMap(response -> { - if (response.statusCode().value() == HttpStatus.UNAUTHORIZED.value()) { - return response.releaseBody() - .then(acquireNewToken(robotId, robotSecret)) - .flatMap(token -> { - var newRequest = ClientRequest.from(request) - .headers(headers -> headers.setBearerAuth(token)) - .build(); - log.warn("retrying request to {} after receiving status code 401 (unauthorized)", - request.url()); - return next.exchange(newRequest); - }); - } else { - return Mono.just(response); - } - }); - } - - private Mono acquireNewToken(String robotId, String robotSecret) { - return authClient.requestAccessToken(robotId, robotSecret); - } -} diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/api/HubAuthTokenResponse.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/api/HubAuthTokenResponse.java index 6cbb41a..7c5f0d6 100644 --- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/api/HubAuthTokenResponse.java +++ b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/api/HubAuthTokenResponse.java @@ -15,7 +15,10 @@ @JsonIgnoreProperties(ignoreUnknown = true) public final class HubAuthTokenResponse { - @JsonProperty("access_token") + @JsonProperty(value = "access_token", required = true) public String accessToken; + @JsonProperty(value = "refresh_token") + public String refreshToken; + } diff --git a/src/main/java/de/privateaim/node_message_broker/message/MessageSpringConfig.java b/src/main/java/de/privateaim/node_message_broker/message/MessageSpringConfig.java index fd61ae7..9e7dc4a 100644 --- a/src/main/java/de/privateaim/node_message_broker/message/MessageSpringConfig.java +++ b/src/main/java/de/privateaim/node_message_broker/message/MessageSpringConfig.java @@ -2,8 +2,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import de.privateaim.node_message_broker.ConfigurationUtil; +import de.privateaim.node_message_broker.common.OIDCAuthenticator; import de.privateaim.node_message_broker.common.hub.HubClient; -import de.privateaim.node_message_broker.common.hub.auth.HubAuthClient; import de.privateaim.node_message_broker.message.crypto.HubMessageCryptoService; import de.privateaim.node_message_broker.message.crypto.MessageCryptoService; import de.privateaim.node_message_broker.message.emit.EmitMessage; @@ -76,10 +76,8 @@ OkHttpClient socketBaseClient(@Qualifier("COMMON_JAVA_SSL_CONTEXT") SSLContext s @Qualifier("HUB_MESSENGER_UNDERLYING_SOCKET") @Bean(destroyMethod = "disconnect") public Socket underlyingMessengerSocket( - @Qualifier("HUB_AUTH_CLIENT") HubAuthClient hubAuthClient, + @Qualifier("HUB_AUTHENTICATOR") OIDCAuthenticator hubAuthenticator, @Qualifier("HUB_MESSAGE_RECEIVER") MessageReceiver messageReceiver, - @Qualifier("HUB_AUTH_ROBOT_SECRET") String hubAuthRobotSecret, - @Qualifier("HUB_AUTH_ROBOT_ID") String hubAuthRobotId, @Qualifier("HUB_MESSENGER_UNDERLYING_SOCKET_SECURE_CLIENT") OkHttpClient secureBaseClient) { IO.Options options = IO.Options.builder() .setPath(null) @@ -96,8 +94,12 @@ public Socket underlyingMessengerSocket( log.error("cannot connect to hub messenger at `{}`", hubMessengerBaseUrl); // we block here since this is a crucial component - options.auth.put("token", hubAuthClient.requestAccessToken(hubAuthRobotId, hubAuthRobotSecret) - .block()); + var oidcTokenPair = hubAuthenticator.authenticate().block(); + if (oidcTokenPair == null) { + throw new RuntimeException("authentication failed - cannot connect to hub messenger at `%s`" + .formatted(hubMessengerBaseUrl)); + } + options.auth.put("token", oidcTokenPair.accessToken().getTokenValue()); log.info("reconnecting to hub messenger at `{}` with new authentication token", hubMessengerBaseUrl); socket.connect(); @@ -110,8 +112,13 @@ public Socket underlyingMessengerSocket( objects -> { log.info("trying to reconnect to hub messenger via socket at `{}", hubMessengerBaseUrl); // we block here since this is a crucial component - options.auth.put("token", hubAuthClient.requestAccessToken(hubAuthRobotId, hubAuthRobotSecret) - .block()); + var oidcTokenPair = hubAuthenticator.authenticate().block(); + if (oidcTokenPair == null) { + throw new RuntimeException("authentication failed - cannot connect to hub messenger at `%s`" + .formatted(hubMessengerBaseUrl)); + } + + options.auth.put("token", oidcTokenPair.accessToken().getTokenValue()); }); socket.on(SOCKET_RECEIVE_HUB_MESSAGE_IDENTIFIER, objects -> { @@ -241,12 +248,6 @@ MessageSubscriptionService messageSubscriptionService( return new MessageSubscriptionServiceImpl(messageSubscriptionRepository); } - @Qualifier("HUB_MESSAGE_RECEIVE_JSON_MAPPER") - @Bean - ObjectMapper hubMessageJsonMapper() { - return new ObjectMapper(); - } - @Qualifier("HUB_MESSAGE_RECEIVE_MIDDLEWARE_DECRYPT") @Bean Function> hubMessageReceiveDecryptionMiddleware( @@ -303,7 +304,7 @@ MessageConsumer hubMessageConsumer( @Qualifier("HUB_MESSAGE_RECEIVER") @Bean MessageReceiver hubMessageReceiver( - @Qualifier("HUB_MESSAGE_RECEIVE_JSON_MAPPER") ObjectMapper jsonMapper, + @Qualifier("HUB_JSON_MAPPER") ObjectMapper jsonMapper, @Qualifier("HUB_MESSAGE_RECEIVE_MIDDLEWARES") List>> middlewares, @Qualifier("HUB_MESSAGE_RECEIVE_CONSUMER") MessageConsumer messageConsumer ) { diff --git a/src/test/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddlewareTest.java b/src/test/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddlewareTest.java new file mode 100644 index 0000000..b205792 --- /dev/null +++ b/src/test/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddlewareTest.java @@ -0,0 +1,317 @@ +package de.privateaim.node_message_broker.common; + +import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus; +import io.jsonwebtoken.Jwts; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.http.HttpMethod; +import org.springframework.http.HttpStatusCode; +import org.springframework.security.oauth2.core.OAuth2AccessToken; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.web.reactive.function.client.ClientRequest; +import org.springframework.web.reactive.function.client.ClientResponse; +import org.springframework.web.reactive.function.client.ExchangeFunction; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import java.net.URI; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Date; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.*; + +@ExtendWith(MockitoExtension.class) +public class OIDCAuthenticatorMiddlewareTest { + + @Mock + private OIDCAuthenticator authenticator; + + @InjectMocks + private OIDCAuthenticatorMiddleware authenticatorMiddleware; + + private OAuth2RefreshToken getSimpleRefreshToken(Instant issuedAt, Instant expiresAt) { + var refreshTokenValue = Jwts.builder() + .subject("test") + .issuedAt(Date.from(issuedAt)) + .expiration(Date.from(expiresAt)) + .compact(); + + return new OAuth2RefreshToken( + refreshTokenValue, + issuedAt, + expiresAt + ); + } + + private OAuth2AccessToken getSimpleAccessToken(Instant issuedAt, Instant expiresAt) { + var accessTokenValue = Jwts.builder() + .subject("test") + .issuedAt(Date.from(issuedAt)) + .expiration(Date.from(expiresAt)) + .compact(); + + return new OAuth2AccessToken( + OAuth2AccessToken.TokenType.BEARER, + accessTokenValue, + issuedAt, + expiresAt + ); + } + + private String getBearerToken(ClientRequest request) { + var authorizationHeaderValue = request.headers().getFirst("Authorization"); + if (authorizationHeaderValue.startsWith("Bearer ")) { + return authorizationHeaderValue.substring("Bearer ".length()); + } + + return null; + } + + @Test + void requestFailsIfNoHostInformationIsPresent() { + var targetResponse = Mockito.mock(ClientResponse.class); + var request = ClientRequest.create(HttpMethod.GET, URI.create("/no-host/just-a-path")).build(); + ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse); + + var responses = new ArrayList(); + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .recordWith(ArrayList::new) + .expectNextCount(1) + .consumeRecordedWith(responses::addAll) + .verifyComplete(); + + assertEquals(1, responses.size()); + assertEquals(HttpStatus.SC_BAD_REQUEST, responses.getFirst().statusCode().value()); + } + + @Test + void firstRequestForAHostRequiresAuthentication() { + var issuedAt = Instant.now(); + var expiresAt = issuedAt.plus(Duration.ofHours(1)); + + var accessToken = getSimpleAccessToken(issuedAt, expiresAt); + var refreshToken = getSimpleRefreshToken(issuedAt, expiresAt); + + var targetResponse = Mockito.mock(ClientResponse.class); + var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build(); + ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse); + + doReturn(Mono.just(new OIDCTokenPair(accessToken, Optional.of(refreshToken)))) + .when(authenticator).authenticate(); + doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode(); + + + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + verify(authenticator, Mockito.times(1)).authenticate(); + } + + @Test + void existingAccessTokenGetsReusedIfNotExpired() { + var issuedAt = Instant.now(); + var expiresAt = issuedAt.plus(Duration.ofHours(1)); + + var accessToken = getSimpleAccessToken(issuedAt, expiresAt); + var refreshToken = getSimpleRefreshToken(issuedAt, expiresAt); + + var targetResponse = Mockito.mock(ClientResponse.class); + var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build(); + ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse); + + doReturn(Mono.just(new OIDCTokenPair(accessToken, Optional.of(refreshToken)))) + .when(authenticator).authenticate(); + doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode(); + + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + verify(authenticator, Mockito.times(1)).authenticate(); + } + + @Test + void newAccessTokenIsAcquiredIfExpiredUsingRefreshTokenIfPresent() { + var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(2)), + Instant.now().minus(Duration.ofHours(1))); + var refreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(1))); + var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1))); + var newRefreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(2))); + + + var targetResponse = Mockito.mock(ClientResponse.class); + var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build(); + ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse); + + doReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.of(refreshToken)))) + .when(authenticator).authenticate(); + doReturn(Mono.just(new OIDCTokenPair(newAccessToken, Optional.of(newRefreshToken)))) + .when(authenticator).refresh(refreshToken); + doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode(); + + // This in fact runs the first request with an expired token. However, we are not interested in that fact + // here. It's more about filling the cache for the next request to the same host. + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + verify(authenticator, Mockito.times(1)).authenticate(); + verify(authenticator, Mockito.times(1)).refresh(refreshToken); + } + + @Test + void newAccessTokenIsAcquiredEvenIfRefreshTokenIsAbsent() { + var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(2)), + Instant.now().minus(Duration.ofHours(1))); + var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1))); + + + var targetResponse = Mockito.mock(ClientResponse.class); + var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build(); + ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse); + + doReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.empty())), + Mono.just(new OIDCTokenPair(newAccessToken, Optional.empty()))) + .when(authenticator).authenticate(); + doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode(); + + // This in fact runs the first request with an expired token. However, we are not interested in that fact + // here. It's more about filling the cache for the next request to the same host. + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + verify(authenticator, Mockito.times(2)).authenticate(); + } + + @Test + void newAccessAndRefreshTokenPairAcquiredIfBothHaveExpired() { + var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(3)), + Instant.now().minus(Duration.ofHours(1))); + var expiredRefreshToken = getSimpleRefreshToken(Instant.now().minus(Duration.ofHours(3)), + Instant.now().minus(Duration.ofHours(2))); + var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1))); + var newRefreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(2))); + + + var targetResponse = Mockito.mock(ClientResponse.class); + var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build(); + ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse); + + when(authenticator.authenticate()) + .thenReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.of(expiredRefreshToken)))) + .thenReturn(Mono.just(new OIDCTokenPair(newAccessToken, Optional.of(newRefreshToken)))); + doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode(); + + // This in fact runs the first request with an expired token. However, we are not interested in that fact + // here. It's more about filling the cache for the next request to the same host. + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + + verify(authenticator, Mockito.times(2)).authenticate(); + verify(authenticator, Mockito.never()).refresh(Mockito.any(OAuth2RefreshToken.class)); + } + + @Test + void attemptsNewRequestWithNewAccessTokenIfServerRespondsWithUnauthorized() { + var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(2)), + Instant.now().minus(Duration.ofHours(1))); + var refreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(1))); + var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1))); + var newRefreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(2))); + + var targetResponse = Mockito.mock(ClientResponse.class); + var unauthorizedTargetResponse = ClientResponse.create(HttpStatusCode.valueOf(HttpStatus.SC_UNAUTHORIZED)) + .body("unauthorized").build(); + var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build(); + ExchangeFunction targetExchangeFunction = Mockito.mock(ExchangeFunction.class); + + doReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.of(refreshToken)))) + .when(authenticator).authenticate(); + doReturn(Mono.just(new OIDCTokenPair(newAccessToken, Optional.of(newRefreshToken)))) + .when(authenticator).refresh(refreshToken); + + var requestCaptor = ArgumentCaptor.forClass(ClientRequest.class); + doReturn(Mono.just(unauthorizedTargetResponse), Mono.just(targetResponse)) + .when(targetExchangeFunction) + .exchange(requestCaptor.capture()); + + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(targetResponse) + .expectComplete() + .verify(); + verify(authenticator, Mockito.times(1)).authenticate(); + verify(authenticator, Mockito.times(1)).refresh(refreshToken); + + var capturedRequests = requestCaptor.getAllValues(); + + assertEquals(2, capturedRequests.size()); + assertEquals(expiredAccessToken.getTokenValue(), getBearerToken(capturedRequests.getFirst())); + assertEquals(newAccessToken.getTokenValue(), getBearerToken(capturedRequests.getLast())); + } + + @Test + void requestEventuallyFailsIfRetryIsStillUnauthorized() { + var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(2)), + Instant.now().minus(Duration.ofHours(1))); + var refreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(1))); + var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1))); + var newRefreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(2))); + + var unauthorizedTargetResponse = ClientResponse.create(HttpStatusCode.valueOf(HttpStatus.SC_UNAUTHORIZED)) + .body("unauthorized").build(); + var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build(); + ExchangeFunction targetExchangeFunction = Mockito.mock(ExchangeFunction.class); + + doReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.of(refreshToken)))) + .when(authenticator).authenticate(); + doReturn(Mono.just(new OIDCTokenPair(newAccessToken, Optional.of(newRefreshToken)))) + .when(authenticator).refresh(refreshToken); + + var requestCaptor = ArgumentCaptor.forClass(ClientRequest.class); + doReturn(Mono.just(unauthorizedTargetResponse)) + .when(targetExchangeFunction) + .exchange(requestCaptor.capture()); + + StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction)) + .expectNext(unauthorizedTargetResponse) + .expectComplete() + .verify(); + verify(authenticator, Mockito.times(1)).authenticate(); + verify(authenticator, Mockito.times(1)).refresh(refreshToken); + + var capturedRequests = requestCaptor.getAllValues(); + + assertEquals(2, capturedRequests.size()); + assertEquals(expiredAccessToken.getTokenValue(), getBearerToken(capturedRequests.getFirst())); + assertEquals(newAccessToken.getTokenValue(), getBearerToken(capturedRequests.getLast())); + } +} diff --git a/src/test/java/de/privateaim/node_message_broker/common/hub/HttpHubClientIT.java b/src/test/java/de/privateaim/node_message_broker/common/hub/HttpHubClientIT.java index cebb5c8..bd724e2 100644 --- a/src/test/java/de/privateaim/node_message_broker/common/hub/HttpHubClientIT.java +++ b/src/test/java/de/privateaim/node_message_broker/common/hub/HttpHubClientIT.java @@ -1,6 +1,7 @@ package de.privateaim.node_message_broker.common.hub; import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus; +import de.privateaim.node_message_broker.common.HttpRetryConfig; import de.privateaim.node_message_broker.common.hub.api.AnalysisNode; import de.privateaim.node_message_broker.common.hub.api.HubResponseContainer; import de.privateaim.node_message_broker.common.hub.api.Node; @@ -50,10 +51,7 @@ public class HttpHubClientIT { void setUp() { mockWebServer = new MockWebServer(); var webClient = WebClient.create(mockWebServer.url("/").toString()); - httpHubClient = new HttpHubClient(webClient, new HttpHubClientConfig.Builder() - .withMaxRetries(MAX_RETRIES) - .withRetryDelayMs(RETRY_DELAY_MILLIS) - .build()); + httpHubClient = new HttpHubClient(webClient, new HttpRetryConfig(MAX_RETRIES, RETRY_DELAY_MILLIS)); } @AfterEach diff --git a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientIT.java b/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientIT.java deleted file mode 100644 index b5c620f..0000000 --- a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientIT.java +++ /dev/null @@ -1,111 +0,0 @@ -package de.privateaim.node_message_broker.common.hub.auth; - -import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus; -import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.message.BasicNameValuePair; -import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.net.URLEncodedUtils; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.springframework.web.reactive.function.client.WebClient; -import org.testcontainers.shaded.com.fasterxml.jackson.core.JsonProcessingException; -import org.testcontainers.shaded.com.fasterxml.jackson.databind.ObjectMapper; -import reactor.test.StepVerifier; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Map; - -import static com.mongodb.assertions.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class HttpHubAuthClientIT { - - private MockWebServer mockWebServer; - - private HubAuthClient httpHubAuthClient; - - private static final int MAX_RETRIES = 5; - private static final int RETRY_DELAY_MILLIS = 10; // keeping it short for testing purposes! - - @BeforeEach - void setUp() { - mockWebServer = new MockWebServer(); - var webClient = WebClient.create(mockWebServer.url("/").toString()); - httpHubAuthClient = new HttpHubAuthClient(webClient, new HttpHubAuthClientConfig.Builder() - .withMaxRetries(MAX_RETRIES) - .withRetryDelayMs(RETRY_DELAY_MILLIS) - .build()); - } - - @AfterEach - void tearDown() throws IOException { - mockWebServer.shutdown(); - } - - @Test - void obtainingAccessTokenSucceedsOnFirstTry() throws JsonProcessingException, InterruptedException { - mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) - .setHeader("Content-Type", "application/json") - .setBody( - new ObjectMapper().writeValueAsString(Map.of("access_token", "test")) - )); - - StepVerifier.create(httpHubAuthClient.requestAccessToken("some-client-id", "some-client-secret")) - .expectNext("test") - .verifyComplete(); - - assertEquals(1, mockWebServer.getRequestCount()); - var recordedRequest = mockWebServer.takeRequest(); - assertEquals("/token", recordedRequest.getPath()); - - var urlEncodedParams = URLEncodedUtils.parse(recordedRequest.getBody().readUtf8(), StandardCharsets.UTF_8); - - assertEquals(3, urlEncodedParams.size()); - assertEquals("robot_credentials", urlEncodedParams.getFirst().getValue()); - - assertTrue(urlEncodedParams.containsAll(List.of( - new BasicNameValuePair("grant_type", "robot_credentials"), - new BasicNameValuePair("id", "some-client-id"), - new BasicNameValuePair("secret", "some-client-secret") - ))); - } - - @Test - void obtainingAccessTokenSucceedsWithinRetryRange() throws JsonProcessingException, InterruptedException { - mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE)); - mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) - .setHeader("Content-Type", "application/json") - .setBody( - new ObjectMapper().writeValueAsString(Map.of("access_token", "test")) - )); - - StepVerifier.create(httpHubAuthClient.requestAccessToken("some-client-id", "some-client-secret")) - .expectNext("test") - .verifyComplete(); - - assertEquals(2, mockWebServer.getRequestCount()); - - for (int i = 0; i < 2; i++) { - var recordedRequest = mockWebServer.takeRequest(); - assertEquals("/token", recordedRequest.getPath()); - } - } - - @Test - void obtainingAccessTokenFailsOfMaximumNumberOfRetriesOnServerError() throws InterruptedException { - for (int i = 0; i < (MAX_RETRIES + 1); i++) { - mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE)); - } - - StepVerifier.create(httpHubAuthClient.requestAccessToken("some-client-id", "some-client-secret")) - .verifyError(HubAccessTokenNotObtainable.class); - - for (int i = 0; i < (MAX_RETRIES + 1); i++) { - var recordedRequest = mockWebServer.takeRequest(); - assertEquals("/token", recordedRequest.getPath()); - } - } -} diff --git a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticatorIT.java b/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticatorIT.java new file mode 100644 index 0000000..4f43adf --- /dev/null +++ b/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticatorIT.java @@ -0,0 +1,457 @@ +package de.privateaim.node_message_broker.common.hub.auth; + +import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus; +import de.privateaim.node_message_broker.common.HttpRetryConfig; +import de.privateaim.node_message_broker.common.OIDCAuthenticator; +import de.privateaim.node_message_broker.common.OIDCTokenPair; +import io.jsonwebtoken.Jwts; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.springframework.security.oauth2.core.OAuth2RefreshToken; +import org.springframework.security.oauth2.core.OAuth2Token; +import org.springframework.web.reactive.function.client.WebClient; +import org.testcontainers.shaded.com.fasterxml.jackson.core.JsonProcessingException; +import org.testcontainers.shaded.com.fasterxml.jackson.databind.ObjectMapper; +import reactor.test.StepVerifier; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Date; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +public class HubOIDCAuthenticatorIT { + + private MockWebServer mockWebServer; + private OIDCAuthenticator authenticator; + + private static final int MAX_RETRIES = 5; + private static final int RETRY_DELAY_MILLIS = 10; // keeping it short for testing purposes! + + private static final String CLIENT_ID = "some-client-id"; + private static final String CLIENT_SECRET = "some-client-secret"; + + @BeforeEach + void setUp() { + mockWebServer = new MockWebServer(); + authenticator = HubOIDCAuthenticator.builder() + .usingWebClient(WebClient.create(mockWebServer.url("/").toString())) + .withRetryConfig(new HttpRetryConfig(MAX_RETRIES, RETRY_DELAY_MILLIS)) + .withJsonDecoder(new com.fasterxml.jackson.databind.ObjectMapper()) + .withAuthCredentials(CLIENT_ID, CLIENT_SECRET) + .build(); + } + + @AfterEach + void tearDown() throws IOException { + mockWebServer.shutdown(); + } + + @Nested + class AuthenticationTests { + + @Test + void succeedsEvenIfServerDoesNotSendRefreshToken() throws JsonProcessingException { + var accessToken = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of("access_token", accessToken)) + )); + + + var oidcTokenPair = new ArrayList(); + StepVerifier.create(authenticator.authenticate()) + .recordWith(ArrayList::new) + .expectNextCount(1) + .consumeRecordedWith(oidcTokenPair::addAll) + .expectComplete() + .verify(); + + assertEquals(1, oidcTokenPair.size()); + assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue()); + assertTrue(oidcTokenPair.getFirst().refreshToken().isEmpty()); + } + + @Test + void failsIfServerDoesNotSendAccessToken() throws JsonProcessingException { + var refreshToken = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of("refresh_token", refreshToken)) + )); + + StepVerifier.create(authenticator.authenticate()) + .expectError() + .verify(); + } + + @Test + void failsIfAccessTokenIsNotAValidJWT() throws JsonProcessingException { + var tokenValue = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of( + "access_token", "test", + "refresh_token", tokenValue)) + )); + + StepVerifier.create(authenticator.authenticate()) + .expectError(HubAuthMalformedResponseException.class) + .verify(); + } + + @Test + void failsIfRefreshTokenIsNotAValidJWT() throws JsonProcessingException { + var tokenValue = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of( + "access_token", tokenValue, + "refresh_token", "test_refresh")) + )); + + StepVerifier.create(authenticator.authenticate()) + .expectError(HubAuthMalformedResponseException.class) + .verify(); + } + + @Test + void authenticateSucceedsOnFirstTry() throws JsonProcessingException { + var accessToken = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + var refreshToken = Jwts.builder() + .subject("test_refresh") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of( + "access_token", accessToken, + "refresh_token", refreshToken + )))); + + var oidcTokenPair = new ArrayList(); + StepVerifier.create(authenticator.authenticate()) + .recordWith(ArrayList::new) + .expectNextCount(1) + .consumeRecordedWith(oidcTokenPair::addAll) + .expectComplete() + .verify(); + + assertEquals(1, oidcTokenPair.size()); + assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue()); + assertTrue(oidcTokenPair.getFirst().refreshToken().isPresent()); + assertEquals(refreshToken, oidcTokenPair.getFirst().refreshToken().get().getTokenValue()); + } + + @Test + void authenticateSucceedsWithinRetryRange() throws JsonProcessingException, InterruptedException { + if (MAX_RETRIES < 2) { + fail("misconfigured test environment"); + } + + var accessToken = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + var refreshToken = Jwts.builder() + .subject("test_refresh") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE)); + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of( + "access_token", accessToken, + "refresh_token", refreshToken + )) + )); + + var oidcTokenPair = new ArrayList(); + StepVerifier.create(authenticator.authenticate()) + .recordWith(ArrayList::new) + .expectNextCount(1) + .consumeRecordedWith(oidcTokenPair::addAll) + .expectComplete() + .verify(); + + assertEquals(2, mockWebServer.getRequestCount()); + assertEquals(1, oidcTokenPair.size()); + assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue()); + assertTrue(oidcTokenPair.getFirst().refreshToken().isPresent()); + assertEquals(refreshToken, oidcTokenPair.getFirst().refreshToken().get().getTokenValue()); + + for (int i = 0; i < 2; i++) { + var recordedRequest = mockWebServer.takeRequest(); + assertEquals("/token", recordedRequest.getPath()); + } + } + + @Test + void authenticateFinallyFailsAfterExhaustedRetryCount() throws InterruptedException { + for (int i = 0; i < (MAX_RETRIES + 1); i++) { + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE)); + } + + StepVerifier.create(authenticator.authenticate()) + .verifyError(HubAccessTokenNotObtainable.class); + + for (int i = 0; i < (MAX_RETRIES + 1); i++) { + var recordedRequest = mockWebServer.takeRequest(); + assertEquals("/token", recordedRequest.getPath()); + } + } + + // TODO: add tests for checking what happens if issuedAt and expiresAt are missing... + } + + @Nested + class RefreshTests { + private OAuth2Token getSimpleRefreshToken() { + var issuedAt = Instant.now(); + var expiresAt = Instant.now().plus(Duration.ofDays(5)); + + var refreshTokenValue = Jwts.builder() + .subject("test") + .issuedAt(Date.from(issuedAt)) + .expiration(Date.from(expiresAt)) + .compact(); + + return new OAuth2RefreshToken( + refreshTokenValue, + issuedAt, + expiresAt + ); + } + + @Test + void succeedsEvenIfServerDoesNotSendRefreshToken() throws JsonProcessingException { + var accessToken = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of("access_token", accessToken)) + )); + + var oidcTokenPair = new ArrayList(); + StepVerifier.create(authenticator.authenticate()) + .recordWith(ArrayList::new) + .expectNextCount(1) + .consumeRecordedWith(oidcTokenPair::addAll) + .expectComplete() + .verify(); + + assertEquals(1, oidcTokenPair.size()); + assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue()); + assertTrue(oidcTokenPair.getFirst().refreshToken().isEmpty()); + } + + @Test + void failsIfServerDoesNotSendAccessToken() throws JsonProcessingException { + var refreshToken = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of("refresh_token", refreshToken)) + )); + + StepVerifier.create(authenticator.refresh(getSimpleRefreshToken())) + .expectError() + .verify(); + } + + @Test + void failsIfAccessTokenIsNotAValidJWT() throws JsonProcessingException { + var tokenValue = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of( + "access_token", "test", + "refresh_token", tokenValue)) + )); + + StepVerifier.create(authenticator.refresh(getSimpleRefreshToken())) + .expectError(HubAuthMalformedResponseException.class) + .verify(); + } + + @Test + void failsIfRefreshTokenIsNotAValidJWT() throws JsonProcessingException { + var tokenValue = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of( + "access_token", tokenValue, + "refresh_token", "test_refresh")) + )); + + StepVerifier.create(authenticator.refresh(getSimpleRefreshToken())) + .expectError(HubAuthMalformedResponseException.class) + .verify(); + } + + @Test + void refreshSucceedsOnFirstTry() throws JsonProcessingException { + var accessToken = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + var refreshToken = Jwts.builder() + .subject("test_refresh") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of( + "access_token", accessToken, + "refresh_token", refreshToken + )))); + + var oidcTokenPair = new ArrayList(); + StepVerifier.create(authenticator.refresh(getSimpleRefreshToken())) + .recordWith(ArrayList::new) + .expectNextCount(1) + .consumeRecordedWith(oidcTokenPair::addAll) + .expectComplete() + .verify(); + + assertEquals(1, oidcTokenPair.size()); + assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue()); + assertTrue(oidcTokenPair.getFirst().refreshToken().isPresent()); + assertEquals(refreshToken, oidcTokenPair.getFirst().refreshToken().get().getTokenValue()); + } + + @Test + void refreshSucceedsWithinRetryRange() throws JsonProcessingException, InterruptedException { + if (MAX_RETRIES < 2) { + fail("misconfigured test environment"); + } + + var accessToken = Jwts.builder() + .subject("test") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + var refreshToken = Jwts.builder() + .subject("test_refresh") + .issuedAt(Date.from(Instant.now())) + .expiration(Date.from(Instant.now().plus(Duration.ofDays(5)))) + .compact(); + + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE)); + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK) + .setHeader("Content-Type", "application/json") + .setBody( + new ObjectMapper().writeValueAsString(Map.of( + "access_token", accessToken, + "refresh_token", refreshToken + )) + )); + + var oidcTokenPair = new ArrayList(); + StepVerifier.create(authenticator.refresh(getSimpleRefreshToken())) + .recordWith(ArrayList::new) + .expectNextCount(1) + .consumeRecordedWith(oidcTokenPair::addAll) + .expectComplete() + .verify(); + + assertEquals(2, mockWebServer.getRequestCount()); + assertEquals(1, oidcTokenPair.size()); + assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue()); + assertTrue(oidcTokenPair.getFirst().refreshToken().isPresent()); + assertEquals(refreshToken, oidcTokenPair.getFirst().refreshToken().get().getTokenValue()); + + for (int i = 0; i < 2; i++) { + var recordedRequest = mockWebServer.takeRequest(); + assertEquals("/token", recordedRequest.getPath()); + } + } + + @Test + void refreshFinallyFailsAfterExhaustedRetryCount() throws InterruptedException { + for (int i = 0; i < (MAX_RETRIES + 1); i++) { + mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE)); + } + + StepVerifier.create(authenticator.refresh(getSimpleRefreshToken())) + .verifyError(HubAccessTokenNotObtainable.class); + + for (int i = 0; i < (MAX_RETRIES + 1); i++) { + var recordedRequest = mockWebServer.takeRequest(); + assertEquals("/token", recordedRequest.getPath()); + } + } + } +} diff --git a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilterTest.java b/src/test/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilterTest.java deleted file mode 100644 index 9f35113..0000000 --- a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilterTest.java +++ /dev/null @@ -1,106 +0,0 @@ -package de.privateaim.node_message_broker.common.hub.auth; - -import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.Mockito; -import org.mockito.junit.jupiter.MockitoExtension; -import org.springframework.http.HttpMethod; -import org.springframework.http.HttpStatusCode; -import org.springframework.web.reactive.function.client.ClientRequest; -import org.springframework.web.reactive.function.client.ClientResponse; -import org.springframework.web.reactive.function.client.ExchangeFunction; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - -import java.net.URI; - -import static org.mockito.Mockito.*; - -@ExtendWith(MockitoExtension.class) -public class RenewAuthTokenFilterTest { - - private static final String ROBOT_ID = "test-id"; - private static final String ROBOT_SECRET = "test-secret"; - - @Mock - private HubAuthClient hubAuthClientMock; - private RenewAuthTokenFilter renewAuthTokenFilter; - - @BeforeEach - void setUp() { - renewAuthTokenFilter = new RenewAuthTokenFilter(hubAuthClientMock, ROBOT_ID, ROBOT_SECRET); - } - - @AfterEach - void tearDown() { - Mockito.reset(hubAuthClientMock); - } - - @Test - void requestDoesNotGetReAuthenticatedOnSuccess() { - var targetResponse = Mockito.mock(ClientResponse.class); - doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode(); - var request = ClientRequest.create(HttpMethod.GET, URI.create("/some-resource")).build(); - - ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse); - - var actualResponse = renewAuthTokenFilter.filter(request, targetExchangeFunction); - - StepVerifier.create(actualResponse) - .expectNext(targetResponse) - .expectComplete() - .verify(); - } - - @Test - void requestGetsReAuthenticatedOnFailure() { - doReturn(Mono.just("some-access-token")).when(hubAuthClientMock).requestAccessToken(ROBOT_ID, ROBOT_SECRET); - var unauthorizedTargetResponse = ClientResponse.create(HttpStatusCode.valueOf(HttpStatus.SC_UNAUTHORIZED)) - .body("unauthorized").build(); - var request = ClientRequest.create(HttpMethod.GET, URI.create("/some-resource")).build(); - - var reAuthenticatedTargetResponse = Mockito.mock(ClientResponse.class); - var targetExchangeFunction = Mockito.mock(ExchangeFunction.class); - doReturn(Mono.just(unauthorizedTargetResponse), Mono.just(reAuthenticatedTargetResponse)) - .when(targetExchangeFunction) - .exchange(Mockito.any(ClientRequest.class)); - - var actualResponse = renewAuthTokenFilter.filter(request, targetExchangeFunction); - - StepVerifier.create(actualResponse) - .expectNext(reAuthenticatedTargetResponse) - .expectComplete() - .verify(); - - verify(hubAuthClientMock, times(1)) - .requestAccessToken(ROBOT_ID, ROBOT_SECRET); - } - - @Test - void requestUltimatelyFailsOnReAuthenticatedRequestFailure() { - doReturn(Mono.error(new Exception("some-error"))).when(hubAuthClientMock) - .requestAccessToken(ROBOT_ID, ROBOT_SECRET); - - var unauthorizedTargetResponse = ClientResponse.create(HttpStatusCode.valueOf(HttpStatus.SC_UNAUTHORIZED)) - .body("unauthorized").build(); - var request = ClientRequest.create(HttpMethod.GET, URI.create("/some-resource")).build(); - - var targetExchangeFunction = Mockito.mock(ExchangeFunction.class); - doReturn(Mono.just(unauthorizedTargetResponse)) - .when(targetExchangeFunction) - .exchange(Mockito.any(ClientRequest.class)); - - var actualResponse = renewAuthTokenFilter.filter(request, targetExchangeFunction); - - StepVerifier.create(actualResponse) - .expectError(Exception.class) - .verify(); - - verify(hubAuthClientMock, times(1)) - .requestAccessToken(ROBOT_ID, ROBOT_SECRET); - } -} diff --git a/src/test/java/de/privateaim/node_message_broker/discovery/DiscoveryServiceIT.java b/src/test/java/de/privateaim/node_message_broker/discovery/DiscoveryServiceIT.java index 8d72c37..0290ad2 100644 --- a/src/test/java/de/privateaim/node_message_broker/discovery/DiscoveryServiceIT.java +++ b/src/test/java/de/privateaim/node_message_broker/discovery/DiscoveryServiceIT.java @@ -1,8 +1,8 @@ package de.privateaim.node_message_broker.discovery; import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus; +import de.privateaim.node_message_broker.common.HttpRetryConfig; import de.privateaim.node_message_broker.common.hub.HttpHubClient; -import de.privateaim.node_message_broker.common.hub.HttpHubClientConfig; import de.privateaim.node_message_broker.common.hub.api.AnalysisNode; import de.privateaim.node_message_broker.common.hub.api.HubResponseContainer; import de.privateaim.node_message_broker.common.hub.api.Node; @@ -39,10 +39,7 @@ public final class DiscoveryServiceIT { void setUp() { mockWebServer = new MockWebServer(); var noAuthWebClient = WebClient.create(mockWebServer.url("/").toString()); - var hubClientCfg = new HttpHubClientConfig.Builder() - .withMaxRetries(0) - .withRetryDelayMs(0) - .build(); + var hubClientCfg = new HttpRetryConfig(0, 0); var hubClient = Mockito.spy(new HttpHubClient(noAuthWebClient, hubClientCfg)); discoveryService = new DiscoveryService(hubClient, SELF_ROBOT_ID); } diff --git a/src/test/java/de/privateaim/node_message_broker/message/MessageServiceIT.java b/src/test/java/de/privateaim/node_message_broker/message/MessageServiceIT.java index a5a08c9..caedb6a 100644 --- a/src/test/java/de/privateaim/node_message_broker/message/MessageServiceIT.java +++ b/src/test/java/de/privateaim/node_message_broker/message/MessageServiceIT.java @@ -4,8 +4,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.JsonNodeFactory; import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus; +import de.privateaim.node_message_broker.common.HttpRetryConfig; import de.privateaim.node_message_broker.common.hub.HttpHubClient; -import de.privateaim.node_message_broker.common.hub.HttpHubClientConfig; import de.privateaim.node_message_broker.common.hub.api.AnalysisNode; import de.privateaim.node_message_broker.common.hub.api.HubResponseContainer; import de.privateaim.node_message_broker.common.hub.api.Node; @@ -57,10 +57,7 @@ void setUp() { mockSocket = Mockito.mock(Socket.class); spyMessageEmitter = Mockito.spy(new HubMessageEmitter(mockSocket)); var webClient = WebClient.create(mockWebServer.url("/").toString()); - var httpHubClient = new HttpHubClient(webClient, new HttpHubClientConfig.Builder() - .withMaxRetries(0) - .withRetryDelayMs(0) - .build()); + var httpHubClient = new HttpHubClient(webClient, new HttpRetryConfig(0, 0)); messageService = new MessageService(spyMessageEmitter, httpHubClient, SELF_ROBOT_ID); emitMessageCaptor = ArgumentCaptor.forClass(EmitMessage.class);