Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions pkg/gateway/mcpadd.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,12 @@ func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler {

// If secrets or config are missing, handle based on client type
if len(missingSecrets) > 0 || len(missingConfig) > 0 {
// Safely determine client name (InitializeParams may be nil for some transports)
// Safely determine client name (session or InitializeParams may be nil)
clientName := ""
if init := req.Session.InitializeParams(); init != nil && init.ClientInfo != nil {
clientName = init.ClientInfo.Name
if req.Session != nil {
if init := req.Session.InitializeParams(); init != nil && init.ClientInfo != nil {
clientName = init.ClientInfo.Name
}
}

if clientName == "nanobot" && len(missingSecrets) > 0 {
Expand Down Expand Up @@ -317,12 +319,15 @@ func addServerHandler(g *Gateway, clientConfig *clientConfig) mcp.ToolHandler {
if g.McpOAuthDcrEnabled &&
serverConfig != nil &&
serverConfig.IsRemote() {

init := req.Session.InitializeParams()
if init != nil &&
init.Capabilities != nil &&
init.Capabilities.Elicitation != nil {

shouldHandleOAuth := req.Session == nil
if !shouldHandleOAuth {
if init := req.Session.InitializeParams(); init != nil &&
init.Capabilities != nil &&
init.Capabilities.Elicitation != nil {
shouldHandleOAuth = true
}
}
if shouldHandleOAuth {
authorized, oauthText := g.getRemoteOAuthServerStatus(
ctx,
serverName,
Expand Down Expand Up @@ -515,10 +520,10 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str
}

// Proceed with elicitation only if the client supports it
init := req.Session.InitializeParams()
if init != nil &&
init.Capabilities != nil &&
init.Capabilities.Elicitation != nil {
if req.Session != nil {
if init := req.Session.InitializeParams(); init != nil &&
init.Capabilities != nil &&
init.Capabilities.Elicitation != nil {
// Elicit a response from the client asking whether to open a browser for authorization
elicitResult, err := req.Session.Elicit(ctx, &mcp.ElicitParams{
Message: fmt.Sprintf("Would you like to open a browser to authorize the '%s' server?", serverName),
Expand Down Expand Up @@ -561,6 +566,7 @@ func (g *Gateway) getRemoteOAuthServerStatus(ctx context.Context, serverName str
}

return true, fmt.Sprintf("Successfully added server '%s'. Authorization completed.", serverName)
}
}

// Check if user is already authorized by checking if token exists (only if provider exists)
Expand Down
103 changes: 103 additions & 0 deletions pkg/gateway/mcpadd_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package gateway

import (
"context"
"encoding/json"
"strings"
"testing"

"github.com/modelcontextprotocol/go-sdk/mcp"

"github.com/docker/mcp-gateway/pkg/catalog"
"github.com/docker/mcp-gateway/pkg/oauth"
)

func TestAddServerHandlerNilSessionMissingConfig(t *testing.T) {
t.Parallel()

g := &Gateway{
configuration: Configuration{
serverNames: []string{},
servers: map[string]catalog.Server{
"webhook-mcp": {
Name: "webhook-mcp",
Image: "example/webhook-mcp:latest",
Config: []any{
map[string]any{
"name": "endpoint",
"type": "string",
},
},
},
},
},
}

handler := addServerHandler(g, nil)
args, err := json.Marshal(map[string]any{
"name": "webhook-mcp",
})
if err != nil {
t.Fatalf("marshal arguments: %v", err)
}

req := &mcp.CallToolRequest{
Session: nil,
Params: &mcp.CallToolParamsRaw{
Arguments: args,
},
}

result, err := handler(context.Background(), req)
if err != nil {
t.Fatalf("handler returned error: %v", err)
}
if result == nil || len(result.Content) == 0 {
t.Fatal("expected tool result content")
}

text, ok := result.Content[0].(*mcp.TextContent)
if !ok {
t.Fatalf("expected text content, got %T", result.Content[0])
}
if !strings.Contains(text.Text, "Missing required") {
t.Fatalf("expected missing requirements error, got: %s", text.Text)
}
}

func TestGetRemoteOAuthServerStatusNilSession(t *testing.T) {
t.Parallel()

g := &Gateway{
Options: Options{McpOAuthDcrEnabled: true},
configuration: Configuration{
servers: map[string]catalog.Server{
"remote-oauth": {
Name: "remote-oauth",
Type: "remote",
Remote: catalog.Remote{
URL: "https://example.com/mcp",
},
},
},
},
oauthProviders: map[string]*oauth.Provider{
"remote-oauth": {},
},
}

req := &mcp.CallToolRequest{Session: nil}
authorized, message := g.getRemoteOAuthServerStatus(
context.Background(),
"remote-oauth",
req,
false,
)

if authorized {
t.Fatalf("expected unauthorized flow without session, got authorized with message: %q", message)
}
if !strings.Contains(message, "authorize") {
t.Fatalf("expected authorization instructions, got: %s", message)
}
}