diff --git a/pkg/vmcp/client/auth_retry.go b/pkg/vmcp/client/auth_retry.go new file mode 100644 index 0000000000..4c26bca098 --- /dev/null +++ b/pkg/vmcp/client/auth_retry.go @@ -0,0 +1,276 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package client + +import ( + "context" + "errors" + "fmt" + "log/slog" + "sync" + "time" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/singleflight" + + "github.com/stacklok/toolhive/pkg/vmcp" + vmcpauth "github.com/stacklok/toolhive/pkg/vmcp/auth" +) + +const ( + // authRetryInstrumentationName is the OpenTelemetry instrumentation scope for auth retries. + authRetryInstrumentationName = "github.com/stacklok/toolhive/pkg/vmcp/client" + + // maxAuthRetries is the maximum number of retry attempts after an auth failure. + maxAuthRetries = 3 + + // authCircuitBreakerThreshold is the number of consecutive auth failures before + // the circuit breaker opens and disables further retries for a backend. + authCircuitBreakerThreshold = 5 + + // initialRetryBackoff is the base duration for exponential backoff between retries. + // Attempt 1: 100ms, Attempt 2: 200ms, Attempt 3: 400ms. + initialRetryBackoff = 100 * time.Millisecond +) + +// authCircuitBreaker tracks consecutive auth failures per backend and opens the circuit +// after too many failures to prevent excessive latency from repeated auth retries. +type authCircuitBreaker struct { + mu sync.Mutex + consecutiveFails int + open bool +} + +// canRetry returns true if auth retries are still allowed (circuit is closed). +func (cb *authCircuitBreaker) canRetry() bool { + cb.mu.Lock() + defer cb.mu.Unlock() + return !cb.open +} + +// recordSuccess resets the consecutive failure counter and closes the circuit. +func (cb *authCircuitBreaker) recordSuccess() { + cb.mu.Lock() + defer cb.mu.Unlock() + cb.consecutiveFails = 0 + cb.open = false +} + +// recordFailure increments the failure counter and opens the circuit if the threshold is exceeded. +func (cb *authCircuitBreaker) recordFailure(threshold int, backendID string) { + cb.mu.Lock() + defer cb.mu.Unlock() + cb.consecutiveFails++ + if !cb.open && cb.consecutiveFails >= threshold { + cb.open = true + slog.Warn("auth circuit breaker opened: too many consecutive auth failures, disabling retries", + "backend", backendID, "consecutive_failures", cb.consecutiveFails) + } +} + +// retryingBackendClient wraps a BackendClient and automatically retries operations that +// fail due to authentication errors (401/403). It uses: +// - Exponential backoff with a maximum of [maxAuthRetries] attempts +// - A per-backend circuit breaker to stop retrying after [authCircuitBreakerThreshold] consecutive failures +// - singleflight to deduplicate concurrent backoff waits for the same backend +// - OpenTelemetry spans to surface auth-retry latency in distributed traces +// +// Raw credentials are never logged. +type retryingBackendClient struct { + inner vmcp.BackendClient + registry vmcpauth.OutgoingAuthRegistry + + // sf deduplicates concurrent backoff waits for the same backend at the same attempt number. + sf singleflight.Group + + // breakers maps backendID -> *authCircuitBreaker. LoadOrStore is used for concurrent safety. + breakers sync.Map + + tracer trace.Tracer + maxRetries int + cbThreshold int + initialBackoff time.Duration + + // backoffFn is the sleep function used inside singleflight. nil uses time.After. + // Tests inject a counted hook to assert coalescing without real wall-clock delays. + backoffFn func(ctx context.Context, d time.Duration) error +} + +// newRetryingBackendClient wraps inner with auth-failure retry logic. +func newRetryingBackendClient(inner vmcp.BackendClient, registry vmcpauth.OutgoingAuthRegistry) *retryingBackendClient { + return &retryingBackendClient{ + inner: inner, + registry: registry, + tracer: otel.Tracer(authRetryInstrumentationName), + maxRetries: maxAuthRetries, + cbThreshold: authCircuitBreakerThreshold, + initialBackoff: initialRetryBackoff, + } +} + +// getBreaker returns (or lazily creates) the auth circuit breaker for a backend. +func (r *retryingBackendClient) getBreaker(backendID string) *authCircuitBreaker { + v, _ := r.breakers.LoadOrStore(backendID, &authCircuitBreaker{}) + return v.(*authCircuitBreaker) //nolint:forcetypeassert +} + +// withAuthRetry executes op, and if it returns ErrAuthenticationFailed, retries up to +// r.maxRetries times with exponential backoff, using singleflight to deduplicate concurrent +// backoff waits per backend. Auth-retry overhead is surfaced as an OpenTelemetry span. +func (r *retryingBackendClient) withAuthRetry( + ctx context.Context, + backendID string, + op func(context.Context) error, +) error { + breaker := r.getBreaker(backendID) + + err := op(ctx) + if err == nil { + breaker.recordSuccess() + return nil + } + if !errors.Is(err, vmcp.ErrAuthenticationFailed) { + return err + } + if !breaker.canRetry() { + slog.Debug("auth circuit breaker open, skipping auth retry", + "backend", backendID) + return err + } + + // Start a span to surface auth-retry latency in distributed traces. + ctx, span := r.tracer.Start(ctx, "auth.retry", + trace.WithAttributes( + attribute.String("target.workload_id", backendID), + attribute.Int("max_retries", r.maxRetries), + ), + trace.WithSpanKind(trace.SpanKindInternal), + ) + defer span.End() + + lastErr := err + backoff := r.initialBackoff + for attempt := 1; attempt <= r.maxRetries; attempt++ { + // Use singleflight to deduplicate concurrent backoff waits for the same backend + // and attempt number. The first goroutine sleeps; the others coalesce with it. + // DoChan is used instead of Do so every caller can also select on its own + // ctx.Done() — otherwise a coalesced caller with a short deadline would be + // stuck for the full backoff duration of the leader's longer-lived context. + sfKey := fmt.Sprintf("%s:attempt:%d", backendID, attempt) + // The singleflight function uses a detached context so that a cancelled + // leader goroutine does not propagate its error to all coalesced callers. + // Per-caller cancellation is handled by the outer select on ctx.Done() below. + detachedCtx := context.WithoutCancel(ctx) + currentBackoff := backoff + ch := r.sf.DoChan(sfKey, func() (any, error) { + if r.backoffFn != nil { + return nil, r.backoffFn(detachedCtx, currentBackoff) + } + select { + case <-detachedCtx.Done(): + return nil, detachedCtx.Err() + case <-time.After(currentBackoff): + return nil, nil + } + }) + var sfErr error + select { + case <-ctx.Done(): + sfErr = ctx.Err() + case res := <-ch: + sfErr = res.Err + } + if sfErr != nil { + span.RecordError(sfErr) + return sfErr + } + + span.AddEvent("auth.retry.attempt", + trace.WithAttributes(attribute.Int("attempt", attempt))) + + retryErr := op(ctx) + if retryErr == nil { + breaker.recordSuccess() + span.SetStatus(codes.Ok, "auth retry succeeded") + return nil + } + + lastErr = retryErr + if !errors.Is(retryErr, vmcp.ErrAuthenticationFailed) { + // Non-auth error on retry — no point continuing auth retries. + span.RecordError(retryErr) + return retryErr + } + backoff *= 2 + } + + // All retries exhausted with auth failures — update circuit breaker. + breaker.recordFailure(r.cbThreshold, backendID) + span.RecordError(lastErr) + span.SetStatus(codes.Error, "auth retry exhausted") + return lastErr +} + +// retryResult is a generic helper that wraps withAuthRetry for operations that return a value, +// eliminating the boilerplate of capturing a result variable in every BackendClient method. +func retryResult[T any]( + ctx context.Context, r *retryingBackendClient, backendID string, op func(context.Context) (T, error), +) (T, error) { + var result T + err := r.withAuthRetry(ctx, backendID, func(ctx context.Context) error { + var opErr error + result, opErr = op(ctx) + return opErr + }) + return result, err +} + +// CallTool implements vmcp.BackendClient. +func (r *retryingBackendClient) CallTool( + ctx context.Context, + target *vmcp.BackendTarget, + toolName string, + arguments map[string]any, + meta map[string]any, +) (*vmcp.ToolCallResult, error) { + return retryResult(ctx, r, target.WorkloadID, func(ctx context.Context) (*vmcp.ToolCallResult, error) { + return r.inner.CallTool(ctx, target, toolName, arguments, meta) + }) +} + +// ReadResource implements vmcp.BackendClient. +func (r *retryingBackendClient) ReadResource( + ctx context.Context, + target *vmcp.BackendTarget, + uri string, +) (*vmcp.ResourceReadResult, error) { + return retryResult(ctx, r, target.WorkloadID, func(ctx context.Context) (*vmcp.ResourceReadResult, error) { + return r.inner.ReadResource(ctx, target, uri) + }) +} + +// GetPrompt implements vmcp.BackendClient. +func (r *retryingBackendClient) GetPrompt( + ctx context.Context, + target *vmcp.BackendTarget, + name string, + arguments map[string]any, +) (*vmcp.PromptGetResult, error) { + return retryResult(ctx, r, target.WorkloadID, func(ctx context.Context) (*vmcp.PromptGetResult, error) { + return r.inner.GetPrompt(ctx, target, name, arguments) + }) +} + +// ListCapabilities implements vmcp.BackendClient. +func (r *retryingBackendClient) ListCapabilities( + ctx context.Context, + target *vmcp.BackendTarget, +) (*vmcp.CapabilityList, error) { + return retryResult(ctx, r, target.WorkloadID, func(ctx context.Context) (*vmcp.CapabilityList, error) { + return r.inner.ListCapabilities(ctx, target) + }) +} diff --git a/pkg/vmcp/client/auth_retry_integration_test.go b/pkg/vmcp/client/auth_retry_integration_test.go new file mode 100644 index 0000000000..84f53d65d8 --- /dev/null +++ b/pkg/vmcp/client/auth_retry_integration_test.go @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package client_test + +import ( + "context" + "net" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/auth" + "github.com/stacklok/toolhive/pkg/vmcp/auth/strategies" + vmcpclient "github.com/stacklok/toolhive/pkg/vmcp/client" +) + +// TestAuthRetry_Transient401_ListCapabilities verifies the end-to-end retry path when a +// backend MCP server returns HTTP 401 on the first request it receives. +// +// NewHTTPBackendClient wraps httpBackendClient with retryingBackendClient. +// ListCapabilities creates a fresh MCP client per call (Start + Initialize + List*). +// httpStatusRoundTripper intercepts the 401 response before mcp-go processes it, +// converting it to vmcp.ErrAuthenticationFailed, which retryingBackendClient detects +// via errors.Is and retries until success. +func TestAuthRetry_Transient401_ListCapabilities(t *testing.T) { + t.Parallel() + + var requestCount atomic.Int32 + backend, cleanup := startTransient401Server(t, &requestCount) + defer cleanup() + + registry := auth.NewDefaultOutgoingAuthRegistry() + require.NoError(t, registry.RegisterStrategy("unauthenticated", &strategies.UnauthenticatedStrategy{})) + + backendClient, err := vmcpclient.NewHTTPBackendClient(registry) + require.NoError(t, err) + + target := &vmcp.BackendTarget{ + WorkloadID: "test-backend", + WorkloadName: "Test Backend", + BaseURL: backend, + TransportType: "streamable-http", + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // ListCapabilities should succeed despite the initial 401 — the retry wrapper + // must recreate the MCP client and successfully complete the capability query. + caps, err := backendClient.ListCapabilities(ctx, target) + + require.NoError(t, err, "ListCapabilities should succeed after auth retry") + require.NotNil(t, caps) + assert.Len(t, caps.Tools, 1, "should discover the echo tool after retry") + + // Confirm the retry was exercised: the backend received more than one batch of + // requests (the 401 attempt + the successful retry). + assert.Greater(t, int(requestCount.Load()), 1, + "backend must have received >1 request, confirming retry was exercised") +} + +// startTransient401Server starts an httptest.Server backed by a real mcp-go MCP server. +// It returns 401 for the first request, then passes through to the real handler. +// The returned cleanup function must be deferred by the caller. +func startTransient401Server(tb testing.TB, requestCount *atomic.Int32) (baseURL string, cleanup func()) { + tb.Helper() + + mcpSrv := server.NewMCPServer("test-backend", "1.0.0", + server.WithToolCapabilities(true), + ) + mcpSrv.AddTool( + mcp.Tool{Name: "echo", Description: "Echo the input"}, + func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{mcp.NewTextContent("ok")}, + }, nil + }, + ) + + streamable := server.NewStreamableHTTPServer(mcpSrv, server.WithEndpointPath("/mcp")) + + httpSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodDelete { + // Allow session-close DELETE to pass through without counting. + streamable.ServeHTTP(w, r) + return + } + n := requestCount.Add(1) + if n <= 1 { + w.WriteHeader(http.StatusUnauthorized) + return + } + streamable.ServeHTTP(w, r) + })) + + // Bind to a free port on loopback. + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(tb, err) + httpSrv.Listener = ln + httpSrv.Start() + + tb.Logf("started transient-401 backend at %s/mcp (will fail first non-DELETE request)", httpSrv.URL) + + return httpSrv.URL + "/mcp", httpSrv.Close +} diff --git a/pkg/vmcp/client/auth_retry_test.go b/pkg/vmcp/client/auth_retry_test.go new file mode 100644 index 0000000000..5177363659 --- /dev/null +++ b/pkg/vmcp/client/auth_retry_test.go @@ -0,0 +1,411 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package client + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// authErr wraps ErrAuthenticationFailed so errors.Is() matches. +func authErr(msg string) error { + return fmt.Errorf("%w: %s", vmcp.ErrAuthenticationFailed, msg) +} + +// stubBackendClient is a simple stub that returns a pre-configured sequence of errors/results. +type stubBackendClient struct { + mu sync.Mutex + callErrs []error // errors to return in order (nil = success) + callIdx int + calls int +} + +func (s *stubBackendClient) nextErr() error { + s.mu.Lock() + defer s.mu.Unlock() + s.calls++ + if s.callIdx >= len(s.callErrs) { + return nil + } + err := s.callErrs[s.callIdx] + s.callIdx++ + return err +} + +func (s *stubBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any, _ map[string]any) (*vmcp.ToolCallResult, error) { + if err := s.nextErr(); err != nil { + return nil, err + } + return &vmcp.ToolCallResult{}, nil +} + +func (s *stubBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) (*vmcp.ResourceReadResult, error) { + if err := s.nextErr(); err != nil { + return nil, err + } + return &vmcp.ResourceReadResult{}, nil +} + +func (s *stubBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (*vmcp.PromptGetResult, error) { + if err := s.nextErr(); err != nil { + return nil, err + } + return &vmcp.PromptGetResult{}, nil +} + +func (s *stubBackendClient) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + if err := s.nextErr(); err != nil { + return nil, err + } + return &vmcp.CapabilityList{}, nil +} + +func makeTarget(id string) *vmcp.BackendTarget { + return &vmcp.BackendTarget{ + WorkloadID: id, + WorkloadName: id, + BaseURL: "http://localhost:8080", + TransportType: "streamable-http", + } +} + +// newFastRetryClient creates a retryingBackendClient with minimal backoff for tests. +func newFastRetryClient(inner vmcp.BackendClient) *retryingBackendClient { + c := newRetryingBackendClient(inner, nil) + c.initialBackoff = time.Millisecond // fast for tests + c.tracer = noop.NewTracerProvider().Tracer("test") + return c +} + +// TestRetryingBackendClient_SuccessOnFirstAttempt verifies that operations that succeed +// immediately are passed through without any retry overhead. +func TestRetryingBackendClient_SuccessOnFirstAttempt(t *testing.T) { + t.Parallel() + + stub := &stubBackendClient{callErrs: []error{nil}} + c := newFastRetryClient(stub) + target := makeTarget("backend-1") + + result, err := c.CallTool(context.Background(), target, "tool1", nil, nil) + + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, 1, stub.calls) +} + +// TestRetryingBackendClient_SuccessOnFirstAttempt_ResetsBreaker verifies that a first-attempt +// success resets the circuit breaker, so prior failures don't accumulate indefinitely. +func TestRetryingBackendClient_SuccessOnFirstAttempt_ResetsBreaker(t *testing.T) { + t.Parallel() + + stub := &stubBackendClient{callErrs: []error{nil}} + c := newFastRetryClient(stub) + target := makeTarget("backend-reset-initial") + + // Prime the breaker with stale failures from a previous sequence. + breaker := c.getBreaker(target.WorkloadID) + breaker.consecutiveFails = 3 + + // A first-attempt success must reset the breaker counter to zero. + _, err := c.CallTool(context.Background(), target, "tool1", nil, nil) + require.NoError(t, err) + assert.Equal(t, 0, breaker.consecutiveFails) + assert.False(t, breaker.open) +} + +// TestRetryingBackendClient_SuccessAfterAuthFailure verifies that a single 401 is retried +// successfully and the operation returns the successful result. +func TestRetryingBackendClient_SuccessAfterAuthFailure(t *testing.T) { + t.Parallel() + + stub := &stubBackendClient{callErrs: []error{authErr("401 unauthorized"), nil}} + c := newFastRetryClient(stub) + target := makeTarget("backend-1") + + result, err := c.CallTool(context.Background(), target, "tool1", nil, nil) + + require.NoError(t, err) + assert.NotNil(t, result) + // First call + one retry + assert.Equal(t, 2, stub.calls) +} + +// TestRetryingBackendClient_MaxRetriesExhausted verifies that after maxAuthRetries, the +// last error is returned and no further retries are attempted. +func TestRetryingBackendClient_MaxRetriesExhausted(t *testing.T) { + t.Parallel() + + // All calls fail with auth error (1 initial + maxAuthRetries retries) + errs := make([]error, maxAuthRetries+1) + for i := range errs { + errs[i] = authErr("401 unauthorized") + } + stub := &stubBackendClient{callErrs: errs} + c := newFastRetryClient(stub) + target := makeTarget("backend-1") + + result, err := c.CallTool(context.Background(), target, "tool1", nil, nil) + + require.Error(t, err) + assert.Nil(t, result) + assert.True(t, errors.Is(err, vmcp.ErrAuthenticationFailed)) + // 1 initial attempt + maxAuthRetries retries + assert.Equal(t, maxAuthRetries+1, stub.calls) +} + +// TestRetryingBackendClient_NonAuthErrorNotRetried verifies that non-auth errors are +// returned immediately without any retry. +func TestRetryingBackendClient_NonAuthErrorNotRetried(t *testing.T) { + t.Parallel() + + nonAuthErr := fmt.Errorf("%w: connection refused", vmcp.ErrBackendUnavailable) + stub := &stubBackendClient{callErrs: []error{nonAuthErr}} + c := newFastRetryClient(stub) + target := makeTarget("backend-1") + + _, err := c.CallTool(context.Background(), target, "tool1", nil, nil) + + require.Error(t, err) + assert.True(t, errors.Is(err, vmcp.ErrBackendUnavailable)) + // Only 1 call — no retries for non-auth errors + assert.Equal(t, 1, stub.calls) +} + +// TestRetryingBackendClient_CircuitBreakerOpens verifies that after N consecutive auth +// failures the circuit breaker opens and further retries are skipped. +func TestRetryingBackendClient_CircuitBreakerOpens(t *testing.T) { + t.Parallel() + + stub := &stubBackendClient{} + // Always return auth error + for i := 0; i < 100; i++ { + stub.callErrs = append(stub.callErrs, authErr("401 unauthorized")) + } + + c := newFastRetryClient(stub) + c.cbThreshold = 2 // open after 2 consecutive failures + target := makeTarget("backend-cb") + + // Drive enough failures to open the circuit breaker. + // Each CallTool call: 1 initial + maxAuthRetries retries = 4 calls total. + // After cbThreshold (2) complete retry sequences fail, circuit opens. + for i := 0; i < c.cbThreshold; i++ { + _, err := c.CallTool(context.Background(), target, "tool1", nil, nil) + require.Error(t, err) + } + + // Circuit should now be open + breaker := c.getBreaker(target.WorkloadID) + assert.True(t, breaker.open, "circuit breaker should be open after threshold failures") + + // Further calls should not retry — only 1 attempt (the initial call) + callsBefore := stub.calls + _, err := c.CallTool(context.Background(), target, "tool1", nil, nil) + require.Error(t, err) + assert.Equal(t, callsBefore+1, stub.calls, "circuit open: should make exactly 1 call with no retries") +} + +// TestRetryingBackendClient_CircuitBreakerResetOnSuccess verifies that the circuit breaker +// resets its counter after a successful operation. +func TestRetryingBackendClient_CircuitBreakerResetOnSuccess(t *testing.T) { + t.Parallel() + + // Fail once, then succeed — should reset the failure counter + stub := &stubBackendClient{callErrs: []error{authErr("401"), nil}} + c := newFastRetryClient(stub) + c.cbThreshold = 2 + target := makeTarget("backend-reset") + + _, err := c.CallTool(context.Background(), target, "tool1", nil, nil) + require.NoError(t, err) + + breaker := c.getBreaker(target.WorkloadID) + assert.Equal(t, 0, breaker.consecutiveFails) + assert.False(t, breaker.open) +} + +// TestRetryingBackendClient_ConcurrentFailuresDeduplicated verifies that concurrent auth +// failures for the same backend result in only one backoff wait per attempt (via singleflight). +func TestRetryingBackendClient_ConcurrentFailuresDeduplicated(t *testing.T) { + t.Parallel() + + const concurrency = 10 + + // failWG counts down each time the stub returns a failure. The backoffFn + // (running inside singleflight) waits until all goroutines have completed + // their initial failing call — at that point the other 9 are already + // coalesced on sf.Do — then records the sleep and returns. + var failWG sync.WaitGroup + failWG.Add(concurrency) + + var opCount atomic.Int64 + inner := &countingBackendClient{ + callCount: &opCount, + failFirst: true, + failCount: concurrency, // first 'concurrency' calls return auth failure + onFail: failWG.Done, // called synchronously when a failure is returned + } + + var sleepCount atomic.Int64 + c := newFastRetryClient(inner) + // Inject a backoff hook that waits for all initial failures before proceeding. + // Because backoffFn runs inside singleflight.Do, it is called exactly once; + // all other goroutines block on sf.Do until this returns. Waiting for failWG + // here guarantees they have all arrived and are coalesced before we assert. + c.backoffFn = func(ctx context.Context, _ time.Duration) error { + done := make(chan struct{}) + go func() { failWG.Wait(); close(done) }() + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + } + sleepCount.Add(1) + return nil + } + target := makeTarget("backend-concurrent") + + var wg sync.WaitGroup + start := make(chan struct{}) + type callResult struct { + result *vmcp.ToolCallResult + err error + } + results := make(chan callResult, concurrency) + + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + result, err := c.CallTool(context.Background(), target, "tool1", nil, nil) + results <- callResult{result, err} + }() + } + + close(start) + wg.Wait() + close(results) + + // All goroutines should succeed. + for r := range results { + require.NoError(t, r.err) + assert.NotNil(t, r.result) + } + + // singleflight must have coalesced: backoffFn fires exactly once for attempt 1, + // not once per goroutine. + assert.Equal(t, int64(1), sleepCount.Load(), + "singleflight should coalesce backoff waits into a single sleep invocation") +} + +// countingBackendClient is a thread-safe stub that returns auth errors for the first +// failCount calls when failFirst is set, then succeeds for all subsequent calls. +// onFail, if set, is called synchronously each time a failure is returned. +type countingBackendClient struct { + callCount *atomic.Int64 + failFirst bool + failCount int // number of initial calls that return auth error (0 means just the first) + onFail func() // called each time a failure is returned +} + +func (c *countingBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any, _ map[string]any) (*vmcp.ToolCallResult, error) { + n := c.callCount.Add(1) + threshold := int64(1) + if c.failCount > 0 { + threshold = int64(c.failCount) + } + if c.failFirst && n <= threshold { + if c.onFail != nil { + c.onFail() + } + return nil, authErr("401 unauthorized") + } + return &vmcp.ToolCallResult{}, nil +} + +func (*countingBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) (*vmcp.ResourceReadResult, error) { + return &vmcp.ResourceReadResult{}, nil +} + +func (*countingBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (*vmcp.PromptGetResult, error) { + return &vmcp.PromptGetResult{}, nil +} + +func (*countingBackendClient) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + return &vmcp.CapabilityList{}, nil +} + +// TestRetryingBackendClient_AllMethods verifies that all four BackendClient methods go +// through the retry logic (success-after-failure scenario). +func TestRetryingBackendClient_AllMethods(t *testing.T) { + t.Parallel() + target := makeTarget("backend-all") + + t.Run("ReadResource retries on auth failure", func(t *testing.T) { + t.Parallel() + stub := &stubBackendClient{callErrs: []error{authErr("403 forbidden"), nil}} + c := newFastRetryClient(stub) + result, err := c.ReadResource(context.Background(), target, "res://foo") + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, 2, stub.calls) + }) + + t.Run("GetPrompt retries on auth failure", func(t *testing.T) { + t.Parallel() + stub := &stubBackendClient{callErrs: []error{authErr("401 unauthorized"), nil}} + c := newFastRetryClient(stub) + result, err := c.GetPrompt(context.Background(), target, "my-prompt", nil) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, 2, stub.calls) + }) + + t.Run("ListCapabilities retries on auth failure", func(t *testing.T) { + t.Parallel() + stub := &stubBackendClient{callErrs: []error{authErr("401 unauthorized"), nil}} + c := newFastRetryClient(stub) + result, err := c.ListCapabilities(context.Background(), target) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, 2, stub.calls) + }) +} + +// TestRetryingBackendClient_ContextCancellation verifies that a cancelled context aborts +// the retry backoff cleanly. +func TestRetryingBackendClient_ContextCancellation(t *testing.T) { + t.Parallel() + + stub := &stubBackendClient{} + // Always auth-fail so the retry loop is entered + for i := 0; i < 10; i++ { + stub.callErrs = append(stub.callErrs, authErr("401")) + } + + c := newFastRetryClient(stub) + // Use a long backoff so context cancellation can interrupt it + c.initialBackoff = 500 * time.Millisecond + target := makeTarget("backend-ctx") + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err := c.CallTool(ctx, target, "tool1", nil, nil) + require.Error(t, err) + // Should get context deadline exceeded, not an auth error + assert.True(t, errors.Is(err, context.DeadlineExceeded)) +} diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index 01642d6641..e96a1b095a 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -79,7 +79,44 @@ func NewHTTPBackendClient(registry vmcpauth.OutgoingAuthRegistry) (vmcp.BackendC registry: registry, } c.clientFactory = c.defaultClientFactory - return c, nil + return newRetryingBackendClient(c, registry), nil +} + +// httpStatusRoundTripper converts HTTP 401, 403, and 5xx responses into structured +// sentinel errors before mcp-go processes the response. This enables type-safe +// errors.Is() checks throughout the error-handling chain without string matching. +type httpStatusRoundTripper struct { + base http.RoundTripper +} + +// RoundTrip implements http.RoundTripper. It intercepts authentication and server +// error status codes, converting them to sentinel errors and closing the response body. +func (h *httpStatusRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := h.base.RoundTrip(req) + if err != nil { + return nil, err + } + switch resp.StatusCode { + case http.StatusUnauthorized, http.StatusForbidden: + drainAndClose(resp.Body, resp.StatusCode) + return nil, fmt.Errorf("%w: HTTP %d", vmcp.ErrAuthenticationFailed, resp.StatusCode) + case http.StatusInternalServerError, http.StatusBadGateway, + http.StatusServiceUnavailable, http.StatusGatewayTimeout: + drainAndClose(resp.Body, resp.StatusCode) + return nil, fmt.Errorf("%w: HTTP %d", vmcp.ErrBackendUnavailable, resp.StatusCode) + } + return resp, nil +} + +// drainAndClose drains up to maxResponseSize bytes from the body before closing it, +// allowing the underlying TCP connection to be reused by the transport. +func drainAndClose(body io.ReadCloser, statusCode int) { + if _, err := io.Copy(io.Discard, io.LimitReader(body, maxResponseSize)); err != nil { + slog.Debug("failed to drain response body", "status", statusCode, "error", err) + } + if err := body.Close(); err != nil { + slog.Debug("failed to close response body", "status", statusCode, "error", err) + } } // roundTripperFunc is a function adapter for http.RoundTripper. @@ -236,7 +273,7 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm return resp, nil }) httpClient := &http.Client{ - Transport: sizeLimitedTransport, + Transport: &httpStatusRoundTripper{base: sizeLimitedTransport}, Timeout: 30 * time.Second, } c, err = client.NewStreamableHttpClient( @@ -253,7 +290,7 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm // Applying io.LimitReader would silently terminate the stream after // maxResponseSize cumulative bytes — not per-event — which is wrong. // http.Client.Timeout is also omitted: it would kill the stream. - httpClient := &http.Client{Transport: baseTransport} + httpClient := &http.Client{Transport: &httpStatusRoundTripper{base: baseTransport}} c, err = client.NewSSEMCPClient( target.BaseURL, transport.WithHTTPClient(httpClient), @@ -321,21 +358,29 @@ func wrapBackendError(err error, backendID string, operation string) error { vmcp.ErrTimeout, operation, backendID, err) } - // 4. String-based detection: Fall back to pattern matching for cases where - // we don't have structured error types (MCP SDK, HTTP libraries with embedded status codes) - // Authentication errors (401, 403, auth failures) + // 4. Sentinel errors set by our httpStatusRoundTripper at the HTTP layer. + // These cover 401/403 (auth) and 5xx (backend unavailable) responses. + if errors.Is(err, vmcp.ErrAuthenticationFailed) { + return fmt.Errorf("%w: failed to %s for backend %s: %v", + vmcp.ErrAuthenticationFailed, operation, backendID, err) + } + if errors.Is(err, vmcp.ErrBackendUnavailable) { + return fmt.Errorf("%w: failed to %s for backend %s: %v", + vmcp.ErrBackendUnavailable, operation, backendID, err) + } + + // 5. String-based fallback for errors not covered above (e.g. from test stubs, + // non-HTTP transports, or external libraries that don't use sentinel errors). if vmcp.IsAuthenticationError(err) { return fmt.Errorf("%w: failed to %s for backend %s: %v", vmcp.ErrAuthenticationFailed, operation, backendID, err) } - // Timeout errors (deadline exceeded, timeout messages) if vmcp.IsTimeoutError(err) { return fmt.Errorf("%w: failed to %s for backend %s (timeout): %v", vmcp.ErrTimeout, operation, backendID, err) } - // Connection errors (refused, reset, unreachable) if vmcp.IsConnectionError(err) { return fmt.Errorf("%w: failed to %s for backend %s (connection error): %v", vmcp.ErrBackendUnavailable, operation, backendID, err) @@ -563,8 +608,7 @@ func (h *httpBackendClient) CallTool( }, }) if err != nil { - // Network/connection errors are operational errors - return nil, fmt.Errorf("%w: tool call failed on backend %s: %w", vmcp.ErrBackendUnavailable, target.WorkloadID, err) + return nil, wrapBackendError(err, target.WorkloadID, "call tool") } // Extract _meta field from backend response @@ -663,7 +707,7 @@ func (h *httpBackendClient) ReadResource( }, }) if err != nil { - return nil, fmt.Errorf("resource read failed on backend %s: %w", target.WorkloadID, err) + return nil, wrapBackendError(err, target.WorkloadID, "read resource") } // Concatenate all resource content items into a single byte slice. @@ -723,7 +767,7 @@ func (h *httpBackendClient) GetPrompt( }, }) if err != nil { - return nil, fmt.Errorf("prompt get failed on backend %s: %w", target.WorkloadID, err) + return nil, wrapBackendError(err, target.WorkloadID, "get prompt") } return &vmcp.PromptGetResult{ diff --git a/pkg/vmcp/client/client_test.go b/pkg/vmcp/client/client_test.go index 5fae752c91..ebb3727eca 100644 --- a/pkg/vmcp/client/client_test.go +++ b/pkg/vmcp/client/client_test.go @@ -129,7 +129,8 @@ func TestDefaultClientFactory_UnsupportedTransport(t *testing.T) { backendClient, err := NewHTTPBackendClient(mockRegistry) require.NoError(t, err) - httpClient := backendClient.(*httpBackendClient) + retryClient := backendClient.(*retryingBackendClient) + httpClient := retryClient.inner.(*httpBackendClient) _, err = httpClient.defaultClientFactory(context.Background(), target) @@ -750,7 +751,8 @@ func TestResolveAuthStrategy(t *testing.T) { backendClient, err := NewHTTPBackendClient(registry) require.NoError(t, err) - httpClient := backendClient.(*httpBackendClient) + retryClient := backendClient.(*retryingBackendClient) + httpClient := retryClient.inner.(*httpBackendClient) // Call resolveAuthStrategy strategy, err := httpClient.resolveAuthStrategy(tt.target) diff --git a/pkg/vmcp/errors.go b/pkg/vmcp/errors.go index 1bc038ca1f..492b2fae38 100644 --- a/pkg/vmcp/errors.go +++ b/pkg/vmcp/errors.go @@ -82,6 +82,29 @@ var ( // code should prefer errors.Is() checks over these string-based functions. // These functions remain for backwards compatibility and as a fallback mechanism. +// authErrorPatterns lists lowercase substrings that identify authentication errors. +// Patterns cover multiple error formats: +// - Standard HTTP: "401 unauthorized", "403 forbidden", "http 401", "status code 401" +// - mcp-go SDK: ErrUnauthorized = "unauthorized (401)" +// - mcp-go generic: "request failed with status 401/403: ..." +// - Explicit messages: "authentication failed", "request unauthenticated", "access denied" +var authErrorPatterns = []string{ + "authentication failed", + "authentication error", + "401 unauthorized", + "403 forbidden", + "http 401", + "http 403", + "status code 401", + "status code 403", + "unauthorized (401)", + "request failed with status 401", + "request failed with status 403", + "request unauthenticated", + "request unauthorized", + "access denied", +} + // IsAuthenticationError checks if an error message indicates an authentication failure. // Uses case-insensitive pattern matching to detect various auth error formats from // HTTP libraries, MCP protocol errors, and authentication middleware. @@ -89,33 +112,12 @@ func IsAuthenticationError(err error) bool { if err == nil { return false } - errLower := strings.ToLower(err.Error()) - - // Check for explicit authentication failure messages - if strings.Contains(errLower, "authentication failed") || - strings.Contains(errLower, "authentication error") { - return true - } - - // Check for HTTP 401/403 status codes with context - // Match patterns like "401 Unauthorized", "HTTP 401", "status code 401" - if strings.Contains(errLower, "401 unauthorized") || - strings.Contains(errLower, "403 forbidden") || - strings.Contains(errLower, "http 401") || - strings.Contains(errLower, "http 403") || - strings.Contains(errLower, "status code 401") || - strings.Contains(errLower, "status code 403") { - return true - } - - // Check for explicit unauthenticated/unauthorized errors - if strings.Contains(errLower, "request unauthenticated") || - strings.Contains(errLower, "request unauthorized") || - strings.Contains(errLower, "access denied") { - return true + for _, pattern := range authErrorPatterns { + if strings.Contains(errLower, pattern) { + return true + } } - return false } @@ -133,6 +135,17 @@ func IsTimeoutError(err error) bool { strings.Contains(errLower, "context deadline exceeded") } +// connectionErrorPatterns lists lowercase substrings that identify connection failures. +// Covers network-level errors, broken pipes, and HTTP 5xx server errors. +var connectionErrorPatterns = []string{ + "connection refused", "connection reset", "no route to host", + "network is unreachable", "broken pipe", "connection closed", + "500 internal server error", "502 bad gateway", + "503 service unavailable", "504 gateway timeout", + "status code 500", "status code 502", + "status code 503", "status code 504", +} + // IsConnectionError checks if an error message indicates a connection failure. // Detects network-level errors like connection refused, reset, unreachable, etc. // Also detects broken pipes, EOF errors, and HTTP 5xx server errors that indicate @@ -141,38 +154,16 @@ func IsConnectionError(err error) bool { if err == nil { return false } - errStr := err.Error() - errLower := strings.ToLower(errStr) - - // Check against list of known connection error patterns - networkPatterns := []string{ - "connection refused", "connection reset", "no route to host", - "network is unreachable", "broken pipe", "connection closed", - } - for _, pattern := range networkPatterns { - if strings.Contains(errLower, pattern) { - return true - } - } - // EOF errors (be specific - check exact case to avoid false positives) if strings.Contains(errStr, "EOF") { return true } - - // HTTP 5xx server errors - httpErrorPatterns := []string{ - "500 internal server error", "502 bad gateway", - "503 service unavailable", "504 gateway timeout", - "status code 500", "status code 502", - "status code 503", "status code 504", - } - for _, pattern := range httpErrorPatterns { + errLower := strings.ToLower(errStr) + for _, pattern := range connectionErrorPatterns { if strings.Contains(errLower, pattern) { return true } } - return false } diff --git a/test/e2e/thv-operator/virtualmcp/helpers.go b/test/e2e/thv-operator/virtualmcp/helpers.go index 337e232f7e..4ec39c020a 100644 --- a/test/e2e/thv-operator/virtualmcp/helpers.go +++ b/test/e2e/thv-operator/virtualmcp/helpers.go @@ -26,6 +26,7 @@ import ( corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/client-go/kubernetes" @@ -78,6 +79,27 @@ func WaitForVirtualMCPServerReady( }, timeout, pollingInterval).Should(gomega.Succeed()) } +// WaitForVirtualMCPServerPod waits for the VirtualMCPServer's own pod to be running and +// ready, without requiring any backend health condition to be True. Use this when the +// VirtualMCPServer is expected to remain Degraded (e.g., when testing auth-retry +// exhaustion with a permanently-failing backend). +func WaitForVirtualMCPServerPod( + ctx context.Context, + c client.Client, + name, namespace string, + timeout time.Duration, + pollingInterval time.Duration, +) { + labels := map[string]string{ + "app.kubernetes.io/name": "virtualmcpserver", + "app.kubernetes.io/instance": name, + } + gomega.Eventually(func() error { + return checkPodsReady(ctx, c, namespace, labels) + }, timeout, pollingInterval).Should(gomega.Succeed(), + "VirtualMCPServer pod should be running") +} + // checkPodsReady waits for at least one pod matching the given labels to be ready. // This is used when checking for a single expected pod (e.g., one replica deployment). // Pods not in Running phase are skipped (e.g., Succeeded, Failed from previous deployments). @@ -811,6 +833,18 @@ type BackendConfig struct { // defaultMCPServerResources() is used to ensure containers are scheduled // with reasonable resource guarantees and do not compete excessively. Resources *mcpv1alpha1.ResourceRequirements + // Args are extra arguments passed to the MCP server image entry-point. + // Useful for images like python:3.x-slim where the script is provided + // inline (e.g. Args: []string{"-c", ""}). + Args []string + // PodTemplateSpec is an optional JSON-encoded patch applied to the pod + // spec that the toolhive runner creates for the MCP server container. + // Use this to override readiness probes, resource limits, etc. + PodTemplateSpec *runtime.RawExtension + // SkipReadinessWait skips waiting for this backend to reach MCPServerPhaseRunning. + // Use this for backends that are intentionally broken (e.g. persistent 401 servers) + // where the operator is expected to mark the MCPServer as Failed. + SkipReadinessWait bool } // defaultMCPServerResources returns conservative resource requests/limits that @@ -864,6 +898,8 @@ func CreateMultipleMCPServersInParallel( ExternalAuthConfigRef: backends[idx].ExternalAuthConfigRef, Secrets: backends[idx].Secrets, Resources: resources, + Args: backends[idx].Args, + PodTemplateSpec: backends[idx].PodTemplateSpec, Env: append([]mcpv1alpha1.EnvVar{ {Name: "TRANSPORT", Value: backendTransport}, }, backends[idx].Env...), @@ -872,9 +908,14 @@ func CreateMultipleMCPServersInParallel( gomega.Expect(c.Create(ctx, backend)).To(gomega.Succeed()) } - // Wait for all backends to be ready in parallel (single Eventually checking all) + // Wait for all backends that require readiness to reach Running phase. + // Backends with SkipReadinessWait=true are created but excluded from this check + // (e.g. intentionally broken servers expected to be marked Failed by the operator). gomega.Eventually(func() error { for _, cfg := range backends { + if cfg.SkipReadinessWait { + continue + } server := &mcpv1alpha1.MCPServer{} err := c.Get(ctx, types.NamespacedName{ Name: cfg.Name, @@ -891,7 +932,7 @@ func CreateMultipleMCPServersInParallel( return fmt.Errorf("%s not ready yet, phase: %s", cfg.Name, server.Status.Phase) } } - // All backends are ready + // All watched backends are ready return nil }, timeout, pollingInterval).Should(gomega.Succeed(), "All MCPServers should be ready") } @@ -1805,6 +1846,9 @@ func DeployMockOAuth2Server( // ---- /status and /api/backends/health HTTP helpers ---- +// backendHealthStatusHealthy is the health status string for a healthy backend. +const backendHealthStatusHealthy = "healthy" + // VMCPStatusResponse mirrors server.StatusResponse // (pkg/vmcp/server/status.go) for test deserialization. type VMCPStatusResponse struct { diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_retry_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_retry_test.go new file mode 100644 index 0000000000..bbf8144c9e --- /dev/null +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_retry_test.go @@ -0,0 +1,287 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package virtualmcp + +import ( + "encoding/json" + "fmt" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" + + mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/test/e2e/images" +) + +// persistent401BackendScript is an inline Python HTTP server that returns +// HTTP 401 Unauthorized for every request on port 8080. +// It simulates a backend whose credentials are permanently invalid, letting us +// verify that retryingBackendClient exhausts maxAuthRetries (3) and surfaces +// ErrAuthenticationFailed → BackendUnauthenticated → BackendStatusUnavailable. +// +// ThreadingMixIn + HTTPServer is used instead of bare TCPServer so that: +// - Concurrent connections from the ToolHive proxy are handled in separate threads +// - BrokenPipeError from abruptly-closed connections does not crash the process +// - allow_reuse_address avoids "Address already in use" on pod restart +const persistent401BackendScript = `import http.server,socketserver +class H(http.server.BaseHTTPRequestHandler): + def do_GET(self): + try:self.send_response(401);self.end_headers() + except Exception:pass + do_POST=do_PUT=do_DELETE=do_PATCH=do_HEAD=do_OPTIONS=do_GET + def log_message(self,*a):pass + def handle_error(self,r,a):pass +class S(socketserver.ThreadingMixIn,http.server.HTTPServer): + allow_reuse_address=True + daemon_threads=True +S(('',8080),H).serve_forever()` + +// build401PodTemplateSpec returns a PodTemplateSpec patch that replaces the +// default HTTP readiness probe on the "mcp" container with a TCP socket probe. +// Without this, the runner's HTTP GET /health probe would receive 401 and the +// container would never become Ready. +func build401PodTemplateSpec() *runtime.RawExtension { + podTemplateSpec := corev1.PodTemplateSpec{ + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "mcp", + ReadinessProbe: &corev1.Probe{ + ProbeHandler: corev1.ProbeHandler{ + TCPSocket: &corev1.TCPSocketAction{ + Port: intstr.FromInt(8080), + }, + }, + InitialDelaySeconds: 2, + PeriodSeconds: 2, + TimeoutSeconds: 5, + FailureThreshold: 10, + }, + }, + }, + }, + } + raw, err := json.Marshal(podTemplateSpec) + Expect(err).ToNot(HaveOccurred(), "should marshal PodTemplateSpec to JSON") + return &runtime.RawExtension{Raw: raw} +} + +// TestAuthRetry_PersistentUnauthorized_BackendMarkedUnauthenticated verifies the +// end-to-end auth-retry pipeline in a live Kubernetes cluster: +// +// 1. A backend MCPServer runs a Python HTTP server returning 401 for every request. +// 2. retryingBackendClient intercepts the 401, retries up to maxAuthRetries (3) +// times with exponential back-off, then returns ErrAuthenticationFailed. +// 3. The health monitor maps this to BackendUnauthenticated → BackendStatusUnavailable. +// 4. A co-located healthy backend (yardstick) stays Ready throughout. +var _ = Describe("VirtualMCPServer Auth Retry Exhaustion", Ordered, func() { + var ( + testNamespace = "default" + mcpGroupName = "test-auth-retry-group" + vmcpServerName = "test-vmcp-auth-retry" + stableBackend = "backend-auth-stable" + failingBackend = "backend-auth-failing-401" + timeout = 3 * time.Minute + pollInterval = 2 * time.Second + ) + + BeforeAll(func() { + By("Creating MCPGroup for auth retry tests") + CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, testNamespace, + "Test MCP Group for auth retry E2E tests", timeout, pollInterval) + + By("Creating stable and persistent-401 backend MCPServers") + CreateMultipleMCPServersInParallel(ctx, k8sClient, []BackendConfig{ + { + Name: stableBackend, + Namespace: testNamespace, + GroupRef: mcpGroupName, + Image: images.YardstickServerImage, + }, + { + Name: failingBackend, + Namespace: testNamespace, + GroupRef: mcpGroupName, + Image: images.PythonImage, + // Pass the inline 401 server script to the Python interpreter. + Args: []string{"-c", persistent401BackendScript}, + // Replace the default HTTP readiness probe with a TCP one so + // the container becomes Ready as soon as port 8080 is open, + // regardless of the HTTP 401 responses it serves. + PodTemplateSpec: build401PodTemplateSpec(), + // The operator will mark this MCPServer as Failed because every + // MCP request returns 401. Skip the readiness gate so BeforeAll + // does not stop-trying on the expected Failed phase. + SkipReadinessWait: true, + }, + }, timeout, pollInterval) + + By("Creating VirtualMCPServer") + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: vmcpServerName, + Namespace: testNamespace, + }, + Spec: mcpv1alpha1.VirtualMCPServerSpec{ + IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &mcpv1alpha1.OutgoingAuthConfig{ + Source: "discovered", + }, + ServiceType: "NodePort", + Config: vmcpconfig.Config{ + Name: vmcpServerName, + Group: mcpGroupName, + Aggregation: &vmcpconfig.AggregationConfig{ + ConflictResolution: "prefix", + }, + }, + }, + } + Expect(k8sClient.Create(ctx, vmcpServer)).To(Succeed()) + + By("Waiting for VirtualMCPServer pod to be running") + // The VirtualMCPServer will be Degraded (not Ready) because one backend always + // returns 401. Wait for the pod itself to be up so health-checking can proceed. + WaitForVirtualMCPServerPod(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollInterval) + }) + + AfterAll(func() { + By("Cleaning up auth retry test resources") + vmcpServer := &mcpv1alpha1.VirtualMCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: vmcpServerName, + Namespace: testNamespace, + }, vmcpServer); err == nil { + Expect(k8sClient.Delete(ctx, vmcpServer)).To(Succeed()) + } + + for _, name := range []string{stableBackend, failingBackend} { + server := &mcpv1alpha1.MCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: name, + Namespace: testNamespace, + }, server); err == nil { + Expect(k8sClient.Delete(ctx, server)).To(Succeed()) + } + } + + group := &mcpv1alpha1.MCPGroup{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: mcpGroupName, + Namespace: testNamespace, + }, group); err == nil { + Expect(k8sClient.Delete(ctx, group)).To(Succeed()) + } + }) + + It("should mark the 401 backend as unavailable after auth retries are exhausted", func() { + // retryingBackendClient retries ListCapabilities up to maxAuthRetries (3) + // times. After exhaustion it returns ErrAuthenticationFailed, which the + // health monitor maps to BackendUnauthenticated → "unavailable" in the CRD. + Eventually(func() error { + vmcpServer := &mcpv1alpha1.VirtualMCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: vmcpServerName, + Namespace: testNamespace, + }, vmcpServer); err != nil { + return err + } + + var failingB *mcpv1alpha1.DiscoveredBackend + for i := range vmcpServer.Status.DiscoveredBackends { + if vmcpServer.Status.DiscoveredBackends[i].Name == failingBackend { + failingB = &vmcpServer.Status.DiscoveredBackends[i] + break + } + } + if failingB == nil { + return fmt.Errorf("401 backend %q not yet in discovered backends", failingBackend) + } + if failingB.Status != mcpv1alpha1.BackendStatusUnavailable { + return fmt.Errorf("expected status %q, got %q (message: %s)", + mcpv1alpha1.BackendStatusUnavailable, failingB.Status, failingB.Message) + } + + GinkgoWriter.Printf("✓ 401 backend unavailable (status: %s, message: %s)\n", + failingB.Status, failingB.Message) + return nil + }, timeout, pollInterval).Should(Succeed()) + }) + + It("should transition VirtualMCPServer to Degraded when a backend is unavailable", func() { + Eventually(func() error { + vmcpServer := &mcpv1alpha1.VirtualMCPServer{} + if err := k8sClient.Get(ctx, types.NamespacedName{ + Name: vmcpServerName, + Namespace: testNamespace, + }, vmcpServer); err != nil { + return err + } + if vmcpServer.Status.Phase != mcpv1alpha1.VirtualMCPServerPhaseDegraded && + vmcpServer.Status.Phase != mcpv1alpha1.VirtualMCPServerPhaseFailed { + return fmt.Errorf("expected phase Degraded or Failed, got: %s", + vmcpServer.Status.Phase) + } + GinkgoWriter.Printf("✓ VirtualMCPServer phase: %s\n", vmcpServer.Status.Phase) + return nil + }, timeout, pollInterval).Should(Succeed()) + }) + + It("should keep the stable backend ready throughout the auth failure", func() { + // Auth failures are isolated per-backend via the per-backend circuit breaker. + vmcpServer := &mcpv1alpha1.VirtualMCPServer{} + Expect(k8sClient.Get(ctx, types.NamespacedName{ + Name: vmcpServerName, + Namespace: testNamespace, + }, vmcpServer)).To(Succeed()) + + var stableB *mcpv1alpha1.DiscoveredBackend + for i := range vmcpServer.Status.DiscoveredBackends { + if vmcpServer.Status.DiscoveredBackends[i].Name == stableBackend { + stableB = &vmcpServer.Status.DiscoveredBackends[i] + break + } + } + Expect(stableB).NotTo(BeNil(), "stable backend should be in discovered backends list") + Expect(stableB.Status).To(Or( + Equal(mcpv1alpha1.BackendStatusReady), + Equal(mcpv1alpha1.BackendStatusDegraded)), + "stable backend should remain healthy; got status=%s message=%s", + stableB.Status, stableB.Message) + + GinkgoWriter.Printf("✓ Stable backend remained healthy: status=%s\n", stableB.Status) + }) + + It("should report the 401 backend as unavailable in /api/backends/health", func() { + nodePort := GetVMCPNodePort(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollInterval) + + Eventually(func() error { + bh, err := GetVMCPBackendsHealth(nodePort) + if err != nil { + return fmt.Errorf("GET /api/backends/health: %w", err) + } + if !bh.MonitoringEnabled { + return fmt.Errorf("monitoring not enabled") + } + state, found := bh.Backends[failingBackend] + if !found { + return fmt.Errorf("401 backend %q not found in /api/backends/health", failingBackend) + } + if state.Status == backendHealthStatusHealthy { + return fmt.Errorf("401 backend still reported as healthy") + } + GinkgoWriter.Printf("✓ /api/backends/health: %s → %s\n", failingBackend, state.Status) + return nil + }, timeout, pollInterval).Should(Succeed()) + }) +}) diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_circuit_breaker_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_circuit_breaker_test.go index ce34ae8882..a6439fce25 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_circuit_breaker_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_circuit_breaker_test.go @@ -421,14 +421,14 @@ var _ = Describe("VirtualMCPServer Circuit Breaker Lifecycle", Ordered, func() { if !inHealth { return fmt.Errorf("unstable backend %q not found in /api/backends/health", backend2Name) } - if unstableHealthState.Status == "healthy" { + if unstableHealthState.Status == backendHealthStatusHealthy { return fmt.Errorf("unstable backend %q still healthy in /api/backends/health", backend2Name) } unstableStatusHealth, inStatus := statusHealthByName[backend2Name] if !inStatus { return fmt.Errorf("unstable backend %q not found in /status", backend2Name) } - if unstableStatusHealth == "healthy" { + if unstableStatusHealth == backendHealthStatusHealthy { return fmt.Errorf("unstable backend %q still healthy in /status (issue #4103 regression)", backend2Name) } diff --git a/test/integration/vmcp/helpers/backend.go b/test/integration/vmcp/helpers/backend.go index c5594b625b..d991809ed7 100644 --- a/test/integration/vmcp/helpers/backend.go +++ b/test/integration/vmcp/helpers/backend.go @@ -157,6 +157,7 @@ type backendServerConfig struct { withPrompts bool captureHeaders bool httpContextFunc server.HTTPContextFunc + httpMiddleware func(http.Handler) http.Handler } // WithBackendName sets the backend server name. @@ -180,6 +181,31 @@ func WithCaptureHeaders() BackendServerOption { } } +// WithHTTPMiddleware wraps the backend's HTTP handler with the given middleware. +// The middleware runs outside the MCP streamable-HTTP handler, so it intercepts +// requests before they reach the MCP layer (including before any header-capture +// configured via WithCaptureHeaders). This allows tests to inject custom HTTP +// behaviour such as returning error status codes for the first N requests +// (simulating transient auth failures). +// +// Example: return 401 for the first request, then pass through: +// +// var count atomic.Int32 +// helpers.WithHTTPMiddleware(func(next http.Handler) http.Handler { +// return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// if count.Add(1) == 1 { +// w.WriteHeader(http.StatusUnauthorized) +// return +// } +// next.ServeHTTP(w, r) +// }) +// }) +func WithHTTPMiddleware(middleware func(http.Handler) http.Handler) BackendServerOption { + return func(c *backendServerConfig) { + c.httpMiddleware = middleware + } +} + // CreateBackendServer creates an MCP backend server using the mark3labs/mcp-go SDK. // It returns an *httptest.Server ready to accept streamable-HTTP connections. // @@ -304,8 +330,14 @@ func CreateBackendServer(tb testing.TB, tools []BackendTool, opts ...BackendServ streamableOpts..., ) + // Wrap with optional HTTP middleware (e.g., to inject transient HTTP errors in tests) + var handler http.Handler = streamableServer + if config.httpMiddleware != nil { + handler = config.httpMiddleware(handler) + } + // Start HTTP test server - httpServer := httptest.NewServer(streamableServer) + httpServer := httptest.NewServer(handler) tb.Logf("Created MCP backend server %q (v%s) at %s%s", config.serverName,