From 40996992dceb526c54520bd17a787157cde612f1 Mon Sep 17 00:00:00 2001 From: syf2211 Date: Sat, 27 Jun 2026 08:37:57 +0000 Subject: [PATCH] fix(gateway): guard nil session in mcp-add handler Fixes #442 Add nil checks before accessing req.Session in addServerHandler and getRemoteOAuthServerStatus so mcp-add returns structured errors instead of panicking when session info is unavailable. --- pkg/gateway/mcpadd.go | 32 +++++++----- pkg/gateway/mcpadd_test.go | 103 +++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 13 deletions(-) create mode 100644 pkg/gateway/mcpadd_test.go diff --git a/pkg/gateway/mcpadd.go b/pkg/gateway/mcpadd.go index 76c75896..98bab2a4 100644 --- a/pkg/gateway/mcpadd.go +++ b/pkg/gateway/mcpadd.go @@ -175,10 +175,12 @@ func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler { // If secrets or config are missing, handle based on client type if len(missingSecrets) > 0 || len(missingConfig) > 0 { - // Safely determine client name (InitializeParams may be nil for some transports) + // Safely determine client name (session or InitializeParams may be nil) clientName := "" - if init := req.Session.InitializeParams(); init != nil && init.ClientInfo != nil { - clientName = init.ClientInfo.Name + if req.Session != nil { + if init := req.Session.InitializeParams(); init != nil && init.ClientInfo != nil { + clientName = init.ClientInfo.Name + } } if clientName == "nanobot" && len(missingSecrets) > 0 { @@ -317,12 +319,15 @@ func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler { if g.McpOAuthDcrEnabled && serverConfig != nil && serverConfig.IsRemote() { - - init := req.Session.InitializeParams() - if init != nil && - init.Capabilities != nil && - init.Capabilities.Elicitation != nil { - + shouldHandleOAuth := req.Session == nil + if !shouldHandleOAuth { + if init := req.Session.InitializeParams(); init != nil && + init.Capabilities != nil && + init.Capabilities.Elicitation != nil { + shouldHandleOAuth = true + } + } + if shouldHandleOAuth { authorized, oauthText := g.getRemoteOAuthServerStatus( ctx, serverName, @@ -515,10 +520,10 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str } // Proceed with elicitation only if the client supports it - init := req.Session.InitializeParams() - if init != nil && - init.Capabilities != nil && - init.Capabilities.Elicitation != nil { + if req.Session != nil { + if init := req.Session.InitializeParams(); init != nil && + init.Capabilities != nil && + init.Capabilities.Elicitation != nil { // Elicit a response from the client asking whether to open a browser for authorization elicitResult, err := req.Session.Elicit(ctx, &mcp.ElicitParams{ Message: fmt.Sprintf("Would you like to open a browser to authorize the '%s' server?", serverName), @@ -561,6 +566,7 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str } return true, fmt.Sprintf("Successfully added server '%s'. Authorization completed.", serverName) + } } // Check if user is already authorized by checking if token exists (only if provider exists) diff --git a/pkg/gateway/mcpadd_test.go b/pkg/gateway/mcpadd_test.go new file mode 100644 index 00000000..47b53bd8 --- /dev/null +++ b/pkg/gateway/mcpadd_test.go @@ -0,0 +1,103 @@ +package gateway + +import ( + "context" + "encoding/json" + "strings" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/docker/mcp-gateway/pkg/catalog" + "github.com/docker/mcp-gateway/pkg/oauth" +) + +func TestAddServerHandlerNilSessionMissingConfig(t *testing.T) { + t.Parallel() + + g := &Gateway{ + configuration: Configuration{ + serverNames: []string{}, + servers: map[string]catalog.Server{ + "webhook-mcp": { + Name: "webhook-mcp", + Image: "example/webhook-mcp:latest", + Config: []any{ + map[string]any{ + "name": "endpoint", + "type": "string", + }, + }, + }, + }, + }, + } + + handler := addServerHandler(g, nil) + args, err := json.Marshal(map[string]any{ + "name": "webhook-mcp", + }) + if err != nil { + t.Fatalf("marshal arguments: %v", err) + } + + req := &mcp.CallToolRequest{ + Session: nil, + Params: &mcp.CallToolParamsRaw{ + Arguments: args, + }, + } + + result, err := handler(context.Background(), req) + if err != nil { + t.Fatalf("handler returned error: %v", err) + } + if result == nil || len(result.Content) == 0 { + t.Fatal("expected tool result content") + } + + text, ok := result.Content[0].(*mcp.TextContent) + if !ok { + t.Fatalf("expected text content, got %T", result.Content[0]) + } + if !strings.Contains(text.Text, "Missing required") { + t.Fatalf("expected missing requirements error, got: %s", text.Text) + } +} + +func TestGetRemoteOAuthServerStatusNilSession(t *testing.T) { + t.Parallel() + + g := &Gateway{ + Options: Options{McpOAuthDcrEnabled: true}, + configuration: Configuration{ + servers: map[string]catalog.Server{ + "remote-oauth": { + Name: "remote-oauth", + Type: "remote", + Remote: catalog.Remote{ + URL: "https://example.com/mcp", + }, + }, + }, + }, + oauthProviders: map[string]*oauth.Provider{ + "remote-oauth": {}, + }, + } + + req := &mcp.CallToolRequest{Session: nil} + authorized, message := g.getRemoteOAuthServerStatus( + context.Background(), + "remote-oauth", + req, + false, + ) + + if authorized { + t.Fatalf("expected unauthorized flow without session, got authorized with message: %q", message) + } + if !strings.Contains(message, "authorize") { + t.Fatalf("expected authorization instructions, got: %s", message) + } +}