Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions cmd/docker-mcp/oauth/ls.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,18 @@ import (

"github.com/docker/mcp-gateway/cmd/docker-mcp/secret-management/formatting"
"github.com/docker/mcp-gateway/pkg/desktop"
pkgoauth "github.com/docker/mcp-gateway/pkg/oauth"
)

func Ls(ctx context.Context, outputJSON bool) error {
if pkgoauth.IsCEMode() {
return lsCEMode(ctx, outputJSON)
}
return lsDesktopMode(ctx, outputJSON)
}

// lsDesktopMode lists OAuth apps via Docker Desktop (existing behavior)
func lsDesktopMode(ctx context.Context, outputJSON bool) error {
client := desktop.NewAuthClient()

// Get OAuth apps from Docker Desktop (includes both built-in and DCR providers)
Expand Down Expand Up @@ -41,3 +50,52 @@ func Ls(ctx context.Context, outputJSON bool) error {
formatting.PrettyPrintTable(rows, []int{80, 120})
return nil
}

// lsCEMode lists OAuth apps in standalone CE mode using local credential storage
func lsCEMode(_ context.Context, outputJSON bool) error {
credHelper := pkgoauth.NewReadWriteCredentialHelper()
manager := pkgoauth.NewManager(credHelper)

clients, err := manager.ListDCRClients()
if err != nil {
return fmt.Errorf("failed to list OAuth apps: %w", err)
}

type ceApp struct {
App string `json:"app"`
Authorized bool `json:"authorized"`
Provider string `json:"provider,omitempty"`
}

var apps []ceApp
for name, client := range clients {
apps = append(apps, ceApp{
App: name,
Authorized: manager.HasValidToken(name),
Provider: client.ProviderName,
})
}

if outputJSON {
if len(apps) == 0 {
apps = make([]ceApp, 0)
}
jsonData, err := json.MarshalIndent(apps, "", " ")
if err != nil {
return err
}
fmt.Println(string(jsonData))
return nil
}

var rows [][]string
for _, app := range apps {
authorized := "not authorized"
if app.Authorized {
authorized = "authorized"
}
rows = append(rows, []string{app.App, authorized})
}
formatting.PrettyPrintTable(rows, []int{80, 120})
return nil
}
5 changes: 5 additions & 0 deletions pkg/desktop/features.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,11 @@ func IsRunningInDockerDesktop(ctx context.Context) bool {
return false
}

// Allow explicit CE mode override for non-Desktop Docker engines.
if os.Getenv("DOCKER_MCP_USE_CE") == "true" {
return false
}

// Always running in Docker Desktop on Windows and macOS
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
return true
Expand Down
44 changes: 44 additions & 0 deletions pkg/desktop/features_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package desktop

import (
"context"
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestIsRunningInDockerDesktop_CEModeEnvVar(t *testing.T) {
t.Run("DOCKER_MCP_USE_CE=true returns false", func(t *testing.T) {
t.Setenv("DOCKER_MCP_USE_CE", "true")
assert.False(t, IsRunningInDockerDesktop(context.Background()),
"Should return false when DOCKER_MCP_USE_CE=true")
})

t.Run("DOCKER_MCP_USE_CE=false does not override", func(t *testing.T) {
t.Setenv("DOCKER_MCP_USE_CE", "false")
// Without the override, platform default behavior applies
// Just verify it doesn't panic
_ = IsRunningInDockerDesktop(context.Background())
})

t.Run("DOCKER_MCP_USE_CE unset does not override", func(t *testing.T) {
os.Unsetenv("DOCKER_MCP_USE_CE")
// Should not panic and should use default platform behavior
_ = IsRunningInDockerDesktop(context.Background())
})
}

func TestIsRunningInDockerDesktop_InContainer(t *testing.T) {
t.Run("DOCKER_MCP_IN_CONTAINER=1 returns false", func(t *testing.T) {
t.Setenv("DOCKER_MCP_IN_CONTAINER", "1")
assert.False(t, IsRunningInDockerDesktop(context.Background()),
"Should return false when running in container")
})
}

func TestIsRunningInDockerDesktop_NoDockerDesktopContext(t *testing.T) {
ctx := WithNoDockerDesktop(context.Background())
assert.False(t, IsRunningInDockerDesktop(ctx),
"Should return false when context has NoDockerDesktop set")
}
5 changes: 5 additions & 0 deletions pkg/docker/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"io"
"os"
"runtime"
"sync"

Expand Down Expand Up @@ -56,6 +57,10 @@ func NewClient(cli command.Cli) Client {
}

func RunningInDockerCE(ctx context.Context, dockerCli command.Cli) (bool, error) {
if os.Getenv("DOCKER_MCP_USE_CE") == "true" {
return true, nil
}

if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
return false, nil
}
Expand Down
32 changes: 32 additions & 0 deletions pkg/docker/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package docker

import (
"os"
"testing"

"github.com/stretchr/testify/assert"
)

func TestRunningInDockerCE_CEModeEnvVar(t *testing.T) {
t.Run("DOCKER_MCP_USE_CE=true returns true", func(t *testing.T) {
t.Setenv("DOCKER_MCP_USE_CE", "true")
result, err := RunningInDockerCE(t.Context(), nil)
assert.NoError(t, err)
assert.True(t, result, "Should return true when DOCKER_MCP_USE_CE=true")
})

t.Run("DOCKER_MCP_USE_CE=false does not short-circuit", func(t *testing.T) {
t.Setenv("DOCKER_MCP_USE_CE", "false")
// Without the env override, the platform default applies (assumes Desktop).
result, err := RunningInDockerCE(t.Context(), nil)
assert.NoError(t, err)
assert.False(t, result)
})

t.Run("DOCKER_MCP_USE_CE unset does not short-circuit", func(t *testing.T) {
os.Unsetenv("DOCKER_MCP_USE_CE")
result, err := RunningInDockerCE(t.Context(), nil)
assert.NoError(t, err)
assert.False(t, result)
})
}
37 changes: 32 additions & 5 deletions pkg/gateway/clientpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gateway

import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
Expand Down Expand Up @@ -211,6 +212,36 @@ func (cp *clientPool) InvalidateOAuthClients(provider string) {
}
}

// normalizeArguments converts tool arguments from various representations
// (json.RawMessage, []byte, map[string]any, nil) into a consistent
// map[string]any for template evaluation. MCP transports may deliver
// arguments in any of these forms depending on the caller.
func normalizeArguments(args any) map[string]any {
switch v := args.(type) {
case map[string]any:
return v
case json.RawMessage:
var m map[string]any
if err := json.Unmarshal(v, &m); err != nil {
log.Logf("Warning: failed to decode tool arguments RawMessage: %v", err)
return make(map[string]any)
}
return m
case []byte:
var m map[string]any
if err := json.Unmarshal(v, &m); err != nil {
log.Logf("Warning: failed to decode tool arguments JSON: %v", err)
return make(map[string]any)
}
return m
case nil:
return make(map[string]any)
default:
log.Logf("Warning: unsupported tool arguments type: %T", args)
return make(map[string]any)
}
}

func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, params *mcp.CallToolParams) (*mcp.CallToolResult, error) {
args := cp.baseArgs(tool.Name)

Expand All @@ -219,11 +250,7 @@ func (cp *clientPool) runToolContainer(ctx context.Context, tool catalog.Tool, p
args = append(args, "--network", network)
}

// Convert params.Arguments to map[string]any
arguments, ok := params.Arguments.(map[string]any)
if !ok {
arguments = make(map[string]any)
}
arguments := normalizeArguments(params.Arguments)

// Volumes
for _, mount := range eval.EvaluateList(tool.Container.Volumes, arguments) {
Expand Down
54 changes: 54 additions & 0 deletions pkg/gateway/clientpool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package gateway

import (
"context"
"encoding/json"
"fmt"
"os"
"testing"
Expand Down Expand Up @@ -278,6 +279,59 @@ func parseConfig(t *testing.T, contentYAML string) map[string]any {
return config
}

func TestNormalizeArguments_MapStringAny(t *testing.T) {
input := map[string]any{"key": "value", "num": float64(42)}
result := normalizeArguments(input)
assert.Equal(t, input, result)
}

func TestNormalizeArguments_JSONRawMessage(t *testing.T) {
raw := json.RawMessage(`{"url":"https://example.com","count":3}`)
result := normalizeArguments(raw)
assert.Equal(t, "https://example.com", result["url"])
assert.Equal(t, float64(3), result["count"])
}

func TestNormalizeArguments_ByteSlice(t *testing.T) {
raw := []byte(`{"path":"/tmp/data","verbose":true}`)
result := normalizeArguments(raw)
assert.Equal(t, "/tmp/data", result["path"])
assert.Equal(t, true, result["verbose"])
}

func TestNormalizeArguments_Nil(t *testing.T) {
result := normalizeArguments(nil)
assert.NotNil(t, result)
assert.Empty(t, result)
}

func TestNormalizeArguments_UnexpectedType(t *testing.T) {
result := normalizeArguments("unexpected string")
assert.NotNil(t, result)
assert.Empty(t, result)
}

func TestNormalizeArguments_InvalidJSON(t *testing.T) {
raw := json.RawMessage(`{not valid json}`)
result := normalizeArguments(raw)
assert.NotNil(t, result)
assert.Empty(t, result)
}

func TestNormalizeArguments_InvalidByteSlice(t *testing.T) {
raw := []byte(`{not valid json}`)
result := normalizeArguments(raw)
assert.NotNil(t, result)
assert.Empty(t, result)
}

func TestNormalizeArguments_EmptyJSONObject(t *testing.T) {
raw := json.RawMessage(`{}`)
result := normalizeArguments(raw)
assert.NotNil(t, result)
assert.Empty(t, result)
}

func TestInvalidateOAuthClients_MatchesCommunityServer(t *testing.T) {
// Community server: remote URL set, but no Spec.OAuth metadata.
// This verifies Gap 3: InvalidateOAuthClients matches community servers
Expand Down
51 changes: 28 additions & 23 deletions pkg/gateway/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package gateway

import (
"context"
"encoding/json"
"fmt"
"os"
"time"
Expand Down Expand Up @@ -43,18 +42,22 @@ func inferServerTransportType(serverConfig *catalog.ServerConfig) string {

func (g *Gateway) mcpToolHandler(tool catalog.Tool) mcp.ToolHandler {
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Convert CallToolParamsRaw to CallToolParams
var args any
if len(req.Params.Arguments) > 0 {
if err := json.Unmarshal(req.Params.Arguments, &args); err != nil {
return nil, fmt.Errorf("failed to unmarshal arguments: %w", err)
}
}
// Convert CallToolParamsRaw to CallToolParams.
//
// Arguments are forwarded as raw JSON (json.RawMessage) and intentionally
// not unmarshaled here. The gateway must remain schema-agnostic and avoid
// coercing tool inputs, preserving full argument fidelity for tools that
// rely on structured or typed inputs.
params := &mcp.CallToolParams{
Meta: req.Params.Meta,
Name: req.Params.Name,
Arguments: args,
Meta: req.Params.Meta,
Name: req.Params.Name,
}

// Forward raw arguments unchanged, if present.
if len(req.Params.Arguments) > 0 {
params.Arguments = req.Params.Arguments
}

return g.clientPool.runToolContainer(ctx, tool, params)
}
}
Expand Down Expand Up @@ -132,19 +135,21 @@ func (g *Gateway) mcpServerToolHandler(serverName string, server *mcp.Server, _
}
defer g.clientPool.ReleaseClient(client)

// Convert CallToolParamsRaw to CallToolParams
var args any
if len(req.Params.Arguments) > 0 {
if jsonErr := json.Unmarshal(req.Params.Arguments, &args); jsonErr != nil {
telemetry.RecordToolError(ctx, span, serverConfig.Name, serverTransportType, req.Params.Name)
span.SetStatus(codes.Error, "Failed to unmarshal arguments")
return nil, fmt.Errorf("failed to unmarshal arguments: %w", jsonErr)
}
}
// Convert CallToolParamsRaw to CallToolParams.
//
// NOTE: Arguments are forwarded as raw JSON (json.RawMessage) instead of being
// unmarshaled here. The gateway must not interpret or coerce tool arguments,
// as it does not own the tool schema. Preserving the raw payload ensures full
// fidelity for schema-based and typed tools and matches the MCP Go SDK
// expectations.
params := &mcp.CallToolParams{
Meta: req.Params.Meta,
Name: originalToolName,
Arguments: args,
Meta: req.Params.Meta,
Name: originalToolName,
}

// Forward raw arguments unchanged, if present.
if len(req.Params.Arguments) > 0 {
params.Arguments = req.Params.Arguments
}

// Execute the tool call
Expand Down
Loading