diff --git a/cmd/docker-mcp/oauth/ls.go b/cmd/docker-mcp/oauth/ls.go index db468cc2f..2fcbcc2e2 100644 --- a/cmd/docker-mcp/oauth/ls.go +++ b/cmd/docker-mcp/oauth/ls.go @@ -7,9 +7,18 @@ import ( "github.com/docker/mcp-gateway/cmd/docker-mcp/secret-management/formatting" "github.com/docker/mcp-gateway/pkg/desktop" + pkgoauth "github.com/docker/mcp-gateway/pkg/oauth" ) func Ls(ctx context.Context, outputJSON bool) error { + if pkgoauth.IsCEMode() { + return lsCEMode(ctx, outputJSON) + } + return lsDesktopMode(ctx, outputJSON) +} + +// lsDesktopMode lists OAuth apps via Docker Desktop (existing behavior) +func lsDesktopMode(ctx context.Context, outputJSON bool) error { client := desktop.NewAuthClient() // Get OAuth apps from Docker Desktop (includes both built-in and DCR providers) @@ -41,3 +50,52 @@ func Ls(ctx context.Context, outputJSON bool) error { formatting.PrettyPrintTable(rows, []int{80, 120}) return nil } + +// lsCEMode lists OAuth apps in standalone CE mode using local credential storage +func lsCEMode(_ context.Context, outputJSON bool) error { + credHelper := pkgoauth.NewReadWriteCredentialHelper() + manager := pkgoauth.NewManager(credHelper) + + clients, err := manager.ListDCRClients() + if err != nil { + return fmt.Errorf("failed to list OAuth apps: %w", err) + } + + type ceApp struct { + App string `json:"app"` + Authorized bool `json:"authorized"` + Provider string `json:"provider,omitempty"` + } + + var apps []ceApp + for name, client := range clients { + apps = append(apps, ceApp{ + App: name, + Authorized: manager.HasValidToken(name), + Provider: client.ProviderName, + }) + } + + if outputJSON { + if len(apps) == 0 { + apps = make([]ceApp, 0) + } + jsonData, err := json.MarshalIndent(apps, "", " ") + if err != nil { + return err + } + fmt.Println(string(jsonData)) + return nil + } + + var rows [][]string + for _, app := range apps { + authorized := "not authorized" + if app.Authorized { + authorized = "authorized" + } + rows = append(rows, []string{app.App, authorized}) + } + formatting.PrettyPrintTable(rows, []int{80, 120}) + return nil +} diff --git a/pkg/desktop/features.go b/pkg/desktop/features.go index a286b7889..0468839c4 100644 --- a/pkg/desktop/features.go +++ b/pkg/desktop/features.go @@ -157,6 +157,11 @@ func IsRunningInDockerDesktop(ctx context.Context) bool { return false } + // Allow explicit CE mode override for non-Desktop Docker engines. + if os.Getenv("DOCKER_MCP_USE_CE") == "true" { + return false + } + // Always running in Docker Desktop on Windows and macOS if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { return true diff --git a/pkg/desktop/features_test.go b/pkg/desktop/features_test.go new file mode 100644 index 000000000..b4bfbe755 --- /dev/null +++ b/pkg/desktop/features_test.go @@ -0,0 +1,44 @@ +package desktop + +import ( + "context" + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsRunningInDockerDesktop_CEModeEnvVar(t *testing.T) { + t.Run("DOCKER_MCP_USE_CE=true returns false", func(t *testing.T) { + t.Setenv("DOCKER_MCP_USE_CE", "true") + assert.False(t, IsRunningInDockerDesktop(context.Background()), + "Should return false when DOCKER_MCP_USE_CE=true") + }) + + t.Run("DOCKER_MCP_USE_CE=false does not override", func(t *testing.T) { + t.Setenv("DOCKER_MCP_USE_CE", "false") + // Without the override, platform default behavior applies + // Just verify it doesn't panic + _ = IsRunningInDockerDesktop(context.Background()) + }) + + t.Run("DOCKER_MCP_USE_CE unset does not override", func(t *testing.T) { + os.Unsetenv("DOCKER_MCP_USE_CE") + // Should not panic and should use default platform behavior + _ = IsRunningInDockerDesktop(context.Background()) + }) +} + +func TestIsRunningInDockerDesktop_InContainer(t *testing.T) { + t.Run("DOCKER_MCP_IN_CONTAINER=1 returns false", func(t *testing.T) { + t.Setenv("DOCKER_MCP_IN_CONTAINER", "1") + assert.False(t, IsRunningInDockerDesktop(context.Background()), + "Should return false when running in container") + }) +} + +func TestIsRunningInDockerDesktop_NoDockerDesktopContext(t *testing.T) { + ctx := WithNoDockerDesktop(context.Background()) + assert.False(t, IsRunningInDockerDesktop(ctx), + "Should return false when context has NoDockerDesktop set") +} diff --git a/pkg/docker/client.go b/pkg/docker/client.go index d861d4768..53e03c4ed 100644 --- a/pkg/docker/client.go +++ b/pkg/docker/client.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "os" "runtime" "sync" @@ -56,6 +57,10 @@ func NewClient(cli command.Cli) Client { } func RunningInDockerCE(ctx context.Context, dockerCli command.Cli) (bool, error) { + if os.Getenv("DOCKER_MCP_USE_CE") == "true" { + return true, nil + } + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { return false, nil } diff --git a/pkg/docker/client_test.go b/pkg/docker/client_test.go new file mode 100644 index 000000000..55a3c73e6 --- /dev/null +++ b/pkg/docker/client_test.go @@ -0,0 +1,32 @@ +package docker + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRunningInDockerCE_CEModeEnvVar(t *testing.T) { + t.Run("DOCKER_MCP_USE_CE=true returns true", func(t *testing.T) { + t.Setenv("DOCKER_MCP_USE_CE", "true") + result, err := RunningInDockerCE(t.Context(), nil) + assert.NoError(t, err) + assert.True(t, result, "Should return true when DOCKER_MCP_USE_CE=true") + }) + + t.Run("DOCKER_MCP_USE_CE=false does not short-circuit", func(t *testing.T) { + t.Setenv("DOCKER_MCP_USE_CE", "false") + // Without the env override, the platform default applies (assumes Desktop). + result, err := RunningInDockerCE(t.Context(), nil) + assert.NoError(t, err) + assert.False(t, result) + }) + + t.Run("DOCKER_MCP_USE_CE unset does not short-circuit", func(t *testing.T) { + os.Unsetenv("DOCKER_MCP_USE_CE") + result, err := RunningInDockerCE(t.Context(), nil) + assert.NoError(t, err) + assert.False(t, result) + }) +} diff --git a/pkg/gateway/clientpool.go b/pkg/gateway/clientpool.go index b9afe3268..32760aae0 100644 --- a/pkg/gateway/clientpool.go +++ b/pkg/gateway/clientpool.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "encoding/json" "fmt" "os" "os/exec" @@ -211,6 +212,36 @@ func (cp *clientPool) InvalidateOAuthClients(provider string) { } } +// normalizeArguments converts tool arguments from various representations +// (json.RawMessage, []byte, map[string]any, nil) into a consistent +// map[string]any for template evaluation. MCP transports may deliver +// arguments in any of these forms depending on the caller. +func normalizeArguments(args any) map[string]any { + switch v := args.(type) { + case map[string]any: + return v + case json.RawMessage: + var m map[string]any + if err := json.Unmarshal(v, &m); err != nil { + log.Logf("Warning: failed to decode tool arguments RawMessage: %v", err) + return make(map[string]any) + } + return m + case []byte: + var m map[string]any + if err := json.Unmarshal(v, &m); err != nil { + log.Logf("Warning: failed to decode tool arguments JSON: %v", err) + return make(map[string]any) + } + return m + case nil: + return make(map[string]any) + default: + log.Logf("Warning: unsupported tool arguments type: %T", args) + return make(map[string]any) + } +} + func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, params *mcp.CallToolParams) (*mcp.CallToolResult, error) { args := cp.baseArgs(tool.Name) @@ -219,11 +250,7 @@ func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, p args = append(args, "--network", network) } - // Convert params.Arguments to map[string]any - arguments, ok := params.Arguments.(map[string]any) - if !ok { - arguments = make(map[string]any) - } + arguments := normalizeArguments(params.Arguments) // Volumes for _, mount := range eval.EvaluateList(tool.Container.Volumes, arguments) { diff --git a/pkg/gateway/clientpool_test.go b/pkg/gateway/clientpool_test.go index 96a149752..9602f4290 100644 --- a/pkg/gateway/clientpool_test.go +++ b/pkg/gateway/clientpool_test.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "encoding/json" "fmt" "os" "testing" @@ -278,6 +279,59 @@ func parseConfig(t *testing.T, contentYAML string) map[string]any { return config } +func TestNormalizeArguments_MapStringAny(t *testing.T) { + input := map[string]any{"key": "value", "num": float64(42)} + result := normalizeArguments(input) + assert.Equal(t, input, result) +} + +func TestNormalizeArguments_JSONRawMessage(t *testing.T) { + raw := json.RawMessage(`{"url":"https://example.com","count":3}`) + result := normalizeArguments(raw) + assert.Equal(t, "https://example.com", result["url"]) + assert.Equal(t, float64(3), result["count"]) +} + +func TestNormalizeArguments_ByteSlice(t *testing.T) { + raw := []byte(`{"path":"/tmp/data","verbose":true}`) + result := normalizeArguments(raw) + assert.Equal(t, "/tmp/data", result["path"]) + assert.Equal(t, true, result["verbose"]) +} + +func TestNormalizeArguments_Nil(t *testing.T) { + result := normalizeArguments(nil) + assert.NotNil(t, result) + assert.Empty(t, result) +} + +func TestNormalizeArguments_UnexpectedType(t *testing.T) { + result := normalizeArguments("unexpected string") + assert.NotNil(t, result) + assert.Empty(t, result) +} + +func TestNormalizeArguments_InvalidJSON(t *testing.T) { + raw := json.RawMessage(`{not valid json}`) + result := normalizeArguments(raw) + assert.NotNil(t, result) + assert.Empty(t, result) +} + +func TestNormalizeArguments_InvalidByteSlice(t *testing.T) { + raw := []byte(`{not valid json}`) + result := normalizeArguments(raw) + assert.NotNil(t, result) + assert.Empty(t, result) +} + +func TestNormalizeArguments_EmptyJSONObject(t *testing.T) { + raw := json.RawMessage(`{}`) + result := normalizeArguments(raw) + assert.NotNil(t, result) + assert.Empty(t, result) +} + func TestInvalidateOAuthClients_MatchesCommunityServer(t *testing.T) { // Community server: remote URL set, but no Spec.OAuth metadata. // This verifies Gap 3: InvalidateOAuthClients matches community servers diff --git a/pkg/gateway/handlers.go b/pkg/gateway/handlers.go index 1675a20bc..0ea161b69 100644 --- a/pkg/gateway/handlers.go +++ b/pkg/gateway/handlers.go @@ -2,7 +2,6 @@ package gateway import ( "context" - "encoding/json" "fmt" "os" "time" @@ -43,18 +42,22 @@ func inferServerTransportType(serverConfig *catalog.ServerConfig) string { func (g *Gateway) mcpToolHandler(tool catalog.Tool) mcp.ToolHandler { return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // Convert CallToolParamsRaw to CallToolParams - var args any - if len(req.Params.Arguments) > 0 { - if err := json.Unmarshal(req.Params.Arguments, &args); err != nil { - return nil, fmt.Errorf("failed to unmarshal arguments: %w", err) - } - } + // Convert CallToolParamsRaw to CallToolParams. + // + // Arguments are forwarded as raw JSON (json.RawMessage) and intentionally + // not unmarshaled here. The gateway must remain schema-agnostic and avoid + // coercing tool inputs, preserving full argument fidelity for tools that + // rely on structured or typed inputs. params := &mcp.CallToolParams{ - Meta: req.Params.Meta, - Name: req.Params.Name, - Arguments: args, + Meta: req.Params.Meta, + Name: req.Params.Name, } + + // Forward raw arguments unchanged, if present. + if len(req.Params.Arguments) > 0 { + params.Arguments = req.Params.Arguments + } + return g.clientPool.runToolContainer(ctx, tool, params) } } @@ -132,19 +135,21 @@ func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, _ } defer g.clientPool.ReleaseClient(client) - // Convert CallToolParamsRaw to CallToolParams - var args any - if len(req.Params.Arguments) > 0 { - if jsonErr := json.Unmarshal(req.Params.Arguments, &args); jsonErr != nil { - telemetry.RecordToolError(ctx, span, serverConfig.Name, serverTransportType, req.Params.Name) - span.SetStatus(codes.Error, "Failed to unmarshal arguments") - return nil, fmt.Errorf("failed to unmarshal arguments: %w", jsonErr) - } - } + // Convert CallToolParamsRaw to CallToolParams. + // + // NOTE: Arguments are forwarded as raw JSON (json.RawMessage) instead of being + // unmarshaled here. The gateway must not interpret or coerce tool arguments, + // as it does not own the tool schema. Preserving the raw payload ensures full + // fidelity for schema-based and typed tools and matches the MCP Go SDK + // expectations. params := &mcp.CallToolParams{ - Meta: req.Params.Meta, - Name: originalToolName, - Arguments: args, + Meta: req.Params.Meta, + Name: originalToolName, + } + + // Forward raw arguments unchanged, if present. + if len(req.Params.Arguments) > 0 { + params.Arguments = req.Params.Arguments } // Execute the tool call diff --git a/pkg/oauth/manager.go b/pkg/oauth/manager.go index e7f67c105..03be6c549 100644 --- a/pkg/oauth/manager.go +++ b/pkg/oauth/manager.go @@ -4,6 +4,8 @@ import ( "context" "fmt" "net/url" + "os" + "strings" "github.com/docker/docker-credential-helpers/credentials" "golang.org/x/oauth2" @@ -23,13 +25,31 @@ type Manager struct { redirectURI string } -// NewManager creates a new OAuth manager for CE mode +// NewManager creates a new OAuth manager. +// In CE mode, the redirect URI must point to the local gateway callback +// instead of the SaaS endpoint (mcp.docker.com). func NewManager(credHelper credentials.Helper) *Manager { + redirectURI := DefaultRedirectURI + + // CE mode requires a local redirect URI because the OAuth callback + // is handled by the local mcp-gateway process, not by Docker SaaS. + // + // Example: + // http://localhost:5000/callback + if os.Getenv("DOCKER_MCP_USE_CE") == "true" { + if v := os.Getenv("DOCKER_MCP_OAUTH_REDIRECT_URI"); v != "" { + redirectURI = v + } else { + // Default CE callback used by the local OAuth proxy + redirectURI = "http://localhost:5000/callback" + } + } + return &Manager{ - dcrManager: dcr.NewManager(credHelper, DefaultRedirectURI), + dcrManager: dcr.NewManager(credHelper, redirectURI), tokenStore: NewTokenStore(credHelper), stateManager: NewStateManager(), - redirectURI: DefaultRedirectURI, + redirectURI: redirectURI, } } @@ -125,6 +145,16 @@ func (m *Manager) BuildAuthorizationURL(_ context.Context, serverName string, sc // ExchangeCode exchanges an authorization code for an access token func (m *Manager) ExchangeCode(ctx context.Context, code string, state string) error { + // Strip the mcp-gateway:PORT: prefix if present. + // BuildAuthorizationURL formats state as "mcp-gateway:PORT:UUID" for proxy routing, + // but the StateManager only stores the base UUID. + if strings.HasPrefix(state, "mcp-gateway:") { + parts := strings.SplitN(state, ":", 3) + if len(parts) == 3 { + state = parts[2] + } + } + // Validate state and retrieve verifier serverName, verifier, err := m.stateManager.Validate(state) if err != nil { @@ -183,3 +213,18 @@ func (m *Manager) RevokeToken(_ context.Context, serverName string) error { func (m *Manager) DeleteDCRClient(serverName string) error { return m.dcrManager.DeleteDCRClient(serverName) } + +// ListDCRClients returns all registered DCR clients +func (m *Manager) ListDCRClients() (map[string]dcr.Client, error) { + return m.dcrManager.ListDCRClients() +} + +// HasValidToken checks if a valid OAuth token exists for the given server +func (m *Manager) HasValidToken(serverName string) bool { + dcrClient, err := m.dcrManager.GetDCRClient(serverName) + if err != nil { + return false + } + _, err = m.tokenStore.Retrieve(dcrClient) + return err == nil +} diff --git a/pkg/oauth/manager_test.go b/pkg/oauth/manager_test.go index 756f1928b..9ea2d1375 100644 --- a/pkg/oauth/manager_test.go +++ b/pkg/oauth/manager_test.go @@ -2,6 +2,7 @@ package oauth import ( "context" + "os" "strings" "testing" "time" @@ -344,6 +345,101 @@ func TestManager_CallbackURLParsing(t *testing.T) { } } +func TestManager_ExchangeCode_StripsStatePrefix(t *testing.T) { + manager := setupTestManager(t) + serverName := "test-server" + + setupTestDCRClient(t, manager, serverName) + + // Generate a state via BuildAuthorizationURL with a callback URL, + // which produces "mcp-gateway:PORT:UUID" format. + _, baseState, _, err := manager.BuildAuthorizationURL( + context.Background(), + serverName, + []string{"read"}, + "http://localhost:8080/callback", + ) + require.NoError(t, err) + + // Simulate the prefixed state that would come back from the OAuth callback + prefixedState := "mcp-gateway:8080:" + baseState + + // ExchangeCode will fail at token exchange (no real server), but it should + // get past state validation — meaning the prefix was correctly stripped. + err = manager.ExchangeCode(context.Background(), "test-code", prefixedState) + require.Error(t, err) + // If prefix stripping failed, we'd get "invalid state parameter". + // If it succeeded, we get a token exchange error instead. + assert.NotContains(t, err.Error(), "invalid state parameter") +} + +func TestManager_ExchangeCode_PlainStateStillWorks(t *testing.T) { + manager := setupTestManager(t) + serverName := "test-server" + + setupTestDCRClient(t, manager, serverName) + + // Generate a state without callback URL (no prefix) + _, baseState, _, err := manager.BuildAuthorizationURL( + context.Background(), + serverName, + []string{"read"}, + "", + ) + require.NoError(t, err) + + // ExchangeCode should still validate plain UUIDs (no prefix to strip) + err = manager.ExchangeCode(context.Background(), "test-code", baseState) + require.Error(t, err) + assert.NotContains(t, err.Error(), "invalid state parameter") +} + +func TestManager_NewManager_CEModeRedirectURI(t *testing.T) { + // Save and restore env vars + origCE := os.Getenv("DOCKER_MCP_USE_CE") + origURI := os.Getenv("DOCKER_MCP_OAUTH_REDIRECT_URI") + defer func() { + os.Setenv("DOCKER_MCP_USE_CE", origCE) + os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", origURI) + }() + + t.Run("default mode uses SaaS redirect", func(t *testing.T) { + os.Setenv("DOCKER_MCP_USE_CE", "") + os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", "") + + helper := newFakeCredentialHelper() + manager := NewManager(helper) + assert.Equal(t, DefaultRedirectURI, manager.redirectURI) + }) + + t.Run("CE mode uses localhost redirect", func(t *testing.T) { + os.Setenv("DOCKER_MCP_USE_CE", "true") + os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", "") + + helper := newFakeCredentialHelper() + manager := NewManager(helper) + assert.Equal(t, "http://localhost:5000/callback", manager.redirectURI) + }) + + t.Run("CE mode with custom redirect URI", func(t *testing.T) { + os.Setenv("DOCKER_MCP_USE_CE", "true") + os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", "http://localhost:9999/custom") + + helper := newFakeCredentialHelper() + manager := NewManager(helper) + assert.Equal(t, "http://localhost:9999/custom", manager.redirectURI) + }) + + t.Run("non-CE mode ignores custom redirect URI", func(t *testing.T) { + os.Setenv("DOCKER_MCP_USE_CE", "false") + os.Setenv("DOCKER_MCP_OAUTH_REDIRECT_URI", "http://localhost:9999/custom") + + helper := newFakeCredentialHelper() + manager := NewManager(helper) + assert.Equal(t, DefaultRedirectURI, manager.redirectURI) + }) +} + func TestManager_StateFormatWithPort(t *testing.T) { manager := setupTestManager(t) serverName := "test-server" @@ -372,3 +468,37 @@ func TestManager_StateFormatWithPort(t *testing.T) { assert.NotContains(t, baseState, "mcp-gateway") assert.NotContains(t, baseState, ":") } + +func TestManager_ListDCRClients(t *testing.T) { + manager := setupTestManager(t) + + t.Run("empty list", func(t *testing.T) { + clients, err := manager.ListDCRClients() + require.NoError(t, err) + assert.Empty(t, clients) + }) + + t.Run("returns registered clients", func(t *testing.T) { + setupTestDCRClient(t, manager, "server-a") + setupTestDCRClient(t, manager, "server-b") + + clients, err := manager.ListDCRClients() + require.NoError(t, err) + assert.Len(t, clients, 2) + assert.Contains(t, clients, "server-a") + assert.Contains(t, clients, "server-b") + }) +} + +func TestManager_HasValidToken(t *testing.T) { + manager := setupTestManager(t) + + t.Run("no DCR client returns false", func(t *testing.T) { + assert.False(t, manager.HasValidToken("nonexistent")) + }) + + t.Run("DCR client without token returns false", func(t *testing.T) { + setupTestDCRClient(t, manager, "no-token-server") + assert.False(t, manager.HasValidToken("no-token-server")) + }) +}