diff --git a/pkg/gateway/activateprofile.go b/pkg/gateway/activateprofile.go index 17d83055..2e9b2250 100644 --- a/pkg/gateway/activateprofile.go +++ b/pkg/gateway/activateprofile.go @@ -18,6 +18,7 @@ import ( "github.com/docker/mcp-gateway/pkg/gateway/project" "github.com/docker/mcp-gateway/pkg/log" "github.com/docker/mcp-gateway/pkg/oci" + "github.com/docker/mcp-gateway/pkg/policy" "github.com/docker/mcp-gateway/pkg/workingset" ) @@ -131,6 +132,9 @@ func (g *Gateway) ActivateProfile(ctx context.Context, ws workingset.WorkingSet) var validationErrors []serverValidation for _, serverName := range serversToActivate { + if err := g.checkServerLoadPolicy(ctx, profileConfig.policyRequest(serverName, "", policy.ActionLoad), nil); err != nil { + return err + } serverConfig := profileConfig.servers[serverName] validation := serverValidation{serverName: serverName} diff --git a/pkg/gateway/configset.go b/pkg/gateway/configset.go index 250873e0..2a55533b 100644 --- a/pkg/gateway/configset.go +++ b/pkg/gateway/configset.go @@ -11,6 +11,7 @@ import ( "github.com/docker/mcp-gateway/pkg/log" "github.com/docker/mcp-gateway/pkg/oci" + "github.com/docker/mcp-gateway/pkg/policy" ) type configValue struct { @@ -19,7 +20,7 @@ type configValue struct { } func configSetHandler(g *Gateway) mcp.ToolHandler { - return func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Parse parameters var params configValue @@ -58,6 +59,13 @@ func configSetHandler(g *Gateway) mcp.ToolHandler { }, nil } + if err := g.checkServerLoadPolicy(ctx, g.configuration.policyRequest(serverName, "", policy.ActionLoad), req.Session); err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "Error: " + err.Error()}}, + IsError: true, + }, nil + } + // Validate config against server's schema if schema exists if serverConfig != nil && len(serverConfig.Spec.Config) > 0 { var validationErrors []string @@ -153,3 +161,29 @@ func configSetHandler(g *Gateway) mcp.ToolHandler { }, nil } } + +// checkServerLoadPolicy enforces the ActionLoad policy for a server before a +// dynamic management tool (mcp-config-set / activate-profile) mutates gateway +// state for it, mirroring the gate mcp-add already applies. Returns nil when +// allowed or when no policy client is configured. +// +// Callers pass an already-built policy.Request so the request carries the full +// server metadata from the relevant configuration. activate-profile in +// particular evaluates servers that are not yet merged into g.configuration, so +// it must build the request from the profile's configuration to avoid the +// missing-server branch in Configuration.policyRequest. +func (g *Gateway) checkServerLoadPolicy(ctx context.Context, policyReq policy.Request, session *mcp.ServerSession) error { + if g.policyClient == nil { + return nil + } + decision, err := g.policyClient.Evaluate(ctx, policyReq) + event := buildAuditEvent(policyReq, decision, err, auditClientInfoFromSession(session)) + submitAuditEvent(g.policyClient, event) + if err != nil { + return fmt.Errorf("policy check failed for server %q: %w", policyReq.Server, err) + } + if !decision.Allowed { + return fmt.Errorf("server %q is blocked by policy: %s", policyReq.Server, decision.Reason) + } + return nil +} diff --git a/pkg/gateway/policy_authz_test.go b/pkg/gateway/policy_authz_test.go new file mode 100644 index 00000000..292b2003 --- /dev/null +++ b/pkg/gateway/policy_authz_test.go @@ -0,0 +1,86 @@ +package gateway + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/mcp-gateway/pkg/catalog" + "github.com/docker/mcp-gateway/pkg/policy" +) + +func newServerLoadPolicyGateway(mock *mockPolicyClient) *Gateway { + return &Gateway{ + policyClient: mock, + configuration: Configuration{ + serverNames: []string{"backend-server"}, + servers: map[string]catalog.Server{"backend-server": {Image: "img"}}, + config: map[string]map[string]any{}, + }, + } +} + +// TestCheckServerLoadPolicy covers the shared ActionLoad gate used by the +// dynamic management tools (mcp-config-set / activate-profile). +func TestCheckServerLoadPolicy(t *testing.T) { + t.Run("blocks_denied_server", func(t *testing.T) { + mock := newMockPolicyClient() + mock.deny("backend-server", "", policy.ActionLoad, "server blocked by admin") + g := newServerLoadPolicyGateway(mock) + err := g.checkServerLoadPolicy(context.Background(), g.configuration.policyRequest("backend-server", "", policy.ActionLoad), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "blocked by policy") + }) + t.Run("denies_on_error", func(t *testing.T) { + mock := newMockPolicyClient() + mock.failWith("backend-server", "", policy.ActionLoad, errors.New("policy service down")) + g := newServerLoadPolicyGateway(mock) + err := g.checkServerLoadPolicy(context.Background(), g.configuration.policyRequest("backend-server", "", policy.ActionLoad), nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "policy") + }) + t.Run("allows_permitted_server", func(t *testing.T) { + g := newServerLoadPolicyGateway(newMockPolicyClient()) + require.NoError(t, g.checkServerLoadPolicy(context.Background(), g.configuration.policyRequest("backend-server", "", policy.ActionLoad), nil)) + }) + t.Run("nil_policy_client_allows", func(t *testing.T) { + g := &Gateway{configuration: Configuration{ + serverNames: []string{"backend-server"}, + servers: map[string]catalog.Server{"backend-server": {Image: "img"}}, + }} + require.NoError(t, g.checkServerLoadPolicy(context.Background(), g.configuration.policyRequest("backend-server", "", policy.ActionLoad), nil)) + }) +} + +// TestConfigSet_PolicyEnforcement verifies mcp-config-set refuses to mutate a +// server's config when policy denies it (and applies it when allowed). +func TestConfigSet_PolicyEnforcement(t *testing.T) { + req := &mcp.CallToolRequest{Params: &mcp.CallToolParamsRaw{ + Name: "mcp-config-set", + Arguments: json.RawMessage(`{"server":"backend-server","config":{"k":"v"}}`), + }} + + t.Run("denied_is_blocked_and_not_applied", func(t *testing.T) { + mock := newMockPolicyClient() + mock.deny("backend-server", "", policy.ActionLoad, "server blocked by admin") + g := newServerLoadPolicyGateway(mock) + res, err := configSetHandler(g)(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, res) + assert.True(t, res.IsError) + assert.Empty(t, g.configuration.config, "config must not be written when policy denies") + }) + t.Run("allowed_is_applied", func(t *testing.T) { + g := newServerLoadPolicyGateway(newMockPolicyClient()) + res, err := configSetHandler(g)(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, res) + assert.False(t, res.IsError) + assert.NotEmpty(t, g.configuration.config, "config should be written when allowed") + }) +}