diff --git a/pkg/gateway/clientpool.go b/pkg/gateway/clientpool.go index 3f04f40f..bad9d5c6 100644 --- a/pkg/gateway/clientpool.go +++ b/pkg/gateway/clientpool.go @@ -200,6 +200,32 @@ func (cp *clientPool) ReleaseClientsForSession(session *mcp.ServerSession) { } } +// InvalidateKeptClient closes and removes the cached client for the given server/session key. +func (cp *clientPool) InvalidateKeptClient(serverConfig *catalog.ServerConfig, config *clientConfig) { + if config == nil || config.serverSession == nil { + return + } + + key := clientKey{serverName: serverConfig.Name, session: config.serverSession} + + cp.clientLock.Lock() + defer cp.clientLock.Unlock() + + kc, exists := cp.keptClients[key] + if !exists { + return + } + + log.Log(fmt.Sprintf("ClientPool: Invalidating kept client for server: %s", serverConfig.Name)) + if kc.Getter.started.Load() { + client, err := kc.Getter.GetClient(context.Background()) + if err == nil { + client.Session().Close() + } + } + delete(cp.keptClients, key) +} + // InvalidateOAuthClients closes and removes all OAuth client connections for the specified provider // This allows clients to reconnect with updated/refreshed tokens func (cp *clientPool) InvalidateOAuthClients(provider string) { diff --git a/pkg/gateway/clientpool_test.go b/pkg/gateway/clientpool_test.go index 3b984a5e..e15ad2ac 100644 --- a/pkg/gateway/clientpool_test.go +++ b/pkg/gateway/clientpool_test.go @@ -912,3 +912,45 @@ func TestStdioClientInitialization(t *testing.T) { t.Logf("Successfully initialized stdio client and retrieved %d tools", len(tools.Tools)) } + +func TestInvalidateKeptClient_RemovesMatchingSession(t *testing.T) { + session := &mcp.ServerSession{} + cp := &clientPool{ + keptClients: make(map[clientKey]keptClient), + } + + getter := &clientGetter{} + getter.once.Do(func() {}) + getter.err = fmt.Errorf("mock: no real client") + + serverConfig := &catalog.ServerConfig{ + Name: "remote-svc", + Spec: catalog.Server{ + Type: "remote", + Remote: catalog.Remote{URL: "https://mcp.example.com/mcp"}, + }, + } + key := clientKey{serverName: "remote-svc", session: session} + cp.keptClients[key] = keptClient{ + Name: "remote-svc", + Getter: getter, + Config: serverConfig, + } + + cp.InvalidateKeptClient(serverConfig, &clientConfig{serverSession: session}) + + assert.Empty(t, cp.keptClients) +} + +func TestInvalidateKeptClient_SkipsNilSession(t *testing.T) { + cp := &clientPool{ + keptClients: map[clientKey]keptClient{ + {serverName: "remote-svc"}: {}, + }, + } + + serverConfig := &catalog.ServerConfig{Name: "remote-svc"} + cp.InvalidateKeptClient(serverConfig, nil) + + assert.Len(t, cp.keptClients, 1) +} diff --git a/pkg/gateway/handlers.go b/pkg/gateway/handlers.go index 1675a20b..48582359 100644 --- a/pkg/gateway/handlers.go +++ b/pkg/gateway/handlers.go @@ -59,7 +59,7 @@ func (g *Gateway) mcpToolHandler(tool catalog.Tool) mcp.ToolHandler { } } -func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, _ *mcp.ToolAnnotations, originalToolName string) mcp.ToolHandler { +func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, annotations *mcp.ToolAnnotations, originalToolName string) mcp.ToolHandler { return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Look up server configuration serverConfig, _, ok := g.configuration.Find(serverName) @@ -147,8 +147,8 @@ func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, _ Arguments: args, } - // Execute the tool call - result, err := client.Session().CallTool(ctx, params) + // Execute the tool call, recovering from stale remote sessions when safe. + result, err := g.callRemoteTool(ctx, serverConfig, server, annotations, originalToolName, params, req.Session, client) // Record duration duration := time.Since(startTime).Milliseconds() diff --git a/pkg/gateway/stale_response.go b/pkg/gateway/stale_response.go new file mode 100644 index 00000000..6328d108 --- /dev/null +++ b/pkg/gateway/stale_response.go @@ -0,0 +1,89 @@ +package gateway + +import ( + "context" + "fmt" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/docker/mcp-gateway/pkg/catalog" + mcpclient "github.com/docker/mcp-gateway/pkg/mcp" +) + +// isStaleEmptySuccess reports whether a tool result is a structurally empty +// success response that typically indicates a stale remote session. +func isStaleEmptySuccess(result *mcp.CallToolResult) bool { + if result == nil || result.IsError { + return false + } + if len(result.Content) > 0 { + return false + } + return result.StructuredContent == nil +} + +// isSafeToRetryTool reports whether an empty-success recovery retry is safe +// based on MCP tool annotation hints. +func isSafeToRetryTool(annotations *mcp.ToolAnnotations) bool { + if annotations == nil { + return false + } + if annotations.ReadOnlyHint { + return true + } + return annotations.IdempotentHint +} + +func (g *Gateway) callRemoteTool( + ctx context.Context, + serverConfig *catalog.ServerConfig, + server *mcp.Server, + annotations *mcp.ToolAnnotations, + originalToolName string, + params *mcp.CallToolParams, + session *mcp.ServerSession, + client mcpclient.Client, +) (*mcp.CallToolResult, error) { + result, err := client.Session().CallTool(ctx, params) + if err != nil { + return nil, err + } + + if !serverConfig.IsRemote() || !isStaleEmptySuccess(result) { + return result, nil + } + + clientConfig := getClientConfig(session, server) + g.clientPool.InvalidateKeptClient(serverConfig, clientConfig) + + if !isSafeToRetryTool(annotations) { + return nil, fmt.Errorf( + "remote tool %q on server %q returned an empty success response (stale session)", + originalToolName, + serverConfig.Name, + ) + } + + retryClient, retryErr := g.clientPool.AcquireClient(ctx, serverConfig, clientConfig) + if retryErr != nil { + return nil, fmt.Errorf( + "remote tool %q returned empty result and failed to refresh session: %w", + originalToolName, + retryErr, + ) + } + defer g.clientPool.ReleaseClient(retryClient) + + result, err = retryClient.Session().CallTool(ctx, params) + if err != nil { + return nil, err + } + if isStaleEmptySuccess(result) { + return nil, fmt.Errorf( + "remote tool %q on server %q returned an empty success response after session refresh", + originalToolName, + serverConfig.Name, + ) + } + return result, nil +} diff --git a/pkg/gateway/stale_response_test.go b/pkg/gateway/stale_response_test.go new file mode 100644 index 00000000..c3751434 --- /dev/null +++ b/pkg/gateway/stale_response_test.go @@ -0,0 +1,60 @@ +package gateway + +import ( + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" +) + +func TestIsStaleEmptySuccess(t *testing.T) { + t.Run("nil result", func(t *testing.T) { + assert.False(t, isStaleEmptySuccess(nil)) + }) + + t.Run("error result", func(t *testing.T) { + assert.False(t, isStaleEmptySuccess(&mcp.CallToolResult{ + IsError: true, + })) + }) + + t.Run("empty content success", func(t *testing.T) { + assert.True(t, isStaleEmptySuccess(&mcp.CallToolResult{ + Content: []mcp.Content{}, + })) + }) + + t.Run("nil content success", func(t *testing.T) { + assert.True(t, isStaleEmptySuccess(&mcp.CallToolResult{})) + }) + + t.Run("non-empty content", func(t *testing.T) { + assert.False(t, isStaleEmptySuccess(&mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: `{"issues":[]}`}}, + })) + }) + + t.Run("structured content only", func(t *testing.T) { + assert.False(t, isStaleEmptySuccess(&mcp.CallToolResult{ + StructuredContent: map[string]any{"issues": []any{}}, + })) + }) +} + +func TestIsSafeToRetryTool(t *testing.T) { + t.Run("nil annotations", func(t *testing.T) { + assert.False(t, isSafeToRetryTool(nil)) + }) + + t.Run("read only", func(t *testing.T) { + assert.True(t, isSafeToRetryTool(&mcp.ToolAnnotations{ReadOnlyHint: true})) + }) + + t.Run("idempotent write", func(t *testing.T) { + assert.True(t, isSafeToRetryTool(&mcp.ToolAnnotations{IdempotentHint: true})) + }) + + t.Run("no hints", func(t *testing.T) { + assert.False(t, isSafeToRetryTool(&mcp.ToolAnnotations{})) + }) +}