diff --git a/pom.xml b/pom.xml
index 4e5977f..1fda826 100644
--- a/pom.xml
+++ b/pom.xml
@@ -166,6 +166,27 @@
4.12.0
test
+
+
+ io.jsonwebtoken
+ jjwt-api
+ 0.12.6
+ test
+
+
+
+ io.jsonwebtoken
+ jjwt-impl
+ 0.12.6
+ test
+
+
+
+ io.jsonwebtoken
+ jjwt-jackson
+ 0.12.6
+ test
+
diff --git a/src/main/java/de/privateaim/node_message_broker/common/CommonSpringConfig.java b/src/main/java/de/privateaim/node_message_broker/common/CommonSpringConfig.java
index 60f372c..4e231d8 100644
--- a/src/main/java/de/privateaim/node_message_broker/common/CommonSpringConfig.java
+++ b/src/main/java/de/privateaim/node_message_broker/common/CommonSpringConfig.java
@@ -1,13 +1,11 @@
package de.privateaim.node_message_broker.common;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
import de.privateaim.node_message_broker.ConfigurationUtil;
import de.privateaim.node_message_broker.common.hub.HttpHubClient;
-import de.privateaim.node_message_broker.common.hub.HttpHubClientConfig;
import de.privateaim.node_message_broker.common.hub.HubClient;
-import de.privateaim.node_message_broker.common.hub.auth.HttpHubAuthClient;
-import de.privateaim.node_message_broker.common.hub.auth.HttpHubAuthClientConfig;
-import de.privateaim.node_message_broker.common.hub.auth.HubAuthClient;
-import de.privateaim.node_message_broker.common.hub.auth.RenewAuthTokenFilter;
+import de.privateaim.node_message_broker.common.hub.auth.HubOIDCAuthenticator;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
@@ -24,6 +22,9 @@
@Configuration
public class CommonSpringConfig {
+ private static final int EXCHANGE__MAX_RETRIES = 5;
+ private static final int EXCHANGE__MAX_RETRY_DELAY_MS = 1000;
+
@Value("${app.hub.baseUrl}")
private String hubCoreBaseUrl;
@@ -48,10 +49,16 @@ public String hubAuthRobotId() {
return hubAuthRobotId;
}
+ @Qualifier("HUB_EXCHANGE_RETRY_CONFIG")
+ @Bean
+ HttpRetryConfig exchangeRetryConfig() {
+ return new HttpRetryConfig(EXCHANGE__MAX_RETRIES, EXCHANGE__MAX_RETRY_DELAY_MS);
+ }
+
@Qualifier("HUB_CORE_WEB_CLIENT")
@Bean
public WebClient alwaysReAuthenticatedWebClient(
- @Qualifier("HUB_AUTH_RENEW_TOKEN") ExchangeFilterFunction renewTokenFilter,
+ @Qualifier("HUB_AUTHENTICATION_MIDDLEWARE") ExchangeFilterFunction authenticationMiddleware,
@Qualifier("BASE_SSL_HTTP_CLIENT_CONNECTOR") ReactorClientHttpConnector baseSslHttpClientConnector) {
// We can't use Spring's default security mechanisms out-of-the-box here since HUB uses a non-standard grant
// type which is not supported. There's a way by using a custom grant type accompanied by a client manager.
@@ -65,19 +72,17 @@ public WebClient alwaysReAuthenticatedWebClient(
return WebClient.builder()
.uriBuilderFactory(factory)
.defaultHeaders(httpHeaders -> httpHeaders.setAccept(List.of(MediaType.APPLICATION_JSON)))
- .filter(renewTokenFilter)
+ .filter(authenticationMiddleware)
.clientConnector(baseSslHttpClientConnector)
.build();
}
@Bean
- public HubClient hubClient(@Qualifier("HUB_CORE_WEB_CLIENT") WebClient alwaysReAuthenticatedWebClient) {
- var clientConfig = new HttpHubClientConfig.Builder()
- .withMaxRetries(5)
- .withRetryDelayMs(1000)
- .build();
-
- return new HttpHubClient(alwaysReAuthenticatedWebClient, clientConfig);
+ public HubClient hubClient(
+ @Qualifier("HUB_CORE_WEB_CLIENT") WebClient webClient,
+ @Qualifier("HUB_EXCHANGE_RETRY_CONFIG") HttpRetryConfig retryConfig
+ ) {
+ return new HttpHubClient(webClient, retryConfig);
}
@Qualifier("HUB_AUTH_WEB_CLIENT")
@@ -91,21 +96,37 @@ WebClient hubAuthWebClient(
.build();
}
- @Qualifier("HUB_AUTH_CLIENT")
+ @Qualifier("HUB_JSON_MAPPER")
+ @Bean
+ ObjectMapper simpleJsonMapper() {
+ return new ObjectMapper()
+ .findAndRegisterModules()
+ .disable(DeserializationFeature.FAIL_ON_IGNORED_PROPERTIES)
+ .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
+ }
+
+ @Qualifier("HUB_AUTHENTICATOR")
@Bean
- public HubAuthClient hubAuthClient(@Qualifier("HUB_AUTH_WEB_CLIENT") WebClient webClient) {
- var clientConfig = new HttpHubAuthClientConfig.Builder()
- .withMaxRetries(5)
- .withRetryDelayMs(1000)
+ OIDCAuthenticator hubAuthenticator(
+ @Qualifier("HUB_AUTH_WEB_CLIENT") WebClient webClient,
+ @Qualifier("HUB_EXCHANGE_RETRY_CONFIG") HttpRetryConfig retryConfig,
+ @Qualifier("HUB_AUTH_ROBOT_ID") String hubAuthRobotId,
+ @Qualifier("HUB_AUTH_ROBOT_SECRET") String hubAuthRobotSecret,
+ @Qualifier("HUB_JSON_MAPPER") ObjectMapper jsonMapper
+ ) {
+ return HubOIDCAuthenticator.builder()
+ .usingWebClient(webClient)
+ .withRetryConfig(retryConfig)
+ .withAuthCredentials(hubAuthRobotId, hubAuthRobotSecret)
+ .withJsonDecoder(jsonMapper)
.build();
- return new HttpHubAuthClient(webClient, clientConfig);
}
- @Qualifier("HUB_AUTH_RENEW_TOKEN")
+ @Qualifier("HUB_AUTHENTICATION_MIDDLEWARE")
@Bean
- ExchangeFilterFunction renewAuthTokenFilter(
- @Qualifier("HUB_AUTH_CLIENT") HubAuthClient hubAuthClient,
- @Qualifier("HUB_AUTH_ROBOT_SECRET") String hubAuthRobotSecret) {
- return new RenewAuthTokenFilter(hubAuthClient, hubAuthRobotId, hubAuthRobotSecret);
+ ExchangeFilterFunction hubAuthenticationMiddleware(
+ @Qualifier("HUB_AUTHENTICATOR") OIDCAuthenticator authenticator
+ ) {
+ return new OIDCAuthenticatorMiddleware(authenticator);
}
}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClientConfig.java b/src/main/java/de/privateaim/node_message_broker/common/HttpRetryConfig.java
similarity index 73%
rename from src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClientConfig.java
rename to src/main/java/de/privateaim/node_message_broker/common/HttpRetryConfig.java
index 1ba97a8..e70a523 100644
--- a/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClientConfig.java
+++ b/src/main/java/de/privateaim/node_message_broker/common/HttpRetryConfig.java
@@ -1,12 +1,12 @@
-package de.privateaim.node_message_broker.common.hub;
+package de.privateaim.node_message_broker.common;
/**
- * Additional behavioural configuration for the {@link HttpHubClient}.
+ * Configuration options for retrying failed HTTP requests.
*
* @param maxRetries number of maximum retries carried out by the client in case of a retryable error
* @param retryDelayMs time between retries in ms
*/
-public record HttpHubClientConfig(int maxRetries, int retryDelayMs) {
+public record HttpRetryConfig(int maxRetries, int retryDelayMs) {
public static final class Builder {
private int maxRetries = 5;
private int retryDelayMs = 2000;
@@ -21,7 +21,7 @@ public Builder withRetryDelayMs(int retryDelayMs) {
return this;
}
- public HttpHubClientConfig build() {
+ public HttpRetryConfig build() {
if (maxRetries < 0) {
throw new IllegalArgumentException("maxRetries must be greater than 0");
}
@@ -30,7 +30,7 @@ public HttpHubClientConfig build() {
throw new IllegalArgumentException("retryDelayMs must be greater than 0");
}
- return new HttpHubClientConfig(maxRetries, retryDelayMs);
+ return new HttpRetryConfig(maxRetries, retryDelayMs);
}
}
}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticator.java b/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticator.java
new file mode 100644
index 0000000..d7f600b
--- /dev/null
+++ b/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticator.java
@@ -0,0 +1,24 @@
+package de.privateaim.node_message_broker.common;
+
+import org.springframework.security.oauth2.core.OAuth2Token;
+import reactor.core.publisher.Mono;
+
+/**
+ * Describes an OIDC compliant authenticator.
+ */
+public interface OIDCAuthenticator {
+ /**
+ * Authenticates against an external system.
+ *
+ * @return A pair of access and refresh token.
+ */
+ Mono authenticate();
+
+ /**
+ * Refreshes the authentication against an external system.
+ *
+ * @param refreshToken The refresh token to be used.
+ * @return A pair of access and refresh token.
+ */
+ Mono refresh(OAuth2Token refreshToken);
+}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddleware.java b/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddleware.java
new file mode 100644
index 0000000..88ea2f1
--- /dev/null
+++ b/src/main/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddleware.java
@@ -0,0 +1,91 @@
+package de.privateaim.node_message_broker.common;
+
+import lombok.NonNull;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.http.HttpStatus;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ClientResponse;
+import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
+import org.springframework.web.reactive.function.client.ExchangeFunction;
+import reactor.core.publisher.Mono;
+
+import java.time.Instant;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * A middleware for authenticating against several different services of a single external provider.
+ */
+@Slf4j
+public final class OIDCAuthenticatorMiddleware implements ExchangeFilterFunction {
+
+ private final Map tokenPairByHost;
+ private final OIDCAuthenticator oidcAuthenticator;
+
+ public OIDCAuthenticatorMiddleware(OIDCAuthenticator oidcAuthenticator) {
+ this.tokenPairByHost = new ConcurrentHashMap<>();
+ this.oidcAuthenticator = requireNonNull(oidcAuthenticator, "OIDC authenticator must not be null");
+ }
+
+ @Override
+ public @NonNull Mono filter(@NonNull ClientRequest request, ExchangeFunction next) {
+ var host = request.url().getHost();
+ if (host == null) {
+ return Mono.just(ClientResponse.create(HttpStatus.BAD_REQUEST).build());
+ }
+
+ var authToken = computeAccessTokenForHost(host);
+ var authenticatedRequest = ClientRequest.from(request)
+ .headers(headers -> headers.setBearerAuth(authToken.getTokenValue()))
+ .build();
+
+ return next.exchange(authenticatedRequest).flatMap(response -> {
+ // Handling for unauthorized events in case of time overlaps regarding token expiration.
+ // Can happen if time skew is not properly handled by the authentication server.
+ if (response.statusCode().value() == HttpStatus.UNAUTHORIZED.value()) {
+ return response.releaseBody()
+ .then(Mono.just(computeAccessTokenForHost(host)))
+ .flatMap(token -> {
+ var newRequest = ClientRequest.from(request)
+ .headers(headers -> headers.setBearerAuth(token.getTokenValue()))
+ .build();
+ log.warn("retrying request to '{}' with new bearer token after receiving status code 401 " +
+ "(unauthorized)", request.url());
+ return next.exchange(newRequest);
+ });
+ } else {
+ return Mono.just(response);
+ }
+ });
+ }
+
+ private OAuth2AccessToken computeAccessTokenForHost(@NonNull String host) {
+ // TODO: revise - find another way to circumvent using block()
+ // The following is an atomic operation!
+ return tokenPairByHost.compute(host, (unused, tokenPair) -> {
+ if (tokenPair == null) {
+ log.info("acquiring access token for host '{}' as there is none yet", host);
+ return oidcAuthenticator.authenticate().block();
+ }
+
+ if (tokenPair.accessToken().getExpiresAt().isBefore(Instant.now())) {
+ return tokenPair.refreshToken()
+ .map(refreshToken -> {
+ if (refreshToken.getExpiresAt().isBefore(Instant.now())) {
+ log.warn("refresh token expired - acquiring new pair of access token and refresh token for " +
+ "host '{}'", host);
+ return oidcAuthenticator.authenticate().block();
+ } else {
+ log.info("refreshing access token for host '{}'", host);
+ return oidcAuthenticator.refresh(refreshToken).block();
+ }
+ })
+ .orElseGet(() -> oidcAuthenticator.authenticate().block());
+ }
+ return tokenPair;
+ }).accessToken();
+ }
+}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/OIDCTokenPair.java b/src/main/java/de/privateaim/node_message_broker/common/OIDCTokenPair.java
new file mode 100644
index 0000000..0876394
--- /dev/null
+++ b/src/main/java/de/privateaim/node_message_broker/common/OIDCTokenPair.java
@@ -0,0 +1,18 @@
+package de.privateaim.node_message_broker.common;
+
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+
+import java.util.Optional;
+
+/**
+ * An OIDC compliant pair of access token & refresh token.
+ *
+ * @param accessToken JWT acting as an access token.
+ * @param refreshToken JWT acting as a refresh token for acquiring a new access token.
+ */
+public record OIDCTokenPair(
+ OAuth2AccessToken accessToken,
+ Optional refreshToken
+) {
+}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClient.java b/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClient.java
index 06002b7..26f7766 100644
--- a/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClient.java
+++ b/src/main/java/de/privateaim/node_message_broker/common/hub/HttpHubClient.java
@@ -1,5 +1,6 @@
package de.privateaim.node_message_broker.common.hub;
+import de.privateaim.node_message_broker.common.HttpRetryConfig;
import de.privateaim.node_message_broker.common.hub.api.AnalysisNode;
import de.privateaim.node_message_broker.common.hub.api.HubResponseContainer;
import de.privateaim.node_message_broker.common.hub.api.Node;
@@ -30,11 +31,11 @@
public final class HttpHubClient implements HubClient {
private final WebClient authenticatedWebClient;
- private final HttpHubClientConfig config;
+ private final HttpRetryConfig retryConfig;
- public HttpHubClient(WebClient authenticatedWebClient, HttpHubClientConfig config) {
+ public HttpHubClient(WebClient authenticatedWebClient, HttpRetryConfig retryConfig) {
this.authenticatedWebClient = requireNonNull(authenticatedWebClient, "authenticated web client must not be null");
- this.config = requireNonNull(config, "config must not be null");
+ this.retryConfig = requireNonNull(retryConfig, "retry config must not be null");
}
// TODO: this might use a cache to cut corners and improve performance by avoiding unnecessary round-trips
@@ -59,12 +60,12 @@ public Mono> fetchAnalysisNodes(String analysisId) {
.bodyToMono(new ParameterizedTypeReference>>() {
})
.map(resp -> resp.data)
- .retryWhen(Retry.backoff(config.maxRetries(), Duration.ofMillis(config.retryDelayMs()))
+ .retryWhen(Retry.backoff(retryConfig.maxRetries(), Duration.ofMillis(retryConfig.retryDelayMs()))
.jitter(0.75)
.filter(err -> err instanceof HubCoreServerException)
.onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) ->
new HubAnalysisNodesNotObtainable("exhausted maximum number of retries of '%d'"
- .formatted(config.maxRetries())))));
+ .formatted(retryConfig.maxRetries())))));
}
// TODO: add cache here! - see spring annotations
@@ -110,11 +111,11 @@ public Mono fetchPublicKey(String robotId) {
"robot id `%s`".formatted(robotId), e));
}
})
- .retryWhen(Retry.backoff(config.maxRetries(), Duration.ofMillis(config.retryDelayMs()))
+ .retryWhen(Retry.backoff(retryConfig.maxRetries(), Duration.ofMillis(retryConfig.retryDelayMs()))
.jitter(0.75)
.filter(err -> err instanceof HubCoreServerException)
.onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) ->
new HubNodePublicKeyNotObtainable("exhausted maximum number of retries of '%d'"
- .formatted(config.maxRetries())))));
+ .formatted(retryConfig.maxRetries())))));
}
}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClient.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClient.java
deleted file mode 100644
index e5d0ebd..0000000
--- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClient.java
+++ /dev/null
@@ -1,59 +0,0 @@
-package de.privateaim.node_message_broker.common.hub.auth;
-
-import de.privateaim.node_message_broker.common.hub.auth.api.HubAuthTokenResponse;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.http.HttpStatusCode;
-import org.springframework.http.MediaType;
-import org.springframework.web.reactive.function.BodyInserters;
-import org.springframework.web.reactive.function.client.WebClient;
-import reactor.core.publisher.Mono;
-import reactor.util.retry.Retry;
-
-import java.time.Duration;
-
-import static java.util.Objects.requireNonNull;
-
-/**
- * A client for communicating with the hub's auth service via HTTP/HTTPS.
- */
-@Slf4j
-public final class HttpHubAuthClient implements HubAuthClient {
-
- private final WebClient webClient;
- private final HttpHubAuthClientConfig config;
-
- private static final String TOKEN_PATH = "/token";
- private static final String GRANT_TYPE = "robot_credentials";
-
- public HttpHubAuthClient(WebClient webClient, HttpHubAuthClientConfig config) {
- this.webClient = requireNonNull(webClient, "web client must not be null");
- this.config = requireNonNull(config, "config must not be null");
- }
-
- @Override
- public Mono requestAccessToken(String clientId, String clientSecret) {
- return webClient.post()
- .uri(TOKEN_PATH)
- .contentType(MediaType.APPLICATION_FORM_URLENCODED)
- .accept(MediaType.APPLICATION_JSON)
- .body(BodyInserters.fromFormData("grant_type", GRANT_TYPE)
- .with("id", clientId)
- .with("secret", clientSecret))
- .retrieve()
- .onStatus(HttpStatusCode::is5xxServerError,
- response -> {
- var err = new HubAuthException("could not fetch hub access token");
-
- log.warn("retrying token request after failed attempt", err);
- return Mono.error(err);
- })
- .bodyToMono(HubAuthTokenResponse.class)
- .map(hubAuthTokenResponse -> hubAuthTokenResponse.accessToken)
- .retryWhen(Retry.backoff(config.maxRetries(), Duration.ofMillis(config.retryDelayMs()))
- .jitter(0.75)
- .filter(err -> err instanceof HubAuthException)
- .onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) ->
- new HubAccessTokenNotObtainable("exhausted maximum retries of '%d'"
- .formatted(config.maxRetries())))));
- }
-}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientConfig.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientConfig.java
deleted file mode 100644
index 5820594..0000000
--- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientConfig.java
+++ /dev/null
@@ -1,37 +0,0 @@
-package de.privateaim.node_message_broker.common.hub.auth;
-
-/**
- * Additional behavioural configuration for the {@link HttpHubAuthClient}.
- *
- * @param maxRetries number of maximum retries carried out by the client in case of a retryable error
- * @param retryDelayMs time between retries in ms
- */
-public record HttpHubAuthClientConfig(int maxRetries, int retryDelayMs) {
-
- public static final class Builder {
- private int maxRetries = 5;
- private int retryDelayMs = 2000;
-
- public Builder withMaxRetries(int maxRetries) {
- this.maxRetries = maxRetries;
- return this;
- }
-
- public Builder withRetryDelayMs(int retryDelayMs) {
- this.retryDelayMs = retryDelayMs;
- return this;
- }
-
- public HttpHubAuthClientConfig build() {
- if (maxRetries < 0) {
- throw new IllegalArgumentException("maxRetries must be greater than 0");
- }
-
- if (retryDelayMs < 0) {
- throw new IllegalArgumentException("retryDelayMs must be greater than 0");
- }
-
- return new HttpHubAuthClientConfig(maxRetries, retryDelayMs);
- }
- }
-}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthClient.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthClient.java
deleted file mode 100644
index b117aeb..0000000
--- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthClient.java
+++ /dev/null
@@ -1,18 +0,0 @@
-package de.privateaim.node_message_broker.common.hub.auth;
-
-import reactor.core.publisher.Mono;
-
-/**
- * Describes a client able to carry out authentication related operations.
- */
-public interface HubAuthClient {
-
- /**
- * Asynchronously requests an access token from the hub using the "robot_credentials" grant type.
- *
- * @param clientId the client's id
- * @param clientSecret the client's secret
- * @return The access token.
- */
- Mono requestAccessToken(String clientId, String clientSecret);
-}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthMalformedResponseException.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthMalformedResponseException.java
new file mode 100644
index 0000000..88cf846
--- /dev/null
+++ b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubAuthMalformedResponseException.java
@@ -0,0 +1,10 @@
+package de.privateaim.node_message_broker.common.hub.auth;
+
+/**
+ * Signals that the response from the Hub's auth service is malformed.
+ */
+public class HubAuthMalformedResponseException extends Exception {
+ public HubAuthMalformedResponseException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticator.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticator.java
new file mode 100644
index 0000000..5b97603
--- /dev/null
+++ b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticator.java
@@ -0,0 +1,256 @@
+package de.privateaim.node_message_broker.common.hub.auth;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import de.privateaim.node_message_broker.common.HttpRetryConfig;
+import de.privateaim.node_message_broker.common.OIDCAuthenticator;
+import de.privateaim.node_message_broker.common.OIDCTokenPair;
+import de.privateaim.node_message_broker.common.hub.auth.api.HubAuthTokenResponse;
+import lombok.Getter;
+import lombok.NonNull;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.http.HttpStatusCode;
+import org.springframework.http.MediaType;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.OAuth2Token;
+import org.springframework.web.reactive.function.BodyInserters;
+import org.springframework.web.reactive.function.client.WebClient;
+import reactor.core.publisher.Mono;
+import reactor.util.retry.Retry;
+
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Base64;
+import java.util.Optional;
+
+import static java.util.Objects.requireNonNull;
+
+/**
+ * An OIDC compliant authenticator for authenticating against the Hub.
+ */
+@Slf4j
+public final class HubOIDCAuthenticator implements OIDCAuthenticator {
+
+ private final WebClient webClient;
+ private final HttpRetryConfig retryConfig;
+ private final ObjectMapper jsonDecoder;
+ private final String clientId;
+ private final String clientSecret;
+
+ private static final String TOKEN_PATH = "/token";
+ private static final String GRANT_TYPE_AUTHENTICATE = "robot_credentials";
+ private static final String GRANT_TYPE_REFRESH_TOKEN = "refresh_token";
+
+ private static final String JWT_CLAIM__ISSUED_AT = "iat";
+ private static final String JWT_CLAIM__EXPIRES_AT = "exp";
+
+ private HubOIDCAuthenticator(Builder builder) {
+ this.webClient = requireNonNull(builder.getWebClient(), "web client must not be null");
+ this.retryConfig = requireNonNull(builder.getRetryConfig(), "retry configuration must not be null");
+ this.jsonDecoder = requireNonNull(builder.getJsonDecoder(), "json decoder must not be null");
+ this.clientId = requireNonNull(builder.getClientId(), "client ID must not be null");
+ this.clientSecret = requireNonNull(builder.getClientSecret(), "client secret must not be null");
+ }
+
+ @Override
+ public Mono authenticate() {
+ return webClient.post()
+ .uri(TOKEN_PATH)
+ .contentType(MediaType.APPLICATION_FORM_URLENCODED)
+ .accept(MediaType.APPLICATION_JSON)
+ .body(BodyInserters.fromFormData("grant_type", GRANT_TYPE_AUTHENTICATE)
+ .with("id", clientId)
+ .with("secret", clientSecret))
+ .retrieve()
+ .onStatus(HttpStatusCode::is5xxServerError,
+ response -> {
+ var err = new HubAuthException("could not fetch hub access token");
+
+ log.warn("retrying token request after failed attempt", err);
+ return Mono.error(err);
+ })
+ .bodyToMono(HubAuthTokenResponse.class)
+ .flatMap(hubAuthTokenResponse -> {
+ try {
+ return Mono.just(new OIDCTokenPair(
+ parseAccessToken(hubAuthTokenResponse.accessToken),
+ Optional.ofNullable(parseRefreshToken(hubAuthTokenResponse.refreshToken))));
+ } catch (IllegalArgumentException e) {
+ return Mono.error(
+ new HubAuthMalformedResponseException("received malformed response from HUB when " +
+ "trying to fetch access token", e));
+ }
+ })
+ .doOnError(err -> log.error(err.getMessage(), err))
+ .retryWhen(Retry.backoff(retryConfig.maxRetries(), Duration.ofMillis(retryConfig.retryDelayMs()))
+ .jitter(0.75)
+ .filter(err -> err instanceof HubAuthException)
+ .onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) ->
+ new HubAccessTokenNotObtainable("exhausted maximum retries of '%d'"
+ .formatted(retryConfig.maxRetries())))));
+ }
+
+ @Override
+ public Mono refresh(OAuth2Token refreshToken) {
+ return webClient.post()
+ .uri(TOKEN_PATH)
+ .contentType(MediaType.APPLICATION_FORM_URLENCODED)
+ .accept(MediaType.APPLICATION_JSON)
+ .body(BodyInserters.fromFormData("grant_type", GRANT_TYPE_REFRESH_TOKEN)
+ .with("refresh_token", refreshToken.getTokenValue()))
+ .retrieve()
+ .onStatus(HttpStatusCode::is5xxServerError,
+ response -> {
+ var err = new HubAuthException("could not fetch new access token using refresh token");
+
+ log.warn("retrying token refresh request after failed attempt", err);
+ return Mono.error(err);
+ })
+ .bodyToMono(HubAuthTokenResponse.class)
+ .flatMap(hubAuthTokenResponse -> {
+ try {
+ return Mono.just(new OIDCTokenPair(
+ parseAccessToken(hubAuthTokenResponse.accessToken),
+ Optional.ofNullable(parseRefreshToken(hubAuthTokenResponse.refreshToken))));
+ } catch (IllegalArgumentException e) {
+ return Mono.error(
+ new HubAuthMalformedResponseException("received malformed response from HUB when " +
+ "trying to refresh access token", e));
+ }
+ })
+ .doOnError(err -> log.error(err.getMessage(), err))
+ .retryWhen(Retry.backoff(retryConfig.maxRetries(), Duration.ofMillis(retryConfig.retryDelayMs()))
+ .jitter(0.75)
+ .filter(err -> err instanceof HubAuthException)
+ .onRetryExhaustedThrow(((retryBackoffSpec, retrySignal) ->
+ new HubAccessTokenNotObtainable("exhausted maximum retries of '%d'"
+ .formatted(retryConfig.maxRetries())))));
+ }
+
+ private OAuth2AccessToken parseAccessToken(@NonNull String tokenValue) {
+ try {
+ var jwt = decodeJwt(tokenValue);
+ if (jwt.getIssuedAt() == null) {
+ throw new IllegalArgumentException("JWT is missing issuing claim");
+ }
+ if (jwt.getExpiresAt() == null) {
+ throw new IllegalArgumentException("JWT is missing expiration claim");
+ }
+
+ return new OAuth2AccessToken(
+ OAuth2AccessToken.TokenType.BEARER,
+ jwt.getTokenValue(),
+ jwt.getIssuedAt(),
+ jwt.getExpiresAt());
+ } catch (JsonProcessingException e) {
+ throw new IllegalArgumentException("failed to decode access token from Hub");
+ }
+ }
+
+ private OAuth2RefreshToken parseRefreshToken(String tokenValue) {
+ if (tokenValue == null) {
+ return null;
+ }
+
+ try {
+ var jwt = decodeJwt(tokenValue);
+ if (jwt.getIssuedAt() == null) {
+ throw new IllegalArgumentException("JWT is missing issuing claim");
+ }
+ if (jwt.getExpiresAt() == null) {
+ throw new IllegalArgumentException("JWT is missing expiration claim");
+ }
+
+ return new OAuth2RefreshToken(
+ jwt.getTokenValue(),
+ jwt.getIssuedAt(),
+ jwt.getExpiresAt());
+ } catch (JsonProcessingException e) {
+ throw new IllegalArgumentException("failed to decode refresh token from Hub");
+ }
+ }
+
+ private OAuth2Token decodeJwt(@NonNull String tokenValue) throws JsonProcessingException {
+ var chunks = tokenValue.split("\\.");
+
+ if (chunks.length < 2) {
+ throw new IllegalArgumentException("could not decode JWT due to missing payload section");
+ }
+ var payload = new String(Base64.getUrlDecoder().decode(chunks[1]));
+
+ var payloadJson = jsonDecoder.readTree(payload);
+
+ var issuedAt = Optional.ofNullable(payloadJson.get(JWT_CLAIM__ISSUED_AT))
+ .map(JsonNode::asLong)
+ .map(Instant::ofEpochSecond)
+ .orElse(null);
+ var expiresAt = Optional.ofNullable(payloadJson.get(JWT_CLAIM__EXPIRES_AT))
+ .map(JsonNode::asLong)
+ .map(Instant::ofEpochSecond)
+ .orElse(null);
+
+ return new HubJWT(tokenValue, issuedAt, expiresAt);
+ }
+
+ private record HubJWT(String tokenValue, Instant issuedAt, Instant expiresAt) implements OAuth2Token {
+ @Override
+ public String getTokenValue() {
+ return tokenValue;
+ }
+
+ @Override
+ public Instant getIssuedAt() {
+ return issuedAt;
+ }
+
+ @Override
+ public Instant getExpiresAt() {
+ return expiresAt;
+ }
+ }
+
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ @Getter
+ public static class Builder {
+ private WebClient webClient;
+ private HttpRetryConfig retryConfig;
+ private ObjectMapper jsonDecoder;
+ private String clientId;
+ private String clientSecret;
+
+ Builder() {
+ }
+
+ public Builder usingWebClient(WebClient webClient) {
+ this.webClient = webClient;
+ return this;
+ }
+
+ public Builder withRetryConfig(HttpRetryConfig retryConfig) {
+ this.retryConfig = retryConfig;
+ return this;
+ }
+
+ public Builder withJsonDecoder(ObjectMapper jsonDecoder) {
+ this.jsonDecoder = jsonDecoder;
+ return this;
+ }
+
+ public Builder withAuthCredentials(String clientId, String clientSecret) {
+ this.clientId = clientId;
+ this.clientSecret = clientSecret;
+ return this;
+ }
+
+ public OIDCAuthenticator build() {
+ // TODO: add validation step
+ return new HubOIDCAuthenticator(this);
+ }
+ }
+}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilter.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilter.java
deleted file mode 100644
index 7749960..0000000
--- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilter.java
+++ /dev/null
@@ -1,63 +0,0 @@
-package de.privateaim.node_message_broker.common.hub.auth;
-
-import lombok.NonNull;
-import lombok.extern.slf4j.Slf4j;
-import org.springframework.http.HttpStatus;
-import org.springframework.web.reactive.function.client.ClientRequest;
-import org.springframework.web.reactive.function.client.ClientResponse;
-import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
-import org.springframework.web.reactive.function.client.ExchangeFunction;
-import reactor.core.publisher.Mono;
-
-import static java.util.Objects.requireNonNull;
-
-/**
- * Filter for a {@link org.springframework.web.reactive.function.client.WebClient} which handles re-authentication of
- * the request if the request target responds with a 401 (Unauthorized) status code.
- */
-@Slf4j
-public class RenewAuthTokenFilter implements ExchangeFilterFunction {
-
- private final HubAuthClient authClient;
-
- private final String robotId;
-
- private final String robotSecret;
-
- /**
- * Creates a new instance of the filter with information required for re-authenticating any in-flight requests.
- *
- * @param authClient client able to communicate to the Hub's auth service
- * @param robotId client's robot id used for authentication
- * @param robotSecret client's robot secret used for authentication
- */
- public RenewAuthTokenFilter(HubAuthClient authClient, String robotId, String robotSecret) {
- this.authClient = requireNonNull(authClient, "auth client must not be null");
- this.robotId = robotId;
- this.robotSecret = robotSecret;
- }
-
- @Override
- public @NonNull Mono filter(@NonNull ClientRequest request, ExchangeFunction next) {
- return next.exchange(request).flatMap(response -> {
- if (response.statusCode().value() == HttpStatus.UNAUTHORIZED.value()) {
- return response.releaseBody()
- .then(acquireNewToken(robotId, robotSecret))
- .flatMap(token -> {
- var newRequest = ClientRequest.from(request)
- .headers(headers -> headers.setBearerAuth(token))
- .build();
- log.warn("retrying request to {} after receiving status code 401 (unauthorized)",
- request.url());
- return next.exchange(newRequest);
- });
- } else {
- return Mono.just(response);
- }
- });
- }
-
- private Mono acquireNewToken(String robotId, String robotSecret) {
- return authClient.requestAccessToken(robotId, robotSecret);
- }
-}
diff --git a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/api/HubAuthTokenResponse.java b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/api/HubAuthTokenResponse.java
index 6cbb41a..7c5f0d6 100644
--- a/src/main/java/de/privateaim/node_message_broker/common/hub/auth/api/HubAuthTokenResponse.java
+++ b/src/main/java/de/privateaim/node_message_broker/common/hub/auth/api/HubAuthTokenResponse.java
@@ -15,7 +15,10 @@
@JsonIgnoreProperties(ignoreUnknown = true)
public final class HubAuthTokenResponse {
- @JsonProperty("access_token")
+ @JsonProperty(value = "access_token", required = true)
public String accessToken;
+ @JsonProperty(value = "refresh_token")
+ public String refreshToken;
+
}
diff --git a/src/main/java/de/privateaim/node_message_broker/message/MessageSpringConfig.java b/src/main/java/de/privateaim/node_message_broker/message/MessageSpringConfig.java
index fd61ae7..9e7dc4a 100644
--- a/src/main/java/de/privateaim/node_message_broker/message/MessageSpringConfig.java
+++ b/src/main/java/de/privateaim/node_message_broker/message/MessageSpringConfig.java
@@ -2,8 +2,8 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import de.privateaim.node_message_broker.ConfigurationUtil;
+import de.privateaim.node_message_broker.common.OIDCAuthenticator;
import de.privateaim.node_message_broker.common.hub.HubClient;
-import de.privateaim.node_message_broker.common.hub.auth.HubAuthClient;
import de.privateaim.node_message_broker.message.crypto.HubMessageCryptoService;
import de.privateaim.node_message_broker.message.crypto.MessageCryptoService;
import de.privateaim.node_message_broker.message.emit.EmitMessage;
@@ -76,10 +76,8 @@ OkHttpClient socketBaseClient(@Qualifier("COMMON_JAVA_SSL_CONTEXT") SSLContext s
@Qualifier("HUB_MESSENGER_UNDERLYING_SOCKET")
@Bean(destroyMethod = "disconnect")
public Socket underlyingMessengerSocket(
- @Qualifier("HUB_AUTH_CLIENT") HubAuthClient hubAuthClient,
+ @Qualifier("HUB_AUTHENTICATOR") OIDCAuthenticator hubAuthenticator,
@Qualifier("HUB_MESSAGE_RECEIVER") MessageReceiver messageReceiver,
- @Qualifier("HUB_AUTH_ROBOT_SECRET") String hubAuthRobotSecret,
- @Qualifier("HUB_AUTH_ROBOT_ID") String hubAuthRobotId,
@Qualifier("HUB_MESSENGER_UNDERLYING_SOCKET_SECURE_CLIENT") OkHttpClient secureBaseClient) {
IO.Options options = IO.Options.builder()
.setPath(null)
@@ -96,8 +94,12 @@ public Socket underlyingMessengerSocket(
log.error("cannot connect to hub messenger at `{}`", hubMessengerBaseUrl);
// we block here since this is a crucial component
- options.auth.put("token", hubAuthClient.requestAccessToken(hubAuthRobotId, hubAuthRobotSecret)
- .block());
+ var oidcTokenPair = hubAuthenticator.authenticate().block();
+ if (oidcTokenPair == null) {
+ throw new RuntimeException("authentication failed - cannot connect to hub messenger at `%s`"
+ .formatted(hubMessengerBaseUrl));
+ }
+ options.auth.put("token", oidcTokenPair.accessToken().getTokenValue());
log.info("reconnecting to hub messenger at `{}` with new authentication token", hubMessengerBaseUrl);
socket.connect();
@@ -110,8 +112,13 @@ public Socket underlyingMessengerSocket(
objects -> {
log.info("trying to reconnect to hub messenger via socket at `{}", hubMessengerBaseUrl);
// we block here since this is a crucial component
- options.auth.put("token", hubAuthClient.requestAccessToken(hubAuthRobotId, hubAuthRobotSecret)
- .block());
+ var oidcTokenPair = hubAuthenticator.authenticate().block();
+ if (oidcTokenPair == null) {
+ throw new RuntimeException("authentication failed - cannot connect to hub messenger at `%s`"
+ .formatted(hubMessengerBaseUrl));
+ }
+
+ options.auth.put("token", oidcTokenPair.accessToken().getTokenValue());
});
socket.on(SOCKET_RECEIVE_HUB_MESSAGE_IDENTIFIER, objects -> {
@@ -241,12 +248,6 @@ MessageSubscriptionService messageSubscriptionService(
return new MessageSubscriptionServiceImpl(messageSubscriptionRepository);
}
- @Qualifier("HUB_MESSAGE_RECEIVE_JSON_MAPPER")
- @Bean
- ObjectMapper hubMessageJsonMapper() {
- return new ObjectMapper();
- }
-
@Qualifier("HUB_MESSAGE_RECEIVE_MIDDLEWARE_DECRYPT")
@Bean
Function> hubMessageReceiveDecryptionMiddleware(
@@ -303,7 +304,7 @@ MessageConsumer hubMessageConsumer(
@Qualifier("HUB_MESSAGE_RECEIVER")
@Bean
MessageReceiver hubMessageReceiver(
- @Qualifier("HUB_MESSAGE_RECEIVE_JSON_MAPPER") ObjectMapper jsonMapper,
+ @Qualifier("HUB_JSON_MAPPER") ObjectMapper jsonMapper,
@Qualifier("HUB_MESSAGE_RECEIVE_MIDDLEWARES") List>> middlewares,
@Qualifier("HUB_MESSAGE_RECEIVE_CONSUMER") MessageConsumer messageConsumer
) {
diff --git a/src/test/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddlewareTest.java b/src/test/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddlewareTest.java
new file mode 100644
index 0000000..b205792
--- /dev/null
+++ b/src/test/java/de/privateaim/node_message_broker/common/OIDCAuthenticatorMiddlewareTest.java
@@ -0,0 +1,317 @@
+package de.privateaim.node_message_broker.common;
+
+import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus;
+import io.jsonwebtoken.Jwts;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.springframework.http.HttpMethod;
+import org.springframework.http.HttpStatusCode;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ClientResponse;
+import org.springframework.web.reactive.function.client.ExchangeFunction;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import java.net.URI;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.Optional;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.mockito.Mockito.*;
+
+@ExtendWith(MockitoExtension.class)
+public class OIDCAuthenticatorMiddlewareTest {
+
+ @Mock
+ private OIDCAuthenticator authenticator;
+
+ @InjectMocks
+ private OIDCAuthenticatorMiddleware authenticatorMiddleware;
+
+ private OAuth2RefreshToken getSimpleRefreshToken(Instant issuedAt, Instant expiresAt) {
+ var refreshTokenValue = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(issuedAt))
+ .expiration(Date.from(expiresAt))
+ .compact();
+
+ return new OAuth2RefreshToken(
+ refreshTokenValue,
+ issuedAt,
+ expiresAt
+ );
+ }
+
+ private OAuth2AccessToken getSimpleAccessToken(Instant issuedAt, Instant expiresAt) {
+ var accessTokenValue = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(issuedAt))
+ .expiration(Date.from(expiresAt))
+ .compact();
+
+ return new OAuth2AccessToken(
+ OAuth2AccessToken.TokenType.BEARER,
+ accessTokenValue,
+ issuedAt,
+ expiresAt
+ );
+ }
+
+ private String getBearerToken(ClientRequest request) {
+ var authorizationHeaderValue = request.headers().getFirst("Authorization");
+ if (authorizationHeaderValue.startsWith("Bearer ")) {
+ return authorizationHeaderValue.substring("Bearer ".length());
+ }
+
+ return null;
+ }
+
+ @Test
+ void requestFailsIfNoHostInformationIsPresent() {
+ var targetResponse = Mockito.mock(ClientResponse.class);
+ var request = ClientRequest.create(HttpMethod.GET, URI.create("/no-host/just-a-path")).build();
+ ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse);
+
+ var responses = new ArrayList();
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .recordWith(ArrayList::new)
+ .expectNextCount(1)
+ .consumeRecordedWith(responses::addAll)
+ .verifyComplete();
+
+ assertEquals(1, responses.size());
+ assertEquals(HttpStatus.SC_BAD_REQUEST, responses.getFirst().statusCode().value());
+ }
+
+ @Test
+ void firstRequestForAHostRequiresAuthentication() {
+ var issuedAt = Instant.now();
+ var expiresAt = issuedAt.plus(Duration.ofHours(1));
+
+ var accessToken = getSimpleAccessToken(issuedAt, expiresAt);
+ var refreshToken = getSimpleRefreshToken(issuedAt, expiresAt);
+
+ var targetResponse = Mockito.mock(ClientResponse.class);
+ var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build();
+ ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse);
+
+ doReturn(Mono.just(new OIDCTokenPair(accessToken, Optional.of(refreshToken))))
+ .when(authenticator).authenticate();
+ doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode();
+
+
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+ verify(authenticator, Mockito.times(1)).authenticate();
+ }
+
+ @Test
+ void existingAccessTokenGetsReusedIfNotExpired() {
+ var issuedAt = Instant.now();
+ var expiresAt = issuedAt.plus(Duration.ofHours(1));
+
+ var accessToken = getSimpleAccessToken(issuedAt, expiresAt);
+ var refreshToken = getSimpleRefreshToken(issuedAt, expiresAt);
+
+ var targetResponse = Mockito.mock(ClientResponse.class);
+ var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build();
+ ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse);
+
+ doReturn(Mono.just(new OIDCTokenPair(accessToken, Optional.of(refreshToken))))
+ .when(authenticator).authenticate();
+ doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode();
+
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+ verify(authenticator, Mockito.times(1)).authenticate();
+ }
+
+ @Test
+ void newAccessTokenIsAcquiredIfExpiredUsingRefreshTokenIfPresent() {
+ var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(2)),
+ Instant.now().minus(Duration.ofHours(1)));
+ var refreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(1)));
+ var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1)));
+ var newRefreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(2)));
+
+
+ var targetResponse = Mockito.mock(ClientResponse.class);
+ var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build();
+ ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse);
+
+ doReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.of(refreshToken))))
+ .when(authenticator).authenticate();
+ doReturn(Mono.just(new OIDCTokenPair(newAccessToken, Optional.of(newRefreshToken))))
+ .when(authenticator).refresh(refreshToken);
+ doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode();
+
+ // This in fact runs the first request with an expired token. However, we are not interested in that fact
+ // here. It's more about filling the cache for the next request to the same host.
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+ verify(authenticator, Mockito.times(1)).authenticate();
+ verify(authenticator, Mockito.times(1)).refresh(refreshToken);
+ }
+
+ @Test
+ void newAccessTokenIsAcquiredEvenIfRefreshTokenIsAbsent() {
+ var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(2)),
+ Instant.now().minus(Duration.ofHours(1)));
+ var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1)));
+
+
+ var targetResponse = Mockito.mock(ClientResponse.class);
+ var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build();
+ ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse);
+
+ doReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.empty())),
+ Mono.just(new OIDCTokenPair(newAccessToken, Optional.empty())))
+ .when(authenticator).authenticate();
+ doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode();
+
+ // This in fact runs the first request with an expired token. However, we are not interested in that fact
+ // here. It's more about filling the cache for the next request to the same host.
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+ verify(authenticator, Mockito.times(2)).authenticate();
+ }
+
+ @Test
+ void newAccessAndRefreshTokenPairAcquiredIfBothHaveExpired() {
+ var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(3)),
+ Instant.now().minus(Duration.ofHours(1)));
+ var expiredRefreshToken = getSimpleRefreshToken(Instant.now().minus(Duration.ofHours(3)),
+ Instant.now().minus(Duration.ofHours(2)));
+ var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1)));
+ var newRefreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(2)));
+
+
+ var targetResponse = Mockito.mock(ClientResponse.class);
+ var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build();
+ ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse);
+
+ when(authenticator.authenticate())
+ .thenReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.of(expiredRefreshToken))))
+ .thenReturn(Mono.just(new OIDCTokenPair(newAccessToken, Optional.of(newRefreshToken))));
+ doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode();
+
+ // This in fact runs the first request with an expired token. However, we are not interested in that fact
+ // here. It's more about filling the cache for the next request to the same host.
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+
+ verify(authenticator, Mockito.times(2)).authenticate();
+ verify(authenticator, Mockito.never()).refresh(Mockito.any(OAuth2RefreshToken.class));
+ }
+
+ @Test
+ void attemptsNewRequestWithNewAccessTokenIfServerRespondsWithUnauthorized() {
+ var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(2)),
+ Instant.now().minus(Duration.ofHours(1)));
+ var refreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(1)));
+ var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1)));
+ var newRefreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(2)));
+
+ var targetResponse = Mockito.mock(ClientResponse.class);
+ var unauthorizedTargetResponse = ClientResponse.create(HttpStatusCode.valueOf(HttpStatus.SC_UNAUTHORIZED))
+ .body("unauthorized").build();
+ var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build();
+ ExchangeFunction targetExchangeFunction = Mockito.mock(ExchangeFunction.class);
+
+ doReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.of(refreshToken))))
+ .when(authenticator).authenticate();
+ doReturn(Mono.just(new OIDCTokenPair(newAccessToken, Optional.of(newRefreshToken))))
+ .when(authenticator).refresh(refreshToken);
+
+ var requestCaptor = ArgumentCaptor.forClass(ClientRequest.class);
+ doReturn(Mono.just(unauthorizedTargetResponse), Mono.just(targetResponse))
+ .when(targetExchangeFunction)
+ .exchange(requestCaptor.capture());
+
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(targetResponse)
+ .expectComplete()
+ .verify();
+ verify(authenticator, Mockito.times(1)).authenticate();
+ verify(authenticator, Mockito.times(1)).refresh(refreshToken);
+
+ var capturedRequests = requestCaptor.getAllValues();
+
+ assertEquals(2, capturedRequests.size());
+ assertEquals(expiredAccessToken.getTokenValue(), getBearerToken(capturedRequests.getFirst()));
+ assertEquals(newAccessToken.getTokenValue(), getBearerToken(capturedRequests.getLast()));
+ }
+
+ @Test
+ void requestEventuallyFailsIfRetryIsStillUnauthorized() {
+ var expiredAccessToken = getSimpleAccessToken(Instant.now().minus(Duration.ofHours(2)),
+ Instant.now().minus(Duration.ofHours(1)));
+ var refreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(1)));
+ var newAccessToken = getSimpleAccessToken(Instant.now(), Instant.now().plus(Duration.ofHours(1)));
+ var newRefreshToken = getSimpleRefreshToken(Instant.now(), Instant.now().plus(Duration.ofHours(2)));
+
+ var unauthorizedTargetResponse = ClientResponse.create(HttpStatusCode.valueOf(HttpStatus.SC_UNAUTHORIZED))
+ .body("unauthorized").build();
+ var request = ClientRequest.create(HttpMethod.GET, URI.create("https://test.host/some-resource")).build();
+ ExchangeFunction targetExchangeFunction = Mockito.mock(ExchangeFunction.class);
+
+ doReturn(Mono.just(new OIDCTokenPair(expiredAccessToken, Optional.of(refreshToken))))
+ .when(authenticator).authenticate();
+ doReturn(Mono.just(new OIDCTokenPair(newAccessToken, Optional.of(newRefreshToken))))
+ .when(authenticator).refresh(refreshToken);
+
+ var requestCaptor = ArgumentCaptor.forClass(ClientRequest.class);
+ doReturn(Mono.just(unauthorizedTargetResponse))
+ .when(targetExchangeFunction)
+ .exchange(requestCaptor.capture());
+
+ StepVerifier.create(authenticatorMiddleware.filter(request, targetExchangeFunction))
+ .expectNext(unauthorizedTargetResponse)
+ .expectComplete()
+ .verify();
+ verify(authenticator, Mockito.times(1)).authenticate();
+ verify(authenticator, Mockito.times(1)).refresh(refreshToken);
+
+ var capturedRequests = requestCaptor.getAllValues();
+
+ assertEquals(2, capturedRequests.size());
+ assertEquals(expiredAccessToken.getTokenValue(), getBearerToken(capturedRequests.getFirst()));
+ assertEquals(newAccessToken.getTokenValue(), getBearerToken(capturedRequests.getLast()));
+ }
+}
diff --git a/src/test/java/de/privateaim/node_message_broker/common/hub/HttpHubClientIT.java b/src/test/java/de/privateaim/node_message_broker/common/hub/HttpHubClientIT.java
index cebb5c8..bd724e2 100644
--- a/src/test/java/de/privateaim/node_message_broker/common/hub/HttpHubClientIT.java
+++ b/src/test/java/de/privateaim/node_message_broker/common/hub/HttpHubClientIT.java
@@ -1,6 +1,7 @@
package de.privateaim.node_message_broker.common.hub;
import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus;
+import de.privateaim.node_message_broker.common.HttpRetryConfig;
import de.privateaim.node_message_broker.common.hub.api.AnalysisNode;
import de.privateaim.node_message_broker.common.hub.api.HubResponseContainer;
import de.privateaim.node_message_broker.common.hub.api.Node;
@@ -50,10 +51,7 @@ public class HttpHubClientIT {
void setUp() {
mockWebServer = new MockWebServer();
var webClient = WebClient.create(mockWebServer.url("/").toString());
- httpHubClient = new HttpHubClient(webClient, new HttpHubClientConfig.Builder()
- .withMaxRetries(MAX_RETRIES)
- .withRetryDelayMs(RETRY_DELAY_MILLIS)
- .build());
+ httpHubClient = new HttpHubClient(webClient, new HttpRetryConfig(MAX_RETRIES, RETRY_DELAY_MILLIS));
}
@AfterEach
diff --git a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientIT.java b/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientIT.java
deleted file mode 100644
index b5c620f..0000000
--- a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HttpHubAuthClientIT.java
+++ /dev/null
@@ -1,111 +0,0 @@
-package de.privateaim.node_message_broker.common.hub.auth;
-
-import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus;
-import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.message.BasicNameValuePair;
-import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.net.URLEncodedUtils;
-import okhttp3.mockwebserver.MockResponse;
-import okhttp3.mockwebserver.MockWebServer;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-import org.springframework.web.reactive.function.client.WebClient;
-import org.testcontainers.shaded.com.fasterxml.jackson.core.JsonProcessingException;
-import org.testcontainers.shaded.com.fasterxml.jackson.databind.ObjectMapper;
-import reactor.test.StepVerifier;
-
-import java.io.IOException;
-import java.nio.charset.StandardCharsets;
-import java.util.List;
-import java.util.Map;
-
-import static com.mongodb.assertions.Assertions.assertTrue;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-
-public class HttpHubAuthClientIT {
-
- private MockWebServer mockWebServer;
-
- private HubAuthClient httpHubAuthClient;
-
- private static final int MAX_RETRIES = 5;
- private static final int RETRY_DELAY_MILLIS = 10; // keeping it short for testing purposes!
-
- @BeforeEach
- void setUp() {
- mockWebServer = new MockWebServer();
- var webClient = WebClient.create(mockWebServer.url("/").toString());
- httpHubAuthClient = new HttpHubAuthClient(webClient, new HttpHubAuthClientConfig.Builder()
- .withMaxRetries(MAX_RETRIES)
- .withRetryDelayMs(RETRY_DELAY_MILLIS)
- .build());
- }
-
- @AfterEach
- void tearDown() throws IOException {
- mockWebServer.shutdown();
- }
-
- @Test
- void obtainingAccessTokenSucceedsOnFirstTry() throws JsonProcessingException, InterruptedException {
- mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
- .setHeader("Content-Type", "application/json")
- .setBody(
- new ObjectMapper().writeValueAsString(Map.of("access_token", "test"))
- ));
-
- StepVerifier.create(httpHubAuthClient.requestAccessToken("some-client-id", "some-client-secret"))
- .expectNext("test")
- .verifyComplete();
-
- assertEquals(1, mockWebServer.getRequestCount());
- var recordedRequest = mockWebServer.takeRequest();
- assertEquals("/token", recordedRequest.getPath());
-
- var urlEncodedParams = URLEncodedUtils.parse(recordedRequest.getBody().readUtf8(), StandardCharsets.UTF_8);
-
- assertEquals(3, urlEncodedParams.size());
- assertEquals("robot_credentials", urlEncodedParams.getFirst().getValue());
-
- assertTrue(urlEncodedParams.containsAll(List.of(
- new BasicNameValuePair("grant_type", "robot_credentials"),
- new BasicNameValuePair("id", "some-client-id"),
- new BasicNameValuePair("secret", "some-client-secret")
- )));
- }
-
- @Test
- void obtainingAccessTokenSucceedsWithinRetryRange() throws JsonProcessingException, InterruptedException {
- mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE));
- mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
- .setHeader("Content-Type", "application/json")
- .setBody(
- new ObjectMapper().writeValueAsString(Map.of("access_token", "test"))
- ));
-
- StepVerifier.create(httpHubAuthClient.requestAccessToken("some-client-id", "some-client-secret"))
- .expectNext("test")
- .verifyComplete();
-
- assertEquals(2, mockWebServer.getRequestCount());
-
- for (int i = 0; i < 2; i++) {
- var recordedRequest = mockWebServer.takeRequest();
- assertEquals("/token", recordedRequest.getPath());
- }
- }
-
- @Test
- void obtainingAccessTokenFailsOfMaximumNumberOfRetriesOnServerError() throws InterruptedException {
- for (int i = 0; i < (MAX_RETRIES + 1); i++) {
- mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE));
- }
-
- StepVerifier.create(httpHubAuthClient.requestAccessToken("some-client-id", "some-client-secret"))
- .verifyError(HubAccessTokenNotObtainable.class);
-
- for (int i = 0; i < (MAX_RETRIES + 1); i++) {
- var recordedRequest = mockWebServer.takeRequest();
- assertEquals("/token", recordedRequest.getPath());
- }
- }
-}
diff --git a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticatorIT.java b/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticatorIT.java
new file mode 100644
index 0000000..4f43adf
--- /dev/null
+++ b/src/test/java/de/privateaim/node_message_broker/common/hub/auth/HubOIDCAuthenticatorIT.java
@@ -0,0 +1,457 @@
+package de.privateaim.node_message_broker.common.hub.auth;
+
+import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus;
+import de.privateaim.node_message_broker.common.HttpRetryConfig;
+import de.privateaim.node_message_broker.common.OIDCAuthenticator;
+import de.privateaim.node_message_broker.common.OIDCTokenPair;
+import io.jsonwebtoken.Jwts;
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Nested;
+import org.junit.jupiter.api.Test;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.OAuth2Token;
+import org.springframework.web.reactive.function.client.WebClient;
+import org.testcontainers.shaded.com.fasterxml.jackson.core.JsonProcessingException;
+import org.testcontainers.shaded.com.fasterxml.jackson.databind.ObjectMapper;
+import reactor.test.StepVerifier;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Date;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+public class HubOIDCAuthenticatorIT {
+
+ private MockWebServer mockWebServer;
+ private OIDCAuthenticator authenticator;
+
+ private static final int MAX_RETRIES = 5;
+ private static final int RETRY_DELAY_MILLIS = 10; // keeping it short for testing purposes!
+
+ private static final String CLIENT_ID = "some-client-id";
+ private static final String CLIENT_SECRET = "some-client-secret";
+
+ @BeforeEach
+ void setUp() {
+ mockWebServer = new MockWebServer();
+ authenticator = HubOIDCAuthenticator.builder()
+ .usingWebClient(WebClient.create(mockWebServer.url("/").toString()))
+ .withRetryConfig(new HttpRetryConfig(MAX_RETRIES, RETRY_DELAY_MILLIS))
+ .withJsonDecoder(new com.fasterxml.jackson.databind.ObjectMapper())
+ .withAuthCredentials(CLIENT_ID, CLIENT_SECRET)
+ .build();
+ }
+
+ @AfterEach
+ void tearDown() throws IOException {
+ mockWebServer.shutdown();
+ }
+
+ @Nested
+ class AuthenticationTests {
+
+ @Test
+ void succeedsEvenIfServerDoesNotSendRefreshToken() throws JsonProcessingException {
+ var accessToken = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of("access_token", accessToken))
+ ));
+
+
+ var oidcTokenPair = new ArrayList();
+ StepVerifier.create(authenticator.authenticate())
+ .recordWith(ArrayList::new)
+ .expectNextCount(1)
+ .consumeRecordedWith(oidcTokenPair::addAll)
+ .expectComplete()
+ .verify();
+
+ assertEquals(1, oidcTokenPair.size());
+ assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue());
+ assertTrue(oidcTokenPair.getFirst().refreshToken().isEmpty());
+ }
+
+ @Test
+ void failsIfServerDoesNotSendAccessToken() throws JsonProcessingException {
+ var refreshToken = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of("refresh_token", refreshToken))
+ ));
+
+ StepVerifier.create(authenticator.authenticate())
+ .expectError()
+ .verify();
+ }
+
+ @Test
+ void failsIfAccessTokenIsNotAValidJWT() throws JsonProcessingException {
+ var tokenValue = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of(
+ "access_token", "test",
+ "refresh_token", tokenValue))
+ ));
+
+ StepVerifier.create(authenticator.authenticate())
+ .expectError(HubAuthMalformedResponseException.class)
+ .verify();
+ }
+
+ @Test
+ void failsIfRefreshTokenIsNotAValidJWT() throws JsonProcessingException {
+ var tokenValue = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of(
+ "access_token", tokenValue,
+ "refresh_token", "test_refresh"))
+ ));
+
+ StepVerifier.create(authenticator.authenticate())
+ .expectError(HubAuthMalformedResponseException.class)
+ .verify();
+ }
+
+ @Test
+ void authenticateSucceedsOnFirstTry() throws JsonProcessingException {
+ var accessToken = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ var refreshToken = Jwts.builder()
+ .subject("test_refresh")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of(
+ "access_token", accessToken,
+ "refresh_token", refreshToken
+ ))));
+
+ var oidcTokenPair = new ArrayList();
+ StepVerifier.create(authenticator.authenticate())
+ .recordWith(ArrayList::new)
+ .expectNextCount(1)
+ .consumeRecordedWith(oidcTokenPair::addAll)
+ .expectComplete()
+ .verify();
+
+ assertEquals(1, oidcTokenPair.size());
+ assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue());
+ assertTrue(oidcTokenPair.getFirst().refreshToken().isPresent());
+ assertEquals(refreshToken, oidcTokenPair.getFirst().refreshToken().get().getTokenValue());
+ }
+
+ @Test
+ void authenticateSucceedsWithinRetryRange() throws JsonProcessingException, InterruptedException {
+ if (MAX_RETRIES < 2) {
+ fail("misconfigured test environment");
+ }
+
+ var accessToken = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ var refreshToken = Jwts.builder()
+ .subject("test_refresh")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE));
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of(
+ "access_token", accessToken,
+ "refresh_token", refreshToken
+ ))
+ ));
+
+ var oidcTokenPair = new ArrayList();
+ StepVerifier.create(authenticator.authenticate())
+ .recordWith(ArrayList::new)
+ .expectNextCount(1)
+ .consumeRecordedWith(oidcTokenPair::addAll)
+ .expectComplete()
+ .verify();
+
+ assertEquals(2, mockWebServer.getRequestCount());
+ assertEquals(1, oidcTokenPair.size());
+ assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue());
+ assertTrue(oidcTokenPair.getFirst().refreshToken().isPresent());
+ assertEquals(refreshToken, oidcTokenPair.getFirst().refreshToken().get().getTokenValue());
+
+ for (int i = 0; i < 2; i++) {
+ var recordedRequest = mockWebServer.takeRequest();
+ assertEquals("/token", recordedRequest.getPath());
+ }
+ }
+
+ @Test
+ void authenticateFinallyFailsAfterExhaustedRetryCount() throws InterruptedException {
+ for (int i = 0; i < (MAX_RETRIES + 1); i++) {
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE));
+ }
+
+ StepVerifier.create(authenticator.authenticate())
+ .verifyError(HubAccessTokenNotObtainable.class);
+
+ for (int i = 0; i < (MAX_RETRIES + 1); i++) {
+ var recordedRequest = mockWebServer.takeRequest();
+ assertEquals("/token", recordedRequest.getPath());
+ }
+ }
+
+ // TODO: add tests for checking what happens if issuedAt and expiresAt are missing...
+ }
+
+ @Nested
+ class RefreshTests {
+ private OAuth2Token getSimpleRefreshToken() {
+ var issuedAt = Instant.now();
+ var expiresAt = Instant.now().plus(Duration.ofDays(5));
+
+ var refreshTokenValue = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(issuedAt))
+ .expiration(Date.from(expiresAt))
+ .compact();
+
+ return new OAuth2RefreshToken(
+ refreshTokenValue,
+ issuedAt,
+ expiresAt
+ );
+ }
+
+ @Test
+ void succeedsEvenIfServerDoesNotSendRefreshToken() throws JsonProcessingException {
+ var accessToken = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of("access_token", accessToken))
+ ));
+
+ var oidcTokenPair = new ArrayList();
+ StepVerifier.create(authenticator.authenticate())
+ .recordWith(ArrayList::new)
+ .expectNextCount(1)
+ .consumeRecordedWith(oidcTokenPair::addAll)
+ .expectComplete()
+ .verify();
+
+ assertEquals(1, oidcTokenPair.size());
+ assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue());
+ assertTrue(oidcTokenPair.getFirst().refreshToken().isEmpty());
+ }
+
+ @Test
+ void failsIfServerDoesNotSendAccessToken() throws JsonProcessingException {
+ var refreshToken = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of("refresh_token", refreshToken))
+ ));
+
+ StepVerifier.create(authenticator.refresh(getSimpleRefreshToken()))
+ .expectError()
+ .verify();
+ }
+
+ @Test
+ void failsIfAccessTokenIsNotAValidJWT() throws JsonProcessingException {
+ var tokenValue = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of(
+ "access_token", "test",
+ "refresh_token", tokenValue))
+ ));
+
+ StepVerifier.create(authenticator.refresh(getSimpleRefreshToken()))
+ .expectError(HubAuthMalformedResponseException.class)
+ .verify();
+ }
+
+ @Test
+ void failsIfRefreshTokenIsNotAValidJWT() throws JsonProcessingException {
+ var tokenValue = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of(
+ "access_token", tokenValue,
+ "refresh_token", "test_refresh"))
+ ));
+
+ StepVerifier.create(authenticator.refresh(getSimpleRefreshToken()))
+ .expectError(HubAuthMalformedResponseException.class)
+ .verify();
+ }
+
+ @Test
+ void refreshSucceedsOnFirstTry() throws JsonProcessingException {
+ var accessToken = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ var refreshToken = Jwts.builder()
+ .subject("test_refresh")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of(
+ "access_token", accessToken,
+ "refresh_token", refreshToken
+ ))));
+
+ var oidcTokenPair = new ArrayList();
+ StepVerifier.create(authenticator.refresh(getSimpleRefreshToken()))
+ .recordWith(ArrayList::new)
+ .expectNextCount(1)
+ .consumeRecordedWith(oidcTokenPair::addAll)
+ .expectComplete()
+ .verify();
+
+ assertEquals(1, oidcTokenPair.size());
+ assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue());
+ assertTrue(oidcTokenPair.getFirst().refreshToken().isPresent());
+ assertEquals(refreshToken, oidcTokenPair.getFirst().refreshToken().get().getTokenValue());
+ }
+
+ @Test
+ void refreshSucceedsWithinRetryRange() throws JsonProcessingException, InterruptedException {
+ if (MAX_RETRIES < 2) {
+ fail("misconfigured test environment");
+ }
+
+ var accessToken = Jwts.builder()
+ .subject("test")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ var refreshToken = Jwts.builder()
+ .subject("test_refresh")
+ .issuedAt(Date.from(Instant.now()))
+ .expiration(Date.from(Instant.now().plus(Duration.ofDays(5))))
+ .compact();
+
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE));
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_OK)
+ .setHeader("Content-Type", "application/json")
+ .setBody(
+ new ObjectMapper().writeValueAsString(Map.of(
+ "access_token", accessToken,
+ "refresh_token", refreshToken
+ ))
+ ));
+
+ var oidcTokenPair = new ArrayList();
+ StepVerifier.create(authenticator.refresh(getSimpleRefreshToken()))
+ .recordWith(ArrayList::new)
+ .expectNextCount(1)
+ .consumeRecordedWith(oidcTokenPair::addAll)
+ .expectComplete()
+ .verify();
+
+ assertEquals(2, mockWebServer.getRequestCount());
+ assertEquals(1, oidcTokenPair.size());
+ assertEquals(accessToken, oidcTokenPair.getFirst().accessToken().getTokenValue());
+ assertTrue(oidcTokenPair.getFirst().refreshToken().isPresent());
+ assertEquals(refreshToken, oidcTokenPair.getFirst().refreshToken().get().getTokenValue());
+
+ for (int i = 0; i < 2; i++) {
+ var recordedRequest = mockWebServer.takeRequest();
+ assertEquals("/token", recordedRequest.getPath());
+ }
+ }
+
+ @Test
+ void refreshFinallyFailsAfterExhaustedRetryCount() throws InterruptedException {
+ for (int i = 0; i < (MAX_RETRIES + 1); i++) {
+ mockWebServer.enqueue(new MockResponse().setResponseCode(HttpStatus.SC_SERVICE_UNAVAILABLE));
+ }
+
+ StepVerifier.create(authenticator.refresh(getSimpleRefreshToken()))
+ .verifyError(HubAccessTokenNotObtainable.class);
+
+ for (int i = 0; i < (MAX_RETRIES + 1); i++) {
+ var recordedRequest = mockWebServer.takeRequest();
+ assertEquals("/token", recordedRequest.getPath());
+ }
+ }
+ }
+}
diff --git a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilterTest.java b/src/test/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilterTest.java
deleted file mode 100644
index 9f35113..0000000
--- a/src/test/java/de/privateaim/node_message_broker/common/hub/auth/RenewAuthTokenFilterTest.java
+++ /dev/null
@@ -1,106 +0,0 @@
-package de.privateaim.node_message_broker.common.hub.auth;
-
-import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus;
-import org.junit.jupiter.api.AfterEach;
-import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.ExtendWith;
-import org.mockito.Mock;
-import org.mockito.Mockito;
-import org.mockito.junit.jupiter.MockitoExtension;
-import org.springframework.http.HttpMethod;
-import org.springframework.http.HttpStatusCode;
-import org.springframework.web.reactive.function.client.ClientRequest;
-import org.springframework.web.reactive.function.client.ClientResponse;
-import org.springframework.web.reactive.function.client.ExchangeFunction;
-import reactor.core.publisher.Mono;
-import reactor.test.StepVerifier;
-
-import java.net.URI;
-
-import static org.mockito.Mockito.*;
-
-@ExtendWith(MockitoExtension.class)
-public class RenewAuthTokenFilterTest {
-
- private static final String ROBOT_ID = "test-id";
- private static final String ROBOT_SECRET = "test-secret";
-
- @Mock
- private HubAuthClient hubAuthClientMock;
- private RenewAuthTokenFilter renewAuthTokenFilter;
-
- @BeforeEach
- void setUp() {
- renewAuthTokenFilter = new RenewAuthTokenFilter(hubAuthClientMock, ROBOT_ID, ROBOT_SECRET);
- }
-
- @AfterEach
- void tearDown() {
- Mockito.reset(hubAuthClientMock);
- }
-
- @Test
- void requestDoesNotGetReAuthenticatedOnSuccess() {
- var targetResponse = Mockito.mock(ClientResponse.class);
- doReturn(HttpStatusCode.valueOf(HttpStatus.SC_OK)).when(targetResponse).statusCode();
- var request = ClientRequest.create(HttpMethod.GET, URI.create("/some-resource")).build();
-
- ExchangeFunction targetExchangeFunction = r -> Mono.just(targetResponse);
-
- var actualResponse = renewAuthTokenFilter.filter(request, targetExchangeFunction);
-
- StepVerifier.create(actualResponse)
- .expectNext(targetResponse)
- .expectComplete()
- .verify();
- }
-
- @Test
- void requestGetsReAuthenticatedOnFailure() {
- doReturn(Mono.just("some-access-token")).when(hubAuthClientMock).requestAccessToken(ROBOT_ID, ROBOT_SECRET);
- var unauthorizedTargetResponse = ClientResponse.create(HttpStatusCode.valueOf(HttpStatus.SC_UNAUTHORIZED))
- .body("unauthorized").build();
- var request = ClientRequest.create(HttpMethod.GET, URI.create("/some-resource")).build();
-
- var reAuthenticatedTargetResponse = Mockito.mock(ClientResponse.class);
- var targetExchangeFunction = Mockito.mock(ExchangeFunction.class);
- doReturn(Mono.just(unauthorizedTargetResponse), Mono.just(reAuthenticatedTargetResponse))
- .when(targetExchangeFunction)
- .exchange(Mockito.any(ClientRequest.class));
-
- var actualResponse = renewAuthTokenFilter.filter(request, targetExchangeFunction);
-
- StepVerifier.create(actualResponse)
- .expectNext(reAuthenticatedTargetResponse)
- .expectComplete()
- .verify();
-
- verify(hubAuthClientMock, times(1))
- .requestAccessToken(ROBOT_ID, ROBOT_SECRET);
- }
-
- @Test
- void requestUltimatelyFailsOnReAuthenticatedRequestFailure() {
- doReturn(Mono.error(new Exception("some-error"))).when(hubAuthClientMock)
- .requestAccessToken(ROBOT_ID, ROBOT_SECRET);
-
- var unauthorizedTargetResponse = ClientResponse.create(HttpStatusCode.valueOf(HttpStatus.SC_UNAUTHORIZED))
- .body("unauthorized").build();
- var request = ClientRequest.create(HttpMethod.GET, URI.create("/some-resource")).build();
-
- var targetExchangeFunction = Mockito.mock(ExchangeFunction.class);
- doReturn(Mono.just(unauthorizedTargetResponse))
- .when(targetExchangeFunction)
- .exchange(Mockito.any(ClientRequest.class));
-
- var actualResponse = renewAuthTokenFilter.filter(request, targetExchangeFunction);
-
- StepVerifier.create(actualResponse)
- .expectError(Exception.class)
- .verify();
-
- verify(hubAuthClientMock, times(1))
- .requestAccessToken(ROBOT_ID, ROBOT_SECRET);
- }
-}
diff --git a/src/test/java/de/privateaim/node_message_broker/discovery/DiscoveryServiceIT.java b/src/test/java/de/privateaim/node_message_broker/discovery/DiscoveryServiceIT.java
index 8d72c37..0290ad2 100644
--- a/src/test/java/de/privateaim/node_message_broker/discovery/DiscoveryServiceIT.java
+++ b/src/test/java/de/privateaim/node_message_broker/discovery/DiscoveryServiceIT.java
@@ -1,8 +1,8 @@
package de.privateaim.node_message_broker.discovery;
import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus;
+import de.privateaim.node_message_broker.common.HttpRetryConfig;
import de.privateaim.node_message_broker.common.hub.HttpHubClient;
-import de.privateaim.node_message_broker.common.hub.HttpHubClientConfig;
import de.privateaim.node_message_broker.common.hub.api.AnalysisNode;
import de.privateaim.node_message_broker.common.hub.api.HubResponseContainer;
import de.privateaim.node_message_broker.common.hub.api.Node;
@@ -39,10 +39,7 @@ public final class DiscoveryServiceIT {
void setUp() {
mockWebServer = new MockWebServer();
var noAuthWebClient = WebClient.create(mockWebServer.url("/").toString());
- var hubClientCfg = new HttpHubClientConfig.Builder()
- .withMaxRetries(0)
- .withRetryDelayMs(0)
- .build();
+ var hubClientCfg = new HttpRetryConfig(0, 0);
var hubClient = Mockito.spy(new HttpHubClient(noAuthWebClient, hubClientCfg));
discoveryService = new DiscoveryService(hubClient, SELF_ROBOT_ID);
}
diff --git a/src/test/java/de/privateaim/node_message_broker/message/MessageServiceIT.java b/src/test/java/de/privateaim/node_message_broker/message/MessageServiceIT.java
index a5a08c9..caedb6a 100644
--- a/src/test/java/de/privateaim/node_message_broker/message/MessageServiceIT.java
+++ b/src/test/java/de/privateaim/node_message_broker/message/MessageServiceIT.java
@@ -4,8 +4,8 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.JsonNodeFactory;
import com.github.dockerjava.zerodep.shaded.org.apache.hc.core5.http.HttpStatus;
+import de.privateaim.node_message_broker.common.HttpRetryConfig;
import de.privateaim.node_message_broker.common.hub.HttpHubClient;
-import de.privateaim.node_message_broker.common.hub.HttpHubClientConfig;
import de.privateaim.node_message_broker.common.hub.api.AnalysisNode;
import de.privateaim.node_message_broker.common.hub.api.HubResponseContainer;
import de.privateaim.node_message_broker.common.hub.api.Node;
@@ -57,10 +57,7 @@ void setUp() {
mockSocket = Mockito.mock(Socket.class);
spyMessageEmitter = Mockito.spy(new HubMessageEmitter(mockSocket));
var webClient = WebClient.create(mockWebServer.url("/").toString());
- var httpHubClient = new HttpHubClient(webClient, new HttpHubClientConfig.Builder()
- .withMaxRetries(0)
- .withRetryDelayMs(0)
- .build());
+ var httpHubClient = new HttpHubClient(webClient, new HttpRetryConfig(0, 0));
messageService = new MessageService(spyMessageEmitter, httpHubClient, SELF_ROBOT_ID);
emitMessageCaptor = ArgumentCaptor.forClass(EmitMessage.class);