diff --git a/docs/server/docs.go b/docs/server/docs.go index 1a5734c59e..5fa4849019 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -230,8 +230,12 @@ const docTemplate = `{ "description": "Cached OAuth token reference for persistence across restarts.\nThe refresh token is stored securely in the secret manager, and this field\ncontains the reference to retrieve it (e.g., \"OAUTH_REFRESH_TOKEN_workload\").\nThis enables session restoration without requiring a new browser-based login.", "type": "string" }, + "cached_reg_client_uri": { + "description": "CachedRegClientURI is the registration_client_uri from the DCR response.\nThis is the endpoint used for RFC 7592 client read/update/delete operations.\nStored as plain text since it is not sensitive.", + "type": "string" + }, "cached_reg_token_ref": { - "description": "RegistrationAccessToken is used to update/delete the client registration.\nStored as a secret reference since it's sensitive.", + "description": "CachedRegTokenRef is a secret manager reference to the registration_access_token\nreturned in the DCR response. Used for RFC 7592 client update operations.\nStored as a secret reference since it's sensitive.", "type": "string" }, "cached_secret_expiry": { diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 4c347eed6b..82a12706f7 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -223,8 +223,12 @@ "description": "Cached OAuth token reference for persistence across restarts.\nThe refresh token is stored securely in the secret manager, and this field\ncontains the reference to retrieve it (e.g., \"OAUTH_REFRESH_TOKEN_workload\").\nThis enables session restoration without requiring a new browser-based login.", "type": "string" }, + "cached_reg_client_uri": { + "description": "CachedRegClientURI is the registration_client_uri from the DCR response.\nThis is the endpoint used for RFC 7592 client read/update/delete operations.\nStored as plain text since it is not sensitive.", + "type": "string" + }, "cached_reg_token_ref": { - "description": "RegistrationAccessToken is used to update/delete the client registration.\nStored as a secret reference since it's sensitive.", + "description": "CachedRegTokenRef is a secret manager reference to the registration_access_token\nreturned in the DCR response. Used for RFC 7592 client update operations.\nStored as a secret reference since it's sensitive.", "type": "string" }, "cached_secret_expiry": { diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 741cedbe4a..738fff642e 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -228,9 +228,16 @@ components: contains the reference to retrieve it (e.g., "OAUTH_REFRESH_TOKEN_workload"). This enables session restoration without requiring a new browser-based login. type: string + cached_reg_client_uri: + description: |- + CachedRegClientURI is the registration_client_uri from the DCR response. + This is the endpoint used for RFC 7592 client read/update/delete operations. + Stored as plain text since it is not sensitive. + type: string cached_reg_token_ref: description: |- - RegistrationAccessToken is used to update/delete the client registration. + CachedRegTokenRef is a secret manager reference to the registration_access_token + returned in the DCR response. Used for RFC 7592 client update operations. Stored as a secret reference since it's sensitive. type: string cached_secret_expiry: diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index 6a69fc6a15..8582a89056 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -509,6 +509,12 @@ type OAuthFlowConfig struct { SkipBrowser bool Resource string // RFC 8707 resource indicator (optional) OAuthParams map[string]string + + // DCR renewal metadata — populated by handleDynamicRegistration and threaded + // into OAuthFlowResult so callers can persist the data for RFC 7592 operations. + SecretExpiry time.Time // zero means the secret never expires + RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data + RegistrationClientURI string } // OAuthFlowResult contains the result of an OAuth flow @@ -524,6 +530,14 @@ type OAuthFlowResult struct { // DCR client credentials for persistence (obtained during Dynamic Client Registration) ClientID string ClientSecret string //nolint:gosec // G117: field legitimately holds sensitive data + + // DCR renewal metadata (RFC 7591 §3.2.1 / RFC 7592). + // SecretExpiry is zero when the provider did not issue an expiring secret. + // RegistrationAccessToken and RegistrationClientURI are empty when the + // provider does not support RFC 7592 management operations. + SecretExpiry time.Time + RegistrationAccessToken string //nolint:gosec // G117: field legitimately holds sensitive data + RegistrationClientURI string } func shouldDynamicallyRegisterClient(config *OAuthFlowConfig) bool { @@ -581,7 +595,10 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi return newOAuthFlow(ctx, oauthConfig, config) } -// handleDynamicRegistration handles the dynamic client registration process +// handleDynamicRegistration handles the dynamic client registration process. +// It populates config with the client credentials AND the DCR renewal metadata +// (SecretExpiry, RegistrationAccessToken, RegistrationClientURI) so that +// callers can persist the full RFC 7592 context for later secret renewal. func handleDynamicRegistration(ctx context.Context, issuer string, config *OAuthFlowConfig) error { discoveredDoc, err := getDiscoveryDocument(ctx, issuer, config) if err != nil { @@ -602,6 +619,18 @@ func handleDynamicRegistration(ctx context.Context, issuer string, config *OAuth config.TokenURL = discoveredDoc.TokenEndpoint } + // Store DCR renewal metadata for RFC 7592 operations. + // client_secret_expires_at == 0 means the secret never expires (RFC 7591 §3.2.1). + if registrationResponse.ClientSecretExpiresAt > 0 { + config.SecretExpiry = time.Unix(registrationResponse.ClientSecretExpiresAt, 0) + } + config.RegistrationAccessToken = registrationResponse.RegistrationAccessToken + config.RegistrationClientURI = registrationResponse.RegistrationClientURI + + if registrationResponse.RegistrationAccessToken != "" { + slog.Debug("DCR response includes registration access token for RFC 7592 operations") + } + return nil } @@ -707,6 +736,10 @@ func newOAuthFlow(ctx context.Context, oauthConfig *oauth.Config, config *OAuthF Expiry: tokenResult.Expiry, ClientID: oauthConfig.ClientID, ClientSecret: oauthConfig.ClientSecret, + // DCR renewal metadata — populated only when dynamic registration was performed. + SecretExpiry: config.SecretExpiry, + RegistrationAccessToken: config.RegistrationAccessToken, + RegistrationClientURI: config.RegistrationClientURI, }, nil } diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index 049e5ab6b2..ff87abfa1f 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -63,9 +63,14 @@ type Config struct { // ClientSecretExpiresAt indicates when the client secret expires (if provided by the DCR server). // A zero value means the secret does not expire. CachedSecretExpiry time.Time `json:"cached_secret_expiry,omitempty" yaml:"cached_secret_expiry,omitempty"` - // RegistrationAccessToken is used to update/delete the client registration. + // CachedRegTokenRef is a secret manager reference to the registration_access_token + // returned in the DCR response. Used for RFC 7592 client update operations. // Stored as a secret reference since it's sensitive. CachedRegTokenRef string `json:"cached_reg_token_ref,omitempty" yaml:"cached_reg_token_ref,omitempty"` + // CachedRegClientURI is the registration_client_uri from the DCR response. + // This is the endpoint used for RFC 7592 client read/update/delete operations. + // Stored as plain text since it is not sensitive. + CachedRegClientURI string `json:"cached_reg_client_uri,omitempty" yaml:"cached_reg_client_uri,omitempty"` } // BearerTokenEnvVarName is the environment variable name used for bearer token authentication. @@ -165,6 +170,7 @@ func (c *Config) ClearCachedClientCredentials() { c.CachedClientSecretRef = "" c.CachedSecretExpiry = time.Time{} c.CachedRegTokenRef = "" + c.CachedRegClientURI = "" } // DefaultResourceIndicator derives the resource indicator (RFC 8707) from the remote server URL. diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index fd82e6c2b9..58ff2495cf 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "log/slog" + "time" "golang.org/x/oauth2" @@ -187,7 +188,13 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. // Persist DCR client credentials if available (for servers that use Dynamic Client Registration) // Only persist if client_id exists - client_secret may be empty for PKCE flows if h.clientCredentialsPersister != nil && result.ClientID != "" { - if err := h.clientCredentialsPersister(result.ClientID, result.ClientSecret); err != nil { + if err := h.clientCredentialsPersister( + result.ClientID, + result.ClientSecret, + result.SecretExpiry, + result.RegistrationAccessToken, + result.RegistrationClientURI, + ); err != nil { slog.Warn("Failed to persist DCR client credentials", "error", err) } else { slog.Debug("Successfully persisted DCR client credentials for future restarts") @@ -205,6 +212,8 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. // resolveClientCredentials returns the client ID and secret to use, preferring // cached DCR credentials over statically configured ones. +// If the cached client secret is expiring soon, it attempts renewal via RFC 7592 +// before returning the credentials. func (h *Handler) resolveClientCredentials(ctx context.Context) (clientID, clientSecret string) { // First try to use statically configured credentials clientID = h.config.ClientID @@ -216,6 +225,18 @@ func (h *Handler) resolveClientCredentials(ctx context.Context) (clientID, clien clientID = h.config.CachedClientID slog.Debug("Using cached DCR client credentials", "client_id", clientID) + // Proactively renew the client secret if it is expiring soon (RFC 7592) + if h.isSecretExpiredOrExpiringSoon() { + slog.Info("Cached client secret is expiring soon, attempting renewal", + "expiry", h.config.CachedSecretExpiry) + if renewErr := h.renewClientSecret(ctx); renewErr != nil { + slog.Warn("Failed to proactively renew client secret; continuing with existing secret", + "error", renewErr) + } else { + slog.Debug("Successfully renewed client secret ahead of expiry") + } + } + // Client secret is stored securely and may be empty for PKCE flows if h.config.CachedClientSecretRef != "" && h.secretProvider != nil { cachedClientSecret, err := h.secretProvider.GetSecret(ctx, h.config.CachedClientSecretRef) @@ -242,6 +263,27 @@ func (h *Handler) tryRestoreFromCachedTokens( return nil, fmt.Errorf("secret provider not configured, cannot restore cached tokens") } + // Check if the cached client secret is expired before attempting token refresh. + // If it has fully expired and renewal also fails we must force a fresh OAuth flow. + if h.isSecretExpiredOrExpiringSoon() { + slog.Info("Cached client secret is expiring or expired; attempting renewal before token restore", + "expiry", h.config.CachedSecretExpiry) + if renewErr := h.renewClientSecret(ctx); renewErr != nil { + slog.Warn("Client secret renewal failed", "error", renewErr) + // Hard-fail only when the secret is already past its expiry. + // If we are still in the buffer window the existing secret may work. + if !h.config.CachedSecretExpiry.IsZero() && time.Now().After(h.config.CachedSecretExpiry) { + return nil, fmt.Errorf( + "client secret expired at %v and renewal failed: %w", + h.config.CachedSecretExpiry, renewErr) + } + // Still within buffer — log and continue with the existing (still-valid) secret + slog.Warn("Proceeding with expiring client secret after failed renewal attempt") + } else { + slog.Debug("Successfully renewed client secret before token restore") + } + } + refreshToken, err := h.secretProvider.GetSecret(ctx, h.config.CachedRefreshTokenRef) if err != nil { return nil, fmt.Errorf("failed to retrieve cached refresh token: %w", err) diff --git a/pkg/auth/remote/persisting_token_source.go b/pkg/auth/remote/persisting_token_source.go index f6b7b451e8..32af4e3a5e 100644 --- a/pkg/auth/remote/persisting_token_source.go +++ b/pkg/auth/remote/persisting_token_source.go @@ -18,8 +18,22 @@ import ( type TokenPersister func(refreshToken string, expiry time.Time) error // ClientCredentialsPersister is called when DCR client credentials need to be persisted. -// This is used to store client_id and client_secret obtained during Dynamic Client Registration. -type ClientCredentialsPersister func(clientID, clientSecret string) error +// This is used to store client_id, client_secret, and renewal metadata obtained during +// Dynamic Client Registration (RFC 7591) and needed for secret renewal (RFC 7592). +// +// Parameters: +// - clientID: the registered client ID (public, stored as plain text) +// - clientSecret: the registered client secret (sensitive, stored via secret manager) +// - secretExpiry: when the client secret expires; zero value means it never expires +// - registrationAccessToken: bearer token for RFC 7592 management operations (sensitive) +// - registrationClientURI: endpoint for RFC 7592 client update/read operations (plain text) +type ClientCredentialsPersister func( + clientID string, + clientSecret string, + secretExpiry time.Time, + registrationAccessToken string, + registrationClientURI string, +) error // PersistingTokenSource wraps an oauth2.TokenSource and persists tokens // whenever they are refreshed. This enables session restoration across diff --git a/pkg/auth/remote/secret_renewal.go b/pkg/auth/remote/secret_renewal.go new file mode 100644 index 0000000000..01c1a58929 --- /dev/null +++ b/pkg/auth/remote/secret_renewal.go @@ -0,0 +1,230 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + "github.com/stacklok/toolhive/pkg/networking" +) + +// secretExpiryBuffer is the lead time before expiry at which we proactively +// renew the client secret (RFC 7592). Renewal is attempted when the secret +// expires within this window, not only after expiry. +const secretExpiryBuffer = 24 * time.Hour + +// clientUpdateRequest is the body sent in a RFC 7592 §2.2 PUT request. +// Per the spec, all client metadata fields that were provided during +// registration must be included in the update request body. +type clientUpdateRequest struct { + ClientID string `json:"client_id"` + ClientName string `json:"client_name,omitempty"` + RedirectURIs []string `json:"redirect_uris,omitempty"` + GrantTypes []string `json:"grant_types,omitempty"` + ResponseTypes []string `json:"response_types,omitempty"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` +} + +// clientUpdateResponse is the body returned by a RFC 7592 §2.1 response. +// The provider may rotate the registration_access_token; if present we must +// replace the stored one. +type clientUpdateResponse struct { + // Required fields mirrored from registration response + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret,omitempty"` //nolint:gosec // G117: field holds sensitive data + + // Expiry fields + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + + // Management fields — registration_access_token may be rotated + RegistrationAccessToken string `json:"registration_access_token,omitempty"` //nolint:gosec + RegistrationClientURI string `json:"registration_client_uri,omitempty"` +} + +// isSecretExpiredOrExpiringSoon returns true when the cached client secret is +// already expired or will expire within secretExpiryBuffer. +// A zero CachedSecretExpiry means the secret never expires, so this returns false. +func (h *Handler) isSecretExpiredOrExpiringSoon() bool { + expiry := h.config.CachedSecretExpiry + if expiry.IsZero() { + return false // Non-expiring secret + } + return time.Now().After(expiry.Add(-secretExpiryBuffer)) +} + +// renewClientSecret attempts to renew the client secret using RFC 7592 §2.2. +// It retrieves the stored registration_access_token and sends a PUT request +// to the registration_client_uri with the current client metadata. +// +// On success the handler's config is updated with the new secret, expiry, and +// (if rotated) the new registration_access_token. +// +// Callers should log a warning and continue if renewal fails — the existing +// secret may still be valid for some time, or the provider may not support renewal. +func (h *Handler) renewClientSecret(ctx context.Context) error { + if err := h.validateRenewalPrerequisites(); err != nil { + return err + } + + // Retrieve the registration access token from the secret manager + regAccessToken, err := h.secretProvider.GetSecret(ctx, h.config.CachedRegTokenRef) + if err != nil { + return fmt.Errorf("failed to retrieve registration access token: %w", err) + } + + slog.Debug("Attempting RFC 7592 client secret renewal", + "registration_client_uri", h.config.CachedRegClientURI) + + // Validate the registration_client_uri before using it + if err := validateRegistrationClientURI(h.config.CachedRegClientURI); err != nil { + return fmt.Errorf("invalid registration_client_uri: %w", err) + } + + // Build the update request body with the current client metadata. + // Per RFC 7592 §2.2, the request MUST include all client metadata fields + // that were provided during the initial registration. + updateReq := clientUpdateRequest{ + ClientID: h.config.CachedClientID, + ClientName: "ToolHive MCP Client", + RedirectURIs: []string{fmt.Sprintf("http://localhost:%d/callback", h.config.CallbackPort)}, + GrantTypes: []string{"authorization_code", "refresh_token"}, + ResponseTypes: []string{"code"}, + TokenEndpointAuthMethod: "none", + } + + reqBody, err := json.Marshal(updateReq) + if err != nil { + return fmt.Errorf("failed to marshal client update request: %w", err) + } + + // Create PUT request per RFC 7592 §2.2 + req, err := http.NewRequestWithContext( + ctx, + http.MethodPut, + h.config.CachedRegClientURI, + strings.NewReader(string(reqBody)), + ) + if err != nil { + return fmt.Errorf("failed to create client update request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+regAccessToken) //nolint:gosec // G117 + + // Execute the request + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + }, + } + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("client update request failed: %w", err) + } + defer func() { + if closeErr := resp.Body.Close(); closeErr != nil { + slog.Debug("Failed to close renewal response body", "error", closeErr) + } + }() + + if resp.StatusCode != http.StatusOK { + errorBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) + return fmt.Errorf("client update request returned HTTP %d: %s", resp.StatusCode, string(errorBody)) + } + + // Parse the renewal response + const maxResponseSize = 1024 * 1024 // 1 MB + var updateResp clientUpdateResponse + if err := json.NewDecoder(io.LimitReader(resp.Body, maxResponseSize)).Decode(&updateResp); err != nil { + return fmt.Errorf("failed to decode client update response: %w", err) + } + + if updateResp.ClientID == "" { + return fmt.Errorf("client update response missing client_id") + } + if updateResp.ClientSecret == "" { + return fmt.Errorf("client update response missing client_secret") + } + + return h.persistRenewedSecret(updateResp) +} + +func (h *Handler) validateRenewalPrerequisites() error { + if h.config.CachedRegClientURI == "" { + return fmt.Errorf("registration_client_uri missing; cannot renew secret (RFC 7592 unsupported)") + } + if h.config.CachedRegTokenRef == "" { + return fmt.Errorf("registration_access_token missing; cannot renew secret (RFC 7592 unsupported)") + } + if h.secretProvider == nil { + return fmt.Errorf("secret provider not configured; cannot retrieve registration access token") + } + return nil +} + +func (h *Handler) persistRenewedSecret(updateResp clientUpdateResponse) error { + if h.clientCredentialsPersister == nil { + return fmt.Errorf("client credentials persister not configured; cannot save renewed secret") + } + + var newExpiry time.Time + if updateResp.ClientSecretExpiresAt > 0 { + newExpiry = time.Unix(updateResp.ClientSecretExpiresAt, 0) + } + + // Use the rotated registration_access_token if provided; fall back to existing. + newRegToken := updateResp.RegistrationAccessToken + newRegURI := updateResp.RegistrationClientURI + if newRegURI == "" { + newRegURI = h.config.CachedRegClientURI + } + + if err := h.clientCredentialsPersister( + updateResp.ClientID, + updateResp.ClientSecret, + newExpiry, + newRegToken, + newRegURI, + ); err != nil { + return fmt.Errorf("failed to persist renewed client secret: %w", err) + } + + slog.Info("Successfully renewed client secret via RFC 7592", + "client_id", updateResp.ClientID, + "new_expiry_zero", newExpiry.IsZero(), + "reg_token_rotated", newRegToken != "") + + return nil +} + +// validateRegistrationClientURI validates that the registration_client_uri is +// a valid HTTPS URL (or localhost for development). +func validateRegistrationClientURI(registrationClientURI string) error { + if registrationClientURI == "" { + return fmt.Errorf("registration_client_uri is empty") + } + + parsedURL, err := url.Parse(registrationClientURI) + if err != nil { + return fmt.Errorf("invalid registration_client_uri URL: %w", err) + } + + if parsedURL.Scheme != "https" && !networking.IsLocalhost(parsedURL.Host) { + return fmt.Errorf("registration_client_uri must use HTTPS: %s", registrationClientURI) + } + + return nil +} diff --git a/pkg/auth/remote/secret_renewal_test.go b/pkg/auth/remote/secret_renewal_test.go new file mode 100644 index 0000000000..04ae124764 --- /dev/null +++ b/pkg/auth/remote/secret_renewal_test.go @@ -0,0 +1,383 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package remote + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/secrets" +) + +// mockSecretProvider is a simple in-memory secret store for tests. +// It implements the full secrets.Provider interface. +type mockSecretProvider struct { + secrets map[string]string +} + +func newMockSecretProvider(initial map[string]string) *mockSecretProvider { + if initial == nil { + initial = make(map[string]string) + } + return &mockSecretProvider{secrets: initial} +} + +func (m *mockSecretProvider) GetSecret(_ context.Context, name string) (string, error) { + v, ok := m.secrets[name] + if !ok { + return "", fmt.Errorf("secret %q not found", name) + } + return v, nil +} + +func (m *mockSecretProvider) SetSecret(_ context.Context, name, value string) error { + m.secrets[name] = value + return nil +} + +func (m *mockSecretProvider) DeleteSecret(_ context.Context, name string) error { + delete(m.secrets, name) + return nil +} + +func (m *mockSecretProvider) ListSecrets(_ context.Context) ([]secrets.SecretDescription, error) { + result := make([]secrets.SecretDescription, 0, len(m.secrets)) + for k := range m.secrets { + result = append(result, secrets.SecretDescription{Key: k}) + } + return result, nil +} + +func (*mockSecretProvider) Cleanup() error { return nil } + +func (*mockSecretProvider) Capabilities() secrets.ProviderCapabilities { + return secrets.ProviderCapabilities{ + CanRead: true, + CanWrite: true, + CanDelete: true, + CanList: true, + } +} + +// TestIsSecretExpiredOrExpiringSoon tests the expiry helper on various time scenarios. +func TestIsSecretExpiredOrExpiringSoon(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + expiry time.Time + wantExpired bool + }{ + { + name: "zero expiry means never expires", + expiry: time.Time{}, + wantExpired: false, + }, + { + name: "expiry far in the future — not expiring", + expiry: time.Now().Add(48 * time.Hour), + wantExpired: false, + }, + { + name: "expiry within 24h buffer — expiring soon", + expiry: time.Now().Add(12 * time.Hour), + wantExpired: true, + }, + { + name: "expiry in the past — already expired", + expiry: time.Now().Add(-1 * time.Hour), + wantExpired: true, + }, + { + name: "expiry exactly at buffer boundary — expiring soon", + expiry: time.Now().Add(secretExpiryBuffer - time.Minute), + wantExpired: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedSecretExpiry: tt.expiry, + }, + } + assert.Equal(t, tt.wantExpired, h.isSecretExpiredOrExpiringSoon()) + }) + } +} + +// TestValidateRegistrationClientURI tests URI validation. +func TestValidateRegistrationClientURI(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + uri string + wantErr bool + }{ + { + name: "empty URI", + uri: "", + wantErr: true, + }, + { + name: "valid HTTPS URI", + uri: "https://example.com/oauth/register/client-id", + wantErr: false, + }, + { + name: "HTTP URI for non-localhost is rejected", + uri: "http://example.com/oauth/register/client-id", + wantErr: true, + }, + { + name: "localhost HTTP is allowed (development)", + uri: "http://localhost:8080/oauth/register/client-id", + wantErr: false, + }, + { + name: "127.0.0.1 HTTP is allowed (development)", + uri: "http://127.0.0.1:8080/oauth/register/client-id", + wantErr: false, + }, + { + name: "invalid URL", + uri: "://bad-url", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateRegistrationClientURI(tt.uri) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// TestRenewClientSecret_MissingConfig tests early-exit conditions. +func TestRenewClientSecret_MissingConfig(t *testing.T) { + t.Parallel() + + t.Run("missing registration_client_uri", func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: "", + CachedRegTokenRef: "some-ref", + }, + } + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "registration_client_uri missing") + }) + + t.Run("missing registration_token_ref", func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: "https://example.com/register/client-id", + CachedRegTokenRef: "", + }, + } + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "registration_access_token missing") + }) + + t.Run("missing secret provider", func(t *testing.T) { + t.Parallel() + + h := &Handler{ + config: &Config{ + CachedRegClientURI: "https://example.com/register/client-id", + CachedRegTokenRef: "some-ref", + }, + secretProvider: nil, // no provider + } + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "secret provider not configured") + }) +} + +// TestRenewClientSecret_Success tests the happy path with a mock RFC 7592 server. +func TestRenewClientSecret_Success(t *testing.T) { + t.Parallel() + + newSecret := "new-client-secret-xyz" + newExpiry := time.Now().Add(24 * time.Hour * 30).Unix() + newRegToken := "new-registration-access-token" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // RFC 7592 §2.2: must be PUT with Bearer auth + assert.Equal(t, http.MethodPut, r.Method) + assert.Contains(t, r.Header.Get("Authorization"), "Bearer reg-access-token") + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + + // Return the updated registration response + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "client_id": "test-client-id", + "client_secret": newSecret, + "client_secret_expires_at": newExpiry, + "registration_access_token": newRegToken, + "registration_client_uri": "http://" + r.Host + r.URL.Path, + }) + })) + defer server.Close() + + // Set up persister capture + var persistedClientID, persistedSecret, persistedRegToken, persistedRegURI string + var persistedExpiry time.Time + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-secret-ref", + }, + secretProvider: newMockSecretProvider(map[string]string{ + "reg-token-secret-ref": "reg-access-token", + }), + clientCredentialsPersister: func( + clientID, secret string, + expiry time.Time, + regToken, regURI string, + ) error { + persistedClientID = clientID + persistedSecret = secret + persistedExpiry = expiry + persistedRegToken = regToken + persistedRegURI = regURI + return nil + }, + } + + err := h.renewClientSecret(context.Background()) + require.NoError(t, err) + + assert.Equal(t, "test-client-id", persistedClientID) + assert.Equal(t, newSecret, persistedSecret) + assert.Equal(t, newRegToken, persistedRegToken) + assert.False(t, persistedExpiry.IsZero(), "expiry should be set") + assert.NotEmpty(t, persistedRegURI) +} + +// TestRenewClientSecret_ServerError tests error propagation when the server returns non-200. +func TestRenewClientSecret_ServerError(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + })) + defer server.Close() + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-ref", + }, + secretProvider: newMockSecretProvider(map[string]string{ + "reg-token-ref": "bad-token", + }), + clientCredentialsPersister: func(_, _ string, _ time.Time, _, _ string) error { + return nil + }, + } + + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "401") +} + +// TestRenewClientSecret_NoPersister tests failure when persister is not set. +func TestRenewClientSecret_NoPersister(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "client_id": "test-client-id", + "client_secret": "new-secret", + }) + })) + defer server.Close() + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-ref", + }, + secretProvider: newMockSecretProvider(map[string]string{ + "reg-token-ref": "some-token", + }), + clientCredentialsPersister: nil, // no persister + } + + err := h.renewClientSecret(context.Background()) + require.Error(t, err) + assert.Contains(t, err.Error(), "client credentials persister not configured") +} + +// TestRenewClientSecret_ZeroExpiryInResponse tests that a zero client_secret_expires_at +// is correctly interpreted as a non-expiring secret. +func TestRenewClientSecret_ZeroExpiryInResponse(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "client_id": "test-client-id", + "client_secret": "new-secret", + "client_secret_expires_at": 0, // never expires + }) + })) + defer server.Close() + + var capturedExpiry time.Time + + h := &Handler{ + config: &Config{ + CachedClientID: "test-client-id", + CachedRegClientURI: server.URL + "/register/test-client-id", + CachedRegTokenRef: "reg-token-ref", + }, + secretProvider: newMockSecretProvider(map[string]string{ + "reg-token-ref": "some-token", + }), + clientCredentialsPersister: func(_, _ string, expiry time.Time, _, _ string) error { + capturedExpiry = expiry + return nil + }, + } + + err := h.renewClientSecret(context.Background()) + require.NoError(t, err) + assert.True(t, capturedExpiry.IsZero(), "zero client_secret_expires_at must produce zero time.Time") +} diff --git a/pkg/runner/config_test.go b/pkg/runner/config_test.go index 6beadef9bd..3578949ebc 100644 --- a/pkg/runner/config_test.go +++ b/pkg/runner/config_test.go @@ -98,6 +98,22 @@ func TestRunConfig_WithTransport(t *testing.T) { // Note: This test uses actual port finding logic, so it may fail if ports are in use func TestRunConfig_WithPorts(t *testing.T) { t.Parallel() + // Find available ports dynamically to avoid flaky failures + port1 := networking.FindAvailable() + require.NotZero(t, port1, "should find an available proxy port for SSE") + targetPort1 := networking.FindAvailable() + require.NotZero(t, targetPort1, "should find an available target port for SSE") + + port2 := networking.FindAvailable() + require.NotZero(t, port2, "should find an available proxy port for HTTP") + targetPort2 := networking.FindAvailable() + require.NotZero(t, targetPort2, "should find an available target port for HTTP") + + port3 := networking.FindAvailable() + require.NotZero(t, port3, "should find an available proxy port for Stdio") + targetPort3 := networking.FindAvailable() + require.NotZero(t, targetPort3, "should find an available target port for Stdio") + testCases := []struct { name string config *RunConfig @@ -108,8 +124,8 @@ func TestRunConfig_WithPorts(t *testing.T) { { name: "SSE transport with specific ports", config: &RunConfig{Transport: types.TransportTypeSSE}, - port: 8001, - targetPort: 9001, + port: port1, + targetPort: targetPort1, expectError: false, }, { @@ -122,15 +138,15 @@ func TestRunConfig_WithPorts(t *testing.T) { { name: "Streamable HTTP transport with specific ports", config: &RunConfig{Transport: types.TransportTypeStreamableHTTP}, - port: 8002, - targetPort: 9002, + port: port2, + targetPort: targetPort2, expectError: false, }, { name: "Stdio transport with specific port", config: &RunConfig{Transport: types.TransportTypeStdio}, - port: 8003, - targetPort: 9003, // This should be ignored for stdio + port: port3, + targetPort: targetPort3, // This should be ignored for stdio expectError: false, }, } diff --git a/pkg/runner/runner.go b/pkg/runner/runner.go index af235acfc1..fb9ef69f46 100644 --- a/pkg/runner/runner.go +++ b/pkg/runner/runner.go @@ -639,63 +639,16 @@ func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSo // Set up token persister to save tokens across restarts if secretManager != nil { authHandler.SetTokenPersister(func(refreshToken string, expiry time.Time) error { - // Generate a unique secret name for this workload's refresh token - secretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( - r.Config.Name, - "OAUTH_REFRESH_TOKEN_", - secretManager, - ) - if err != nil { - return fmt.Errorf("failed to generate secret name: %w", err) - } - - // Store the refresh token in the secret manager - if err := authsecrets.StoreSecretInManagerWithProvider(ctx, secretName, refreshToken, secretManager); err != nil { - return fmt.Errorf("failed to store refresh token: %w", err) - } - - // Store the secret reference (not the actual token) in the config - r.Config.RemoteAuthConfig.CachedRefreshTokenRef = secretName - r.Config.RemoteAuthConfig.CachedTokenExpiry = expiry - - // Save the updated config to persist the reference - if err := r.Config.SaveState(ctx); err != nil { - return fmt.Errorf("failed to save config with token reference: %w", err) - } - - slog.Debug("Stored OAuth refresh token in secret manager", "secret_name", secretName) - return nil + return r.persistRefreshToken(ctx, secretManager, refreshToken, expiry) }) // Set up client credentials persister for DCR (Dynamic Client Registration) - authHandler.SetClientCredentialsPersister(func(clientID, clientSecret string) error { - // Store client ID directly (it's public information) - r.Config.RemoteAuthConfig.CachedClientID = clientID - - // Only store client secret if it's non-empty (PKCE flows may not have one) - if clientSecret != "" { - clientSecretSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( - r.Config.Name, - "OAUTH_CLIENT_SECRET_", - secretManager, - ) - if err != nil { - return fmt.Errorf("failed to generate client secret secret name: %w", err) - } - - if err := authsecrets.StoreSecretInManagerWithProvider(ctx, clientSecretSecretName, clientSecret, secretManager); err != nil { - return fmt.Errorf("failed to store client secret: %w", err) - } - r.Config.RemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName - } - - // Save the updated config to persist the credentials - if err := r.Config.SaveState(ctx); err != nil { - return fmt.Errorf("failed to save config with client credentials: %w", err) - } - - slog.Debug("Stored DCR client credentials", "client_id", clientID) - return nil + authHandler.SetClientCredentialsPersister(func( + clientID, clientSecret string, + secretExpiry time.Time, + regAccessToken, regClientURI string, + ) error { + return r.persistClientCredentials(ctx, secretManager, clientID, clientSecret, secretExpiry, regAccessToken, regClientURI) }) } @@ -708,6 +661,89 @@ func (r *Runner) handleRemoteAuthentication(ctx context.Context) (oauth2.TokenSo return tokenSource, nil } +func (r *Runner) persistRefreshToken( + ctx context.Context, + secretManager secrets.Provider, + refreshToken string, + expiry time.Time, +) error { + secretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( + r.Config.Name, + "OAUTH_REFRESH_TOKEN_", + secretManager, + ) + if err != nil { + return fmt.Errorf("failed to generate secret name: %w", err) + } + + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, secretName, refreshToken, secretManager); err != nil { + return fmt.Errorf("failed to store refresh token: %w", err) + } + + r.Config.RemoteAuthConfig.CachedRefreshTokenRef = secretName + r.Config.RemoteAuthConfig.CachedTokenExpiry = expiry + + if err := r.Config.SaveState(ctx); err != nil { + return fmt.Errorf("failed to save config with token reference: %w", err) + } + + slog.Debug("Stored OAuth refresh token in secret manager", "secret_name", secretName) + return nil +} + +func (r *Runner) persistClientCredentials( + ctx context.Context, + secretManager secrets.Provider, + clientID, clientSecret string, + secretExpiry time.Time, + regAccessToken, regClientURI string, +) error { + r.Config.RemoteAuthConfig.CachedClientID = clientID + + if clientSecret != "" { + clientSecretSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix( + r.Config.Name, + "OAUTH_CLIENT_SECRET_", + secretManager, + ) + if err != nil { + return fmt.Errorf("failed to generate client secret secret name: %w", err) + } + + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, clientSecretSecretName, clientSecret, secretManager); err != nil { + return fmt.Errorf("failed to store client secret: %w", err) + } + r.Config.RemoteAuthConfig.CachedClientSecretRef = clientSecretSecretName + } + + r.Config.RemoteAuthConfig.CachedSecretExpiry = secretExpiry + + if regAccessToken != "" { + regTokenSecretName, err := authsecrets.GenerateUniqueSecretNameWithPrefix(r.Config.Name, "OAUTH_REG_TOKEN_", secretManager) + if err != nil { + return fmt.Errorf("failed to generate registration token secret name: %w", err) + } + + if err := authsecrets.StoreSecretInManagerWithProvider(ctx, regTokenSecretName, regAccessToken, secretManager); err != nil { + return fmt.Errorf("failed to store registration access token: %w", err) + } + r.Config.RemoteAuthConfig.CachedRegTokenRef = regTokenSecretName + slog.Debug("Stored DCR registration access token for RFC 7592 operations") + } + + r.Config.RemoteAuthConfig.CachedRegClientURI = regClientURI + + if err := r.Config.SaveState(ctx); err != nil { + return fmt.Errorf("failed to save config with client credentials: %w", err) + } + + slog.Debug("Stored DCR client credentials", "client_id", clientID, + "has_expiry", !secretExpiry.IsZero(), + "has_reg_token", regAccessToken != "", + "has_reg_uri", regClientURI != "") + return nil +} + // Cleanup performs cleanup operations for the runner, including shutting down all middleware. func (r *Runner) Cleanup(ctx context.Context) error { // For simplicity, return the last error we encounter during cleanup.