Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions pkg/gateway/codemode.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/docker/mcp-gateway/pkg/catalog"
"github.com/docker/mcp-gateway/pkg/codemode"
"github.com/docker/mcp-gateway/pkg/policy"
)

// serverToolSetAdapter adapts a gateway server to the codemode.ToolSet interface
Expand Down Expand Up @@ -44,6 +45,9 @@ func (a *serverToolSetAdapter) Tools(ctx context.Context) ([]*codemode.ToolWithH
// Create a handler that calls the tool on the remote server
handler := func(tool *mcp.Tool) mcp.ToolHandler {
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
if err := a.checkInvokePolicy(ctx, tool.Name, req.Session); err != nil {
return nil, err
}
// Forward the tool call to the actual server
return client.Session().CallTool(ctx, &mcp.CallToolParams{
Name: tool.Name,
Expand All @@ -61,6 +65,26 @@ func (a *serverToolSetAdapter) Tools(ctx context.Context) ([]*codemode.ToolWithH
return result, nil
}

// checkInvokePolicy enforces the ActionInvoke policy for a backend tool before
// code-mode dispatches it, matching the gate applied on the direct tool-call and
// mcp-exec paths so code-mode cannot bypass an operator-configured policy.
func (a *serverToolSetAdapter) checkInvokePolicy(ctx context.Context, toolName string, session *mcp.ServerSession) error {
if a.gateway.policyClient == nil {
return nil
}
policyReq := a.gateway.configuration.policyRequest(a.serverConfig.Name, toolName, policy.ActionInvoke)
decision, err := a.gateway.policyClient.Evaluate(ctx, policyReq)
event := buildAuditEvent(policyReq, decision, err, auditClientInfoFromSession(session))
submitAuditEvent(a.gateway.policyClient, event)
if err != nil {
return fmt.Errorf("policy check failed for %s/%s: %w", a.serverConfig.Name, toolName, err)
}
if !decision.Allowed {
return fmt.Errorf("policy denied tool %s on server %s: %s", toolName, a.serverConfig.Name, decision.Reason)
}
return nil
}

func addCodemodeHandler(g *Gateway) mcp.ToolHandler {
return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Parse parameters
Expand Down
65 changes: 65 additions & 0 deletions pkg/gateway/codemode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package gateway

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/docker/mcp-gateway/pkg/catalog"
"github.com/docker/mcp-gateway/pkg/policy"
)

// TestCodemodeAdapter_PolicyEnforcement verifies that code-mode evaluates the
// ActionInvoke policy for the target backend tool before dispatching it, so a
// code-mode script cannot reach a tool that direct invocation / mcp-exec deny.
func TestCodemodeAdapter_PolicyEnforcement(t *testing.T) {
newAdapter := func(mock *mockPolicyClient) *serverToolSetAdapter {
g := &Gateway{
policyClient: mock,
configuration: Configuration{
serverNames: []string{"backend-server"},
servers: map[string]catalog.Server{"backend-server": {Image: "img"}},
},
}
sc, _, ok := g.configuration.Find("backend-server")
require.True(t, ok)
return &serverToolSetAdapter{gateway: g, serverName: "backend-server", serverConfig: sc}
}

t.Run("blocks_denied_tool", func(t *testing.T) {
mock := newMockPolicyClient()
mock.deny("backend-server", "dangerous-tool", policy.ActionInvoke, "tool blocked for safety")
err := newAdapter(mock).checkInvokePolicy(context.Background(), "dangerous-tool", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "policy denied")
assert.Contains(t, err.Error(), "dangerous-tool")
assert.Contains(t, err.Error(), "tool blocked for safety")
})

t.Run("denies_on_error", func(t *testing.T) {
mock := newMockPolicyClient()
mock.failWith("backend-server", "dangerous-tool", policy.ActionInvoke, errors.New("policy service down"))
err := newAdapter(mock).checkInvokePolicy(context.Background(), "dangerous-tool", nil)
require.Error(t, err)
assert.Contains(t, err.Error(), "policy")
})

t.Run("allows_permitted_tool", func(t *testing.T) {
err := newAdapter(newMockPolicyClient()).checkInvokePolicy(context.Background(), "safe-tool", nil)
require.NoError(t, err)
})

t.Run("nil_policy_client_allows", func(t *testing.T) {
g := &Gateway{configuration: Configuration{
serverNames: []string{"backend-server"},
servers: map[string]catalog.Server{"backend-server": {Image: "img"}},
}}
sc, _, ok := g.configuration.Find("backend-server")
require.True(t, ok)
a := &serverToolSetAdapter{gateway: g, serverName: "backend-server", serverConfig: sc}
require.NoError(t, a.checkInvokePolicy(context.Background(), "any-tool", nil))
})
}
Loading