diff --git a/pkg/transport/middleware/write_timeout.go b/pkg/transport/middleware/write_timeout.go new file mode 100644 index 0000000000..85a46fb251 --- /dev/null +++ b/pkg/transport/middleware/write_timeout.go @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package middleware + +import ( + "log/slog" + "net/http" + "strings" + "time" +) + +// WriteTimeout clears the write deadline for qualifying SSE connections +// (GET + Accept: text/event-stream + matching path) so http.Server.WriteTimeout +// does not kill long-lived streams (golang/go#16100). All other requests are +// left untouched. +func WriteTimeout(endpointPath string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && + strings.Contains(r.Header.Get("Accept"), "text/event-stream") && + r.URL.Path == endpointPath { + rc := http.NewResponseController(w) + if err := rc.SetWriteDeadline(time.Time{}); err != nil { + slog.Warn("failed to clear write deadline for SSE connection; stream may be killed by server WriteTimeout", + "error", err, + "method", r.Method, + "path", r.URL.Path, + "remote", r.RemoteAddr, + ) + } + } + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/transport/middleware/write_timeout_test.go b/pkg/transport/middleware/write_timeout_test.go new file mode 100644 index 0000000000..7cd9c9857c --- /dev/null +++ b/pkg/transport/middleware/write_timeout_test.go @@ -0,0 +1,256 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package middleware_test + +import ( + "bufio" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/transport/middleware" +) + +const testEndpointPath = "/mcp" + +// deadlineTrackingResponseWriter wraps httptest.ResponseRecorder and implements +// the SetWriteDeadline method so http.ResponseController can call it. +// It records whether SetWriteDeadline was called and the deadline value passed. +type deadlineTrackingResponseWriter struct { + *httptest.ResponseRecorder + deadlineSet bool + deadline time.Time +} + +func (d *deadlineTrackingResponseWriter) SetWriteDeadline(t time.Time) error { + d.deadlineSet = true + d.deadline = t + return nil +} + +func newDeadlineTracker() *deadlineTrackingResponseWriter { + return &deadlineTrackingResponseWriter{ + ResponseRecorder: httptest.NewRecorder(), + } +} + +var noopHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) +}) + +func mw(next http.Handler) http.Handler { + return middleware.WriteTimeout(testEndpointPath)(next) +} + +// TestWriteTimeout_SSERequestClearsDeadline verifies that a qualifying SSE request +// (GET + Accept: text/event-stream + correct path) has its write deadline cleared +// (set to zero), overriding the server-level WriteTimeout. +func TestWriteTimeout_SSERequestClearsDeadline(t *testing.T) { + t.Parallel() + + w := newDeadlineTracker() + r := httptest.NewRequest(http.MethodGet, testEndpointPath, nil) + r.Header.Set("Accept", "text/event-stream") + + mw(noopHandler).ServeHTTP(w, r) + + require.True(t, w.deadlineSet, "qualifying SSE request must call SetWriteDeadline") + assert.True(t, w.deadline.IsZero(), "deadline must be zero (no deadline) to override server WriteTimeout") + assert.Equal(t, http.StatusOK, w.Code) +} + +// TestWriteTimeout_GETWithoutAcceptHeaderLeavesDeadlineUntouched verifies that a GET +// request lacking Accept: text/event-stream is not treated as SSE and the middleware +// does not touch its write deadline, leaving http.Server.WriteTimeout in effect. +func TestWriteTimeout_GETWithoutAcceptHeaderLeavesDeadlineUntouched(t *testing.T) { + t.Parallel() + + w := newDeadlineTracker() + r := httptest.NewRequest(http.MethodGet, testEndpointPath, nil) + + mw(noopHandler).ServeHTTP(w, r) + + assert.False(t, w.deadlineSet, "non-SSE GET must not have its deadline touched; server WriteTimeout remains in effect") + assert.Equal(t, http.StatusOK, w.Code) +} + +// TestWriteTimeout_GETOnWrongPathLeavesDeadlineUntouched verifies that a GET request +// with the SSE Accept header but targeting a non-MCP path (e.g. /health) is not treated +// as SSE and the middleware does not touch its write deadline. +func TestWriteTimeout_GETOnWrongPathLeavesDeadlineUntouched(t *testing.T) { + t.Parallel() + + w := newDeadlineTracker() + r := httptest.NewRequest(http.MethodGet, "/health", nil) + r.Header.Set("Accept", "text/event-stream") + + mw(noopHandler).ServeHTTP(w, r) + + assert.False(t, w.deadlineSet, "GET on non-MCP path must not have its deadline touched; server WriteTimeout remains in effect") + assert.Equal(t, http.StatusOK, w.Code) +} + +// TestWriteTimeout_POSTLeavesDeadlineUntouched verifies that POST requests are not +// touched by the middleware — their deadline comes from http.Server.WriteTimeout. +func TestWriteTimeout_POSTLeavesDeadlineUntouched(t *testing.T) { + t.Parallel() + + w := newDeadlineTracker() + r := httptest.NewRequest(http.MethodPost, testEndpointPath, nil) + + mw(noopHandler).ServeHTTP(w, r) + + assert.False(t, w.deadlineSet, "POST deadline is managed by http.Server.WriteTimeout, not the middleware") + assert.Equal(t, http.StatusOK, w.Code) +} + +// TestWriteTimeout_DELETELeavesDeadlineUntouched verifies DELETE is also left alone. +func TestWriteTimeout_DELETELeavesDeadlineUntouched(t *testing.T) { + t.Parallel() + + w := newDeadlineTracker() + r := httptest.NewRequest(http.MethodDelete, testEndpointPath, nil) + + mw(noopHandler).ServeHTTP(w, r) + + assert.False(t, w.deadlineSet, "DELETE deadline is managed by http.Server.WriteTimeout, not the middleware") + assert.Equal(t, http.StatusOK, w.Code) +} + +// TestWriteTimeout_HandlerIsAlwaysCalled verifies the inner handler is invoked for +// every HTTP method, regardless of deadline management. +func TestWriteTimeout_HandlerIsAlwaysCalled(t *testing.T) { + t.Parallel() + + cases := []struct { + method string + path string + accept string + }{ + {http.MethodGet, testEndpointPath, "text/event-stream"}, // qualifying SSE + {http.MethodGet, testEndpointPath, ""}, // GET, no Accept + {http.MethodGet, "/health", "text/event-stream"}, // GET, wrong path + {http.MethodPost, testEndpointPath, ""}, + {http.MethodDelete, testEndpointPath, ""}, + } + + for _, tc := range cases { + t.Run(tc.method+tc.path+tc.accept, func(t *testing.T) { + t.Parallel() + + called := false + handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + called = true + w.WriteHeader(http.StatusOK) + }) + + w := newDeadlineTracker() + r := httptest.NewRequest(tc.method, tc.path, nil) + if tc.accept != "" { + r.Header.Set("Accept", tc.accept) + } + mw(handler).ServeHTTP(w, r) + + assert.True(t, called, "inner handler must be called for %s %s", tc.method, tc.path) + }) + } +} + +// TestWriteTimeout_SSEStreamSurvivesTimeout verifies over a real TCP connection (with +// http.Server.WriteTimeout set) that a qualifying SSE stream is NOT killed after the +// write timeout elapses. +// +// This is the end-to-end proof of the fix for the SSE connection drop bug +// (golang/go#16100): the middleware clears the per-connection write deadline for +// qualifying SSE requests via http.ResponseController.SetWriteDeadline(time.Time{}), +// keeping SSE streams alive past the server-level WriteTimeout. +func TestWriteTimeout_SSEStreamSurvivesTimeout(t *testing.T) { + t.Parallel() + + const shortTimeout = 100 * time.Millisecond + const streamDuration = 3 * shortTimeout + + sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + require.True(t, ok, "ResponseWriter must implement http.Flusher") + + ticker := time.NewTicker(shortTimeout / 5) + defer ticker.Stop() + deadline := time.NewTimer(streamDuration) + defer deadline.Stop() + + for { + select { + case <-r.Context().Done(): + return + case <-deadline.C: + return + case <-ticker.C: + fmt.Fprintf(w, "data: ping\n\n") + flusher.Flush() + } + } + }) + + ts := httptest.NewUnstartedServer(middleware.WriteTimeout(testEndpointPath)(sseHandler)) + ts.Config.WriteTimeout = shortTimeout + ts.Start() + t.Cleanup(ts.Close) + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, ts.URL+testEndpointPath, nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + + start := time.Now() + + resp, err := ts.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + require.Equal(t, http.StatusOK, resp.StatusCode) + + // tickInterval is shortTimeout/5; over the full streamDuration we expect + // ~streamDuration/tickInterval = 15 events. If WriteTimeout fires early + // (after shortTimeout = 100 ms) at most shortTimeout/tickInterval = 5 + // events could arrive before the connection is killed. + const tickInterval = shortTimeout / 5 + minEvents := int(shortTimeout/tickInterval) + 1 // must exceed what's possible before WriteTimeout + + scanner := bufio.NewScanner(resp.Body) + var events []string + for scanner.Scan() { + if strings.HasPrefix(scanner.Text(), "data:") { + events = append(events, scanner.Text()) + } + } + elapsed := time.Since(start) + + // A clean EOF with scanner.Err() == nil is necessary but not sufficient: + // if WriteTimeout kills the stream at shortTimeout the client may still + // observe a clean close with a handful of events already received. + assert.NoError(t, scanner.Err(), "SSE stream must close cleanly, not with a connection error") + + // Elapsed time proves the stream ran for (at least) its intended lifetime. + // If WriteTimeout had fired the handler would have been interrupted at ~100 ms, + // far shorter than streamDuration (300 ms). + assert.GreaterOrEqual(t, elapsed, streamDuration-50*time.Millisecond, + "SSE stream must have lasted at least streamDuration (%v); elapsed %v suggests WriteTimeout fired early", + streamDuration, elapsed) + + // Event count provides a second, independent signal: the stream must have + // delivered more events than could possibly arrive within shortTimeout. + assert.GreaterOrEqual(t, len(events), minEvents, + "expected >= %d events (more than possible before WriteTimeout); got %d", + minEvents, len(events)) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 2d267ce4a3..1b33bb61ee 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -27,6 +27,7 @@ import ( mcpparser "github.com/stacklok/toolhive/pkg/mcp" "github.com/stacklok/toolhive/pkg/recovery" "github.com/stacklok/toolhive/pkg/telemetry" + transportmiddleware "github.com/stacklok/toolhive/pkg/transport/middleware" transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/composer" @@ -47,7 +48,10 @@ const ( // defaultReadTimeout is the maximum duration for reading the entire request, including body. defaultReadTimeout = 30 * time.Second - // defaultWriteTimeout is the maximum duration before timing out writes of the response. + // defaultWriteTimeout is the server-level write deadline set on http.Server.WriteTimeout. + // It protects all routes (health, metrics, well-known, etc.) from slow-write clients. + // For qualifying SSE (GET) connections, transportmiddleware.WriteTimeout clears this + // per-request via http.ResponseController.SetWriteDeadline(time.Time{}) (golang/go#16100). defaultWriteTimeout = 30 * time.Second // defaultIdleTimeout is the maximum amount of time to wait for the next request when keep-alive's are enabled. @@ -557,6 +561,13 @@ func (s *Server) Handler(_ context.Context) (http.Handler, error) { // Apply Accept header validation (rejects GET requests without Accept: text/event-stream) mcpHandler = headerValidatingMiddleware(mcpHandler) + // Clear the write deadline for qualifying SSE connections (GET + + // Accept: text/event-stream + MCP endpoint path) so the server-level + // WriteTimeout does not kill long-lived SSE streams (see golang/go#16100). + // Non-qualifying requests are left untouched; http.Server.WriteTimeout + // (defaultWriteTimeout) remains in effect for them. + mcpHandler = transportmiddleware.WriteTimeout(s.config.EndpointPath)(mcpHandler) + // Apply recovery middleware as outermost (catches panics from all inner middleware) mcpHandler = recovery.Middleware(mcpHandler) slog.Info("recovery middleware enabled for MCP endpoints") diff --git a/pkg/vmcp/server/session_management_realbackend_integration_test.go b/pkg/vmcp/server/session_management_realbackend_integration_test.go index 0f3bf23894..0a638ca603 100644 --- a/pkg/vmcp/server/session_management_realbackend_integration_test.go +++ b/pkg/vmcp/server/session_management_realbackend_integration_test.go @@ -35,10 +35,10 @@ import ( // startRealMCPBackend is defined in testutil_test.go as a shared test utility. -// newRealTestServer builds a vMCP server with session management and and a -// real SessionFactory. The BackendRegistry mock returns the backend at backendURL -// so that CreateSession() opens a real HTTP connection to the MCP server. -func newRealTestServer(t *testing.T, backendURL string) *httptest.Server { +// newRealTestHandler builds the full vMCP handler backed by the MCP server at +// backendURL. It is the low-level helper used by newRealTestServer and any test +// that needs control over the httptest.Server configuration (e.g. WriteTimeout). +func newRealTestHandler(t *testing.T, backendURL string) http.Handler { t.Helper() ctrl := gomock.NewController(t) @@ -88,8 +88,15 @@ func newRealTestServer(t *testing.T, backendURL string) *httptest.Server { handler, err := srv.Handler(context.Background()) require.NoError(t, err) + return handler +} - ts := httptest.NewServer(handler) +// newRealTestServer builds a vMCP server with session management and a real +// SessionFactory. The BackendRegistry mock returns the backend at backendURL +// so that CreateSession() opens a real HTTP connection to the MCP server. +func newRealTestServer(t *testing.T, backendURL string) *httptest.Server { + t.Helper() + ts := httptest.NewServer(newRealTestHandler(t, backendURL)) t.Cleanup(ts.Close) return ts } diff --git a/pkg/vmcp/server/write_timeout_integration_test.go b/pkg/vmcp/server/write_timeout_integration_test.go new file mode 100644 index 0000000000..d3a4e8b265 --- /dev/null +++ b/pkg/vmcp/server/write_timeout_integration_test.go @@ -0,0 +1,112 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package server_test + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestIntegration_SSEGetConnectionSurvivesWriteTimeout verifies that the full +// vMCP server — with writeTimeoutMiddleware wired in — keeps a qualifying SSE +// GET connection alive past the server-level WriteTimeout. +// +// The test uses httptest.NewUnstartedServer so it can set a very short +// WriteTimeout before starting the server. It then opens a GET /mcp request +// with Accept: text/event-stream and reads from the body inside a context whose +// deadline is 3× the WriteTimeout. Two outcomes are possible: +// +// - context.DeadlineExceeded: the read was still pending when our observation +// window ended — the connection was NOT killed. This is the expected result. +// - io.EOF or connection error: the server closed the connection early — the +// WriteTimeout fired on the SSE stream. This is the failure case. +func TestIntegration_SSEGetConnectionSurvivesWriteTimeout(t *testing.T) { + t.Parallel() + + const shortTimeout = 200 * time.Millisecond + + backendURL := startRealMCPBackend(t) + + // Build the handler separately so we can wrap it in a server with a custom WriteTimeout. + handler := newRealTestHandler(t, backendURL) + + ts := httptest.NewUnstartedServer(handler) + ts.Config.WriteTimeout = shortTimeout + ts.Start() + t.Cleanup(ts.Close) + + // Initialize an MCP session so the server assigns us a valid Mcp-Session-Id. + // The initialize POST completes well within the server WriteTimeout. + client := NewMCPTestClient(t, ts.URL) + sessionID := client.InitializeSession() + + // Open a qualifying SSE GET stream. The observation context lives 3× longer + // than the WriteTimeout; if the middleware is absent (or broken) the server + // will kill the TCP connection after ~shortTimeout and the read below will + // return io.EOF instead of context.DeadlineExceeded. + sseCtx, sseCancel := context.WithTimeout(context.Background(), 3*shortTimeout) + defer sseCancel() + + req, err := http.NewRequestWithContext(sseCtx, http.MethodGet, ts.URL+"/mcp", nil) + require.NoError(t, err) + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("Mcp-Session-Id", sessionID) + + resp, err := ts.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Loop until the observation window closes. io.EOF or a connection error + // before the deadline means WriteTimeout killed the stream (test failure); + // context expiry with the connection intact is the expected outcome. + buf := make([]byte, 64) + for { + _, readErr := resp.Body.Read(buf) + if readErr == nil { + continue // data received; connection still alive, keep reading + } + if errors.Is(readErr, context.DeadlineExceeded) || errors.Is(readErr, context.Canceled) { + break // observation window expired with connection intact — test passes + } + if errors.Is(readErr, io.EOF) || errors.Is(readErr, io.ErrUnexpectedEOF) { + assert.Fail(t, "SSE GET connection was closed by the server before the observation window expired; WriteTimeout may have fired", "error: %v", readErr) + break + } + // Any other error (e.g. connection reset) is also a failure. + assert.Fail(t, "unexpected error reading SSE stream", "error: %v", readErr) + break + } +} + +// TestIntegration_NonSSEGetRejectedWithNotAcceptable verifies that a GET request +// without Accept: text/event-stream is rejected by the vMCP server with 406. +// This confirms that headerValidatingMiddleware fires before the SSE stream is +// opened, and that the write-timeout middleware does not interfere with the +// rejection path. +func TestIntegration_NonSSEGetRejectedWithNotAcceptable(t *testing.T) { + t.Parallel() + + backendURL := startRealMCPBackend(t) + ts := newRealTestServer(t, backendURL) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, ts.URL+"/mcp", nil) + require.NoError(t, err) + // No Accept header — not a qualifying SSE request. + + resp, err := ts.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusNotAcceptable, resp.StatusCode, + "GET without Accept: text/event-stream must be rejected with 406") +}