Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions pkg/gateway/clientpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,32 @@ func (cp *clientPool) ReleaseClientsForSession(session *mcp.ServerSession) {
}
}

// InvalidateKeptClient closes and removes the cached client for the given server/session key.
func (cp *clientPool) InvalidateKeptClient(serverConfig *catalog.ServerConfig, config *clientConfig) {
if config == nil || config.serverSession == nil {
return
}

key := clientKey{serverName: serverConfig.Name, session: config.serverSession}

cp.clientLock.Lock()
defer cp.clientLock.Unlock()

kc, exists := cp.keptClients[key]
if !exists {
return
}

log.Log(fmt.Sprintf("ClientPool: Invalidating kept client for server: %s", serverConfig.Name))
if kc.Getter.started.Load() {
client, err := kc.Getter.GetClient(context.Background())
if err == nil {
client.Session().Close()
}
}
delete(cp.keptClients, key)
}

// InvalidateOAuthClients closes and removes all OAuth client connections for the specified provider
// This allows clients to reconnect with updated/refreshed tokens
func (cp *clientPool) InvalidateOAuthClients(provider string) {
Expand Down
42 changes: 42 additions & 0 deletions pkg/gateway/clientpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -912,3 +912,45 @@ func TestStdioClientInitialization(t *testing.T) {

t.Logf("Successfully initialized stdio client and retrieved %d tools", len(tools.Tools))
}

func TestInvalidateKeptClient_RemovesMatchingSession(t *testing.T) {
session := &mcp.ServerSession{}
cp := &clientPool{
keptClients: make(map[clientKey]keptClient),
}

getter := &clientGetter{}
getter.once.Do(func() {})
getter.err = fmt.Errorf("mock: no real client")

serverConfig := &catalog.ServerConfig{
Name: "remote-svc",
Spec: catalog.Server{
Type: "remote",
Remote: catalog.Remote{URL: "https://mcp.example.com/mcp"},
},
}
key := clientKey{serverName: "remote-svc", session: session}
cp.keptClients[key] = keptClient{
Name: "remote-svc",
Getter: getter,
Config: serverConfig,
}

cp.InvalidateKeptClient(serverConfig, &clientConfig{serverSession: session})

assert.Empty(t, cp.keptClients)
}

func TestInvalidateKeptClient_SkipsNilSession(t *testing.T) {
cp := &clientPool{
keptClients: map[clientKey]keptClient{
{serverName: "remote-svc"}: {},
},
}

serverConfig := &catalog.ServerConfig{Name: "remote-svc"}
cp.InvalidateKeptClient(serverConfig, nil)

assert.Len(t, cp.keptClients, 1)
}
6 changes: 3 additions & 3 deletions pkg/gateway/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (g *Gateway) mcpToolHandler(tool catalog.Tool) mcp.ToolHandler {
}
}

func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, _ *mcp.ToolAnnotations, originalToolName string) mcp.ToolHandler {
func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, annotations *mcp.ToolAnnotations, originalToolName string) mcp.ToolHandler {
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Look up server configuration
serverConfig, _, ok := g.configuration.Find(serverName)
Expand Down Expand Up @@ -147,8 +147,8 @@ func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, _
Arguments: args,
}

// Execute the tool call
result, err := client.Session().CallTool(ctx, params)
// Execute the tool call, recovering from stale remote sessions when safe.
result, err := g.callRemoteTool(ctx, serverConfig, server, annotations, originalToolName, params, req.Session, client)

// Record duration
duration := time.Since(startTime).Milliseconds()
Expand Down
89 changes: 89 additions & 0 deletions pkg/gateway/stale_response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package gateway

import (
"context"
"fmt"

"github.com/modelcontextprotocol/go-sdk/mcp"

"github.com/docker/mcp-gateway/pkg/catalog"
mcpclient "github.com/docker/mcp-gateway/pkg/mcp"
)

// isStaleEmptySuccess reports whether a tool result is a structurally empty
// success response that typically indicates a stale remote session.
func isStaleEmptySuccess(result *mcp.CallToolResult) bool {
if result == nil || result.IsError {
return false
}
if len(result.Content) > 0 {
return false
}
return result.StructuredContent == nil
}

// isSafeToRetryTool reports whether an empty-success recovery retry is safe
// based on MCP tool annotation hints.
func isSafeToRetryTool(annotations *mcp.ToolAnnotations) bool {
if annotations == nil {
return false
}
if annotations.ReadOnlyHint {
return true
}
return annotations.IdempotentHint
}

func (g *Gateway) callRemoteTool(
ctx context.Context,
serverConfig *catalog.ServerConfig,
server *mcp.Server,
annotations *mcp.ToolAnnotations,
originalToolName string,
params *mcp.CallToolParams,
session *mcp.ServerSession,
client mcpclient.Client,
) (*mcp.CallToolResult, error) {
result, err := client.Session().CallTool(ctx, params)
if err != nil {
return nil, err
}

if !serverConfig.IsRemote() || !isStaleEmptySuccess(result) {
return result, nil
}

clientConfig := getClientConfig(session, server)
g.clientPool.InvalidateKeptClient(serverConfig, clientConfig)

if !isSafeToRetryTool(annotations) {
return nil, fmt.Errorf(
"remote tool %q on server %q returned an empty success response (stale session)",
originalToolName,
serverConfig.Name,
)
}

retryClient, retryErr := g.clientPool.AcquireClient(ctx, serverConfig, clientConfig)
if retryErr != nil {
return nil, fmt.Errorf(
"remote tool %q returned empty result and failed to refresh session: %w",
originalToolName,
retryErr,
)
}
defer g.clientPool.ReleaseClient(retryClient)

result, err = retryClient.Session().CallTool(ctx, params)
if err != nil {
return nil, err
}
if isStaleEmptySuccess(result) {
return nil, fmt.Errorf(
"remote tool %q on server %q returned an empty success response after session refresh",
originalToolName,
serverConfig.Name,
)
}
return result, nil
}
60 changes: 60 additions & 0 deletions pkg/gateway/stale_response_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package gateway

import (
"testing"

"github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/assert"
)

func TestIsStaleEmptySuccess(t *testing.T) {
t.Run("nil result", func(t *testing.T) {
assert.False(t, isStaleEmptySuccess(nil))
})

t.Run("error result", func(t *testing.T) {
assert.False(t, isStaleEmptySuccess(&mcp.CallToolResult{
IsError: true,
}))
})

t.Run("empty content success", func(t *testing.T) {
assert.True(t, isStaleEmptySuccess(&mcp.CallToolResult{
Content: []mcp.Content{},
}))
})

t.Run("nil content success", func(t *testing.T) {
assert.True(t, isStaleEmptySuccess(&mcp.CallToolResult{}))
})

t.Run("non-empty content", func(t *testing.T) {
assert.False(t, isStaleEmptySuccess(&mcp.CallToolResult{
Content: []mcp.Content{&mcp.TextContent{Text: `{"issues":[]}`}},
}))
})

t.Run("structured content only", func(t *testing.T) {
assert.False(t, isStaleEmptySuccess(&mcp.CallToolResult{
StructuredContent: map[string]any{"issues": []any{}},
}))
})
}

func TestIsSafeToRetryTool(t *testing.T) {
t.Run("nil annotations", func(t *testing.T) {
assert.False(t, isSafeToRetryTool(nil))
})

t.Run("read only", func(t *testing.T) {
assert.True(t, isSafeToRetryTool(&mcp.ToolAnnotations{ReadOnlyHint: true}))
})

t.Run("idempotent write", func(t *testing.T) {
assert.True(t, isSafeToRetryTool(&mcp.ToolAnnotations{IdempotentHint: true}))
})

t.Run("no hints", func(t *testing.T) {
assert.False(t, isSafeToRetryTool(&mcp.ToolAnnotations{}))
})
}