From 1301dbca4d56a803573b99cb9167cb27f996ce78 Mon Sep 17 00:00:00 2001 From: taskbot Date: Mon, 16 Mar 2026 09:39:51 +0100 Subject: [PATCH 1/2] Bring composite tools into session abstraction Composite tool workflow engines were previously relying on the discovery middleware to inject DiscoveredCapabilities into the request context so that the shared stateless router could route backend tool calls within workflows. This created an implicit coupling between the middleware and composite tool execution that made unit-testing harder and was a source of integration bugs. Affected components: pkg/vmcp/router, pkg/vmcp/composer, pkg/vmcp/server, pkg/vmcp/discovery Related-to: #3872 --- pkg/vmcp/composer/workflow_engine.go | 50 +++- pkg/vmcp/composer/workflow_engine_test.go | 95 ++++++ pkg/vmcp/discovery/middleware.go | 10 +- pkg/vmcp/router/session_router.go | 66 +++++ pkg/vmcp/router/session_router_test.go | 278 ++++++++++++++++++ pkg/vmcp/server/composite_tool_converter.go | 83 ++++++ .../server/composite_tool_converter_test.go | 161 ++++++++++ pkg/vmcp/server/server.go | 116 +++++--- .../session_management_integration_test.go | 81 ++++- pkg/vmcp/server/telemetry.go | 54 ++-- 10 files changed, 912 insertions(+), 82 deletions(-) create mode 100644 pkg/vmcp/router/session_router.go create mode 100644 pkg/vmcp/router/session_router_test.go diff --git a/pkg/vmcp/composer/workflow_engine.go b/pkg/vmcp/composer/workflow_engine.go index c1d5b755c6..32648b14d2 100644 --- a/pkg/vmcp/composer/workflow_engine.go +++ b/pkg/vmcp/composer/workflow_engine.go @@ -17,7 +17,6 @@ import ( "github.com/stacklok/toolhive/pkg/audit" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/conversion" - "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/schema" ) @@ -46,6 +45,10 @@ type workflowEngine struct { // backendClient makes calls to backend MCP servers. backendClient vmcp.BackendClient + // tools is the resolved tool list for the session, used by getToolInputSchema + // for argument type coercion. Set via NewSessionWorkflowEngine. + tools []vmcp.Tool + // templateExpander handles template expansion. templateExpander TemplateExpander @@ -93,6 +96,30 @@ func NewWorkflowEngine( } } +// NewSessionWorkflowEngine creates a per-session workflow engine bound to a resolved tool list. +// tools is required: it enables argument type coercion against the session's tool schemas. +// Use this when creating per-session engines via router.NewSessionRouter. +func NewSessionWorkflowEngine( + rtr router.Router, + backendClient vmcp.BackendClient, + elicitationHandler ElicitationProtocolHandler, + stateStore WorkflowStateStore, + auditor *audit.WorkflowAuditor, + tools []vmcp.Tool, +) Composer { + return &workflowEngine{ + router: rtr, + backendClient: backendClient, + templateExpander: NewTemplateExpander(), + contextManager: newWorkflowContextManager(), + elicitationHandler: elicitationHandler, + dagExecutor: newDAGExecutor(defaultMaxParallelSteps), + stateStore: stateStore, + auditor: auditor, + tools: tools, + } +} + // ExecuteWorkflow executes a composite tool workflow. // // TODO(rate-limiting): Add rate limiting per user/session to prevent workflow execution DoS. @@ -407,7 +434,7 @@ func (e *workflowEngine) executeToolStep( // Coerce expanded arguments to expected types based on backend tool schema. // Template expansion returns strings, but backend tools expect typed values // (integer, boolean, number) as defined in their InputSchema. - rawSchema := e.getToolInputSchema(ctx, step.Tool) + rawSchema := e.getToolInputSchema(step.Tool) s := schema.MakeSchema(rawSchema) if coerced, ok := s.TryCoerce(expandedArgs).(map[string]any); ok { expandedArgs = coerced @@ -1223,20 +1250,13 @@ func (e *workflowEngine) auditStepSkipped( } } -// getToolInputSchema looks up a tool's InputSchema from discovered capabilities. -// Returns nil if the tool is not found or capabilities are not in context. -func (*workflowEngine) getToolInputSchema(ctx context.Context, toolName string) map[string]any { - caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) - if !ok || caps == nil { - return nil - } - - // Search in backend tools - for i := range caps.Tools { - if caps.Tools[i].Name == toolName { - return caps.Tools[i].InputSchema +// getToolInputSchema looks up a tool's InputSchema from the session-bound tools list. +// Returns nil if the engine has no tools list or the tool is not found. +func (e *workflowEngine) getToolInputSchema(toolName string) map[string]any { + for i := range e.tools { + if e.tools[i].Name == toolName { + return e.tools[i].InputSchema } } - return nil } diff --git a/pkg/vmcp/composer/workflow_engine_test.go b/pkg/vmcp/composer/workflow_engine_test.go index 8c4039b202..2a97b81ab6 100644 --- a/pkg/vmcp/composer/workflow_engine_test.go +++ b/pkg/vmcp/composer/workflow_engine_test.go @@ -741,3 +741,98 @@ func TestWorkflowEngine_WorkflowMetadataAvailableInTemplates(t *testing.T) { assert.Equal(t, WorkflowStatusCompleted, result.Status) assert.Len(t, result.Steps, 2) } + +func TestWorkflowEngine_SessionEngine_CoercesTemplateStringToTypedArg(t *testing.T) { + t.Parallel() + + // Template expansion always produces strings. When the engine is created + // with NewSessionWorkflowEngine, getToolInputSchema resolves the target tool's InputSchema + // and the schema coercion layer converts "42" → 42 before calling the backend. + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routermocks.NewMockRouter(ctrl) + mockBackend := mocks.NewMockBackendClient(ctrl) + + tools := []vmcp.Tool{ + { + Name: "count_items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "limit": map[string]any{"type": "integer"}, + }, + }, + }, + } + + engine := NewSessionWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools) + + target := &vmcp.BackendTarget{WorkloadID: "backend1", BaseURL: "http://backend1:8080"} + mockRouter.EXPECT().RouteTool(gomock.Any(), "count_items").Return(target, nil) + + // Expect the backend to receive the coerced integer, not the string "42". + coercedArgs := map[string]any{"limit": int64(42)} + mockBackend.EXPECT(). + CallTool(gomock.Any(), target, "count_items", coercedArgs, gomock.Any()). + Return(&vmcp.ToolCallResult{StructuredContent: map[string]any{"items": []any{}}, Content: []vmcp.Content{}}, nil) + + workflow := &WorkflowDefinition{ + Name: "coerce_test", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "n": map[string]any{"type": "string"}, + }, + }, + Steps: []WorkflowStep{ + { + ID: "step1", + Type: StepTypeTool, + Tool: "count_items", + // Template expansion produces a string; coercion must convert it to int. + Arguments: map[string]any{"limit": "{{.params.n}}"}, + }, + }, + } + + result, err := engine.ExecuteWorkflow(context.Background(), workflow, map[string]any{"n": "42"}) + require.NoError(t, err) + assert.Equal(t, WorkflowStatusCompleted, result.Status) +} + +func TestWorkflowEngine_SessionEngine_ToolNotInList_ReturnsNilSchema(t *testing.T) { + t.Parallel() + + // When NewSessionWorkflowEngine is used but the requested tool is not in the list, + // getToolInputSchema returns nil and coercion is a no-op. + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routermocks.NewMockRouter(ctrl) + mockBackend := mocks.NewMockBackendClient(ctrl) + + // Tools list does not include "other_tool". + tools := []vmcp.Tool{{Name: "known_tool", InputSchema: map[string]any{"type": "object"}}} + engine := NewSessionWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools) + + target := &vmcp.BackendTarget{WorkloadID: "backend1", BaseURL: "http://backend1:8080"} + mockRouter.EXPECT().RouteTool(gomock.Any(), "other_tool").Return(target, nil) + + // Args pass through unmodified (string stays a string). + rawArgs := map[string]any{"value": "hello"} + mockBackend.EXPECT(). + CallTool(gomock.Any(), target, "other_tool", rawArgs, gomock.Any()). + Return(&vmcp.ToolCallResult{StructuredContent: map[string]any{"ok": true}, Content: []vmcp.Content{}}, nil) + + workflow := &WorkflowDefinition{ + Name: "no_schema_test", + Steps: []WorkflowStep{ + {ID: "s1", Type: StepTypeTool, Tool: "other_tool", Arguments: rawArgs}, + }, + } + + result, err := engine.ExecuteWorkflow(context.Background(), workflow, nil) + require.NoError(t, err) + assert.Equal(t, WorkflowStatusCompleted, result.Status) +} diff --git a/pkg/vmcp/discovery/middleware.go b/pkg/vmcp/discovery/middleware.go index 6b322307bf..323c7a8bc2 100644 --- a/pkg/vmcp/discovery/middleware.go +++ b/pkg/vmcp/discovery/middleware.go @@ -281,10 +281,12 @@ func handleSubsequentRequest( return ctx, fmt.Errorf("session not found: %s", sessionID) } - // Backend tool calls are routed by session-scoped handlers registered with the SDK. - // However, composite tool workflow steps go through the shared router which requires - // DiscoveredCapabilities in the context. Inject capabilities built from the session's - // routing table so composite workflows can route backend tool calls correctly. + // Backend tool handlers (created by DefaultHandlerFactory) resolve their backend + // target by calling router.RouteTool(ctx, name), which reads DiscoveredCapabilities + // from the request context. Inject capabilities built from the session's routing + // table so these handlers can route correctly on subsequent requests. + // Note: composite tool workflow engines are created per-session and route via + // SessionRouter directly, so they no longer depend on this context value. multiSess, isMulti := rawSess.(vmcpsession.MultiSession) if !isMulti { // The session is still a StreamableSession placeholder — Phase 2 diff --git a/pkg/vmcp/router/session_router.go b/pkg/vmcp/router/session_router.go new file mode 100644 index 0000000000..4f67d173b7 --- /dev/null +++ b/pkg/vmcp/router/session_router.go @@ -0,0 +1,66 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package router + +import ( + "context" + "fmt" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// sessionRouter is a Router implementation backed directly by a RoutingTable, +// requiring no request context to resolve capabilities. It is used by +// per-session workflow engines so that composite tool execution does not depend +// on the discovery middleware injecting DiscoveredCapabilities into the context. +type sessionRouter struct { + routingTable *vmcp.RoutingTable +} + +// NewSessionRouter creates a Router that routes from the provided RoutingTable +// without reading the request context. This is the preferred router for +// composite tool workflow engines because it couples routing to the session +// rather than to middleware-managed context values. +func NewSessionRouter(rt *vmcp.RoutingTable) Router { + return &sessionRouter{routingTable: rt} +} + +// RouteTool resolves a tool name to its backend target using the session's +// routing table directly. +func (r *sessionRouter) RouteTool(_ context.Context, toolName string) (*vmcp.BackendTarget, error) { + if r.routingTable == nil || r.routingTable.Tools == nil { + return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName) + } + target, exists := r.routingTable.Tools[toolName] + if !exists { + return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName) + } + return target, nil +} + +// RouteResource resolves a resource URI to its backend target using the +// session's routing table directly. +func (r *sessionRouter) RouteResource(_ context.Context, uri string) (*vmcp.BackendTarget, error) { + if r.routingTable == nil || r.routingTable.Resources == nil { + return nil, fmt.Errorf("%w: %s", ErrResourceNotFound, uri) + } + target, exists := r.routingTable.Resources[uri] + if !exists { + return nil, fmt.Errorf("%w: %s", ErrResourceNotFound, uri) + } + return target, nil +} + +// RoutePrompt resolves a prompt name to its backend target using the session's +// routing table directly. +func (r *sessionRouter) RoutePrompt(_ context.Context, name string) (*vmcp.BackendTarget, error) { + if r.routingTable == nil || r.routingTable.Prompts == nil { + return nil, fmt.Errorf("%w: %s", ErrPromptNotFound, name) + } + target, exists := r.routingTable.Prompts[name] + if !exists { + return nil, fmt.Errorf("%w: %s", ErrPromptNotFound, name) + } + return target, nil +} diff --git a/pkg/vmcp/router/session_router_test.go b/pkg/vmcp/router/session_router_test.go new file mode 100644 index 0000000000..496674753d --- /dev/null +++ b/pkg/vmcp/router/session_router_test.go @@ -0,0 +1,278 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package router_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/router" +) + +func TestSessionRouter_RouteTool(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + routingTable *vmcp.RoutingTable + toolName string + expectedID string + expectError bool + errorContains string + }{ + { + name: "route to existing tool", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": { + WorkloadID: "backend1", + WorkloadName: "Backend 1", + BaseURL: "http://backend1:8080", + }, + }, + }, + toolName: "test_tool", + expectedID: "backend1", + expectError: false, + }, + { + name: "tool not found", + routingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + }, + toolName: "nonexistent_tool", + expectError: true, + errorContains: "tool not found", + }, + { + name: "nil routing table", + routingTable: nil, + toolName: "test_tool", + expectError: true, + errorContains: "tool not found", + }, + { + name: "nil tools map", + routingTable: &vmcp.RoutingTable{ + Tools: nil, + }, + toolName: "test_tool", + expectError: true, + errorContains: "tool not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := router.NewSessionRouter(tt.routingTable) + target, err := r.RouteTool(context.Background(), tt.toolName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + assert.Nil(t, target) + } else { + require.NoError(t, err) + require.NotNil(t, target) + assert.Equal(t, tt.expectedID, target.WorkloadID) + } + }) + } +} + +func TestSessionRouter_RouteResource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + routingTable *vmcp.RoutingTable + uri string + expectedID string + expectError bool + errorContains string + }{ + { + name: "route to existing resource", + routingTable: &vmcp.RoutingTable{ + Resources: map[string]*vmcp.BackendTarget{ + "file:///path/to/resource": { + WorkloadID: "backend2", + WorkloadName: "Backend 2", + BaseURL: "http://backend2:8080", + }, + }, + }, + uri: "file:///path/to/resource", + expectedID: "backend2", + expectError: false, + }, + { + name: "resource not found", + routingTable: &vmcp.RoutingTable{ + Resources: make(map[string]*vmcp.BackendTarget), + }, + uri: "file:///nonexistent", + expectError: true, + errorContains: "resource not found", + }, + { + name: "nil routing table", + routingTable: nil, + uri: "file:///test", + expectError: true, + errorContains: "resource not found", + }, + { + name: "nil resources map", + routingTable: &vmcp.RoutingTable{ + Resources: nil, + }, + uri: "file:///test", + expectError: true, + errorContains: "resource not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := router.NewSessionRouter(tt.routingTable) + target, err := r.RouteResource(context.Background(), tt.uri) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + assert.Nil(t, target) + } else { + require.NoError(t, err) + require.NotNil(t, target) + assert.Equal(t, tt.expectedID, target.WorkloadID) + } + }) + } +} + +func TestSessionRouter_RoutePrompt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + routingTable *vmcp.RoutingTable + promptName string + expectedID string + expectError bool + errorContains string + }{ + { + name: "route to existing prompt", + routingTable: &vmcp.RoutingTable{ + Prompts: map[string]*vmcp.BackendTarget{ + "greeting": { + WorkloadID: "backend3", + WorkloadName: "Backend 3", + BaseURL: "http://backend3:8080", + }, + }, + }, + promptName: "greeting", + expectedID: "backend3", + expectError: false, + }, + { + name: "prompt not found", + routingTable: &vmcp.RoutingTable{ + Prompts: make(map[string]*vmcp.BackendTarget), + }, + promptName: "nonexistent", + expectError: true, + errorContains: "prompt not found", + }, + { + name: "nil routing table", + routingTable: nil, + promptName: "test", + expectError: true, + errorContains: "prompt not found", + }, + { + name: "nil prompts map", + routingTable: &vmcp.RoutingTable{ + Prompts: nil, + }, + promptName: "test", + expectError: true, + errorContains: "prompt not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := router.NewSessionRouter(tt.routingTable) + target, err := r.RoutePrompt(context.Background(), tt.promptName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + assert.Nil(t, target) + } else { + require.NoError(t, err) + require.NotNil(t, target) + assert.Equal(t, tt.expectedID, target.WorkloadID) + } + }) + } +} + +func TestSessionRouter_ConcurrentAccess(t *testing.T) { + t.Parallel() + + table := &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "tool1": {WorkloadID: "backend1"}, + "tool2": {WorkloadID: "backend2"}, + }, + Resources: map[string]*vmcp.BackendTarget{ + "res1": {WorkloadID: "backend1"}, + }, + Prompts: map[string]*vmcp.BackendTarget{ + "prompt1": {WorkloadID: "backend2"}, + }, + } + + r := router.NewSessionRouter(table) + ctx := context.Background() + + const numGoroutines = 10 + const numOperations = 100 + + done := make(chan bool, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + for j := 0; j < numOperations; j++ { + _, _ = r.RouteTool(ctx, "tool1") + _, _ = r.RouteResource(ctx, "res1") + _, _ = r.RoutePrompt(ctx, "prompt1") + } + done <- true + }() + } + + for i := 0; i < numGoroutines; i++ { + <-done + } + + target, err := r.RouteTool(ctx, "tool1") + require.NoError(t, err) + assert.Equal(t, "backend1", target.WorkloadID) +} diff --git a/pkg/vmcp/server/composite_tool_converter.go b/pkg/vmcp/server/composite_tool_converter.go index 48a2b70448..953b0bb6cf 100644 --- a/pkg/vmcp/server/composite_tool_converter.go +++ b/pkg/vmcp/server/composite_tool_converter.go @@ -7,12 +7,95 @@ package server import ( "fmt" + "strings" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/composer" "github.com/stacklok/toolhive/pkg/vmcp/config" ) +// filterWorkflowDefsForSession returns only the workflow definitions whose every +// tool step references a backend tool that is present in the session routing table. +// +// If a session does not have access to a backend tool (e.g. due to identity-based +// filtering), any composite tool that depends on that backend tool is also excluded. +// This prevents a session from invoking a composite tool that would fail at runtime +// because one or more of its underlying tools are not routable for that session. +func filterWorkflowDefsForSession( + defs map[string]*composer.WorkflowDefinition, + rt *vmcp.RoutingTable, +) map[string]*composer.WorkflowDefinition { + if len(defs) == 0 { + return defs + } + + filtered := make(map[string]*composer.WorkflowDefinition, len(defs)) + for name, def := range defs { + if allToolStepsAccessible(def, rt) { + filtered[name] = def + } + } + return filtered +} + +// allToolStepsAccessible reports whether every tool step in the workflow +// references a backend tool that is present in the session routing table. +// Returns false if rt is nil and the workflow contains any tool steps, +// since a nil routing table means no tools are routable in this session. +// +// Composite tool step names use the convention "{workloadID}.{toolName}" where +// workloadID is a Kubernetes resource name (no dots). The routing table may store +// tools under resolved/prefixed names (e.g. "{workloadID}_echo" with prefix strategy), +// so we look up by BackendTarget.WorkloadID rather than the resolved key directly. +func allToolStepsAccessible(def *composer.WorkflowDefinition, rt *vmcp.RoutingTable) bool { + for _, step := range def.Steps { + if step.Type == composer.StepTypeTool { + if rt == nil { + return false + } + if !isToolStepAccessible(step.Tool, rt) { + return false + } + } + } + return true +} + +// isToolStepAccessible reports whether a composite tool step's tool name can be +// resolved to an accessible backend tool in the given routing table. +// +// Step tool names use the "{workloadID}.{toolName}" convention. Since conflict +// resolution strategies (e.g. prefix) may rename tools in the routing table +// (e.g. "echo" → "yardstick-backend_echo"), we check for accessibility by +// matching on WorkloadID and the original backend capability name rather than +// the resolved routing table key. +func isToolStepAccessible(stepTool string, rt *vmcp.RoutingTable) bool { + // Fast path: exact match in the routing table. + if _, ok := rt.Tools[stepTool]; ok { + return true + } + + // Parse "{workloadID}.{toolName}" convention. + // Workload IDs are Kubernetes resource names and cannot contain dots, + // so the first dot separates the workload ID from the tool name. + dotIdx := strings.Index(stepTool, ".") + if dotIdx <= 0 { + return false + } + workloadID := stepTool[:dotIdx] + originalName := stepTool[dotIdx+1:] + + for resolvedName, target := range rt.Tools { + if target.WorkloadID != workloadID { + continue + } + if target.GetBackendCapabilityName(resolvedName) == originalName { + return true + } + } + return false +} + // convertWorkflowDefsToTools converts workflow definitions to vmcp.Tool format. // // This creates the tool metadata (name, description, schema) that gets exposed diff --git a/pkg/vmcp/server/composite_tool_converter_test.go b/pkg/vmcp/server/composite_tool_converter_test.go index d8c32cecc9..6c05200a85 100644 --- a/pkg/vmcp/server/composite_tool_converter_test.go +++ b/pkg/vmcp/server/composite_tool_converter_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/go-cmp/cmp" + "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/composer" "github.com/stacklok/toolhive/pkg/vmcp/config" ) @@ -375,6 +376,166 @@ func TestConvertWorkflowDefsToToolsWithOutputSchema(t *testing.T) { } } +func TestFilterWorkflowDefsForSession(t *testing.T) { + t.Parallel() + + makeRT := func(toolNames ...string) *vmcp.RoutingTable { + rt := &vmcp.RoutingTable{Tools: make(map[string]*vmcp.BackendTarget)} + for _, name := range toolNames { + rt.Tools[name] = &vmcp.BackendTarget{WorkloadID: name} + } + return rt + } + + tests := []struct { + name string + defs map[string]*composer.WorkflowDefinition + rt *vmcp.RoutingTable + wantNames []string // workflow names expected in result + }{ + { + name: "empty defs", + defs: map[string]*composer.WorkflowDefinition{}, + rt: makeRT("tool_a"), + wantNames: []string{}, + }, + { + name: "all tools accessible", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{{ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}}, + }, + }, + rt: makeRT("tool_a", "tool_b"), + wantNames: []string{"wf1"}, + }, + { + name: "missing tool excludes workflow", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{{ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}}, + }, + }, + rt: makeRT("tool_b"), + wantNames: []string{}, + }, + { + name: "partially accessible: only accessible workflow included", + defs: map[string]*composer.WorkflowDefinition{ + "wf_ok": { + Name: "wf_ok", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}, + }, + }, + "wf_restricted": { + Name: "wf_restricted", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}, + {ID: "s2", Type: composer.StepTypeTool, Tool: "tool_secret"}, + }, + }, + }, + rt: makeRT("tool_a"), + wantNames: []string{"wf_ok"}, + }, + { + name: "elicitation steps do not require routing table entry", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeElicitation}, + {ID: "s2", Type: composer.StepTypeTool, Tool: "tool_a"}, + }, + }, + }, + rt: makeRT("tool_a"), + wantNames: []string{"wf1"}, + }, + { + // Composite tool steps use "{workloadID}.{toolName}" convention. + // With prefix conflict resolution the routing table key is + // "{workloadID}_echo", but the step still uses "{workloadID}.echo". + // The filter must resolve via WorkloadID + OriginalCapabilityName. + name: "dotted step tool resolved via workload ID and original name", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "my-backend.echo"}, + }, + }, + }, + rt: func() *vmcp.RoutingTable { + rt := &vmcp.RoutingTable{Tools: make(map[string]*vmcp.BackendTarget)} + // Prefix strategy stores "my-backend_echo" as the resolved key. + rt.Tools["my-backend_echo"] = &vmcp.BackendTarget{ + WorkloadID: "my-backend", + OriginalCapabilityName: "echo", + } + return rt + }(), + wantNames: []string{"wf1"}, + }, + { + name: "dotted step tool excluded when workload not in session", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "restricted-backend.echo"}, + }, + }, + }, + rt: func() *vmcp.RoutingTable { + rt := &vmcp.RoutingTable{Tools: make(map[string]*vmcp.BackendTarget)} + rt.Tools["other-backend_echo"] = &vmcp.BackendTarget{ + WorkloadID: "other-backend", + OriginalCapabilityName: "echo", + } + return rt + }(), + wantNames: []string{}, + }, + { + name: "nil routing table excludes workflows with tool steps", + defs: map[string]*composer.WorkflowDefinition{ + "wf_tool": { + Name: "wf_tool", + Steps: []composer.WorkflowStep{{ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}}, + }, + "wf_elicit_only": { + Name: "wf_elicit_only", + Steps: []composer.WorkflowStep{{ID: "s1", Type: composer.StepTypeElicitation}}, + }, + }, + rt: nil, + wantNames: []string{"wf_elicit_only"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := filterWorkflowDefsForSession(tt.defs, tt.rt) + + if len(got) != len(tt.wantNames) { + t.Errorf("filterWorkflowDefsForSession() returned %d defs, want %d (%v)", + len(got), len(tt.wantNames), tt.wantNames) + } + for _, name := range tt.wantNames { + if _, ok := got[name]; !ok { + t.Errorf("expected workflow %q in result but it was absent", name) + } + } + }) + } +} + func TestBuildOutputPropertySchema(t *testing.T) { t.Parallel() diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 2d267ce4a3..d04178d22f 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -213,11 +213,17 @@ type Server struct { // Thread-safety: Safe for concurrent reads (no writes after initialization). workflowDefs map[string]*composer.WorkflowDefinition - // Workflow executors for composite tools (adapters around composer + definition). - // Used by capability adapter to create composite tool handlers. - // Initialized during construction and read-only thereafter. - // Thread-safety: Safe for concurrent reads (no writes after initialization). - workflowExecutors map[string]adapter.WorkflowExecutor + // composerFactory creates a per-session workflow Composer bound to the + // session's routing table and tool list. Called at session registration + // time so that composite tool execution routes via the session rather than + // depending on the discovery middleware injecting DiscoveredCapabilities + // into the request context. + composerFactory func(rt *vmcp.RoutingTable, tools []vmcp.Tool) composer.Composer + + // workflowInstruments holds pre-created OTEL metric instruments for workflow + // telemetry. Nil when telemetry is disabled. Created once at server startup + // and shared across all per-session executor wrappers. + workflowInstruments *workflowExecutorInstruments // Ready channel signals when the server is ready to accept connections. // Closed once the listener is created and serving. @@ -336,23 +342,34 @@ func New( stateStore := composer.NewInMemoryStateStore(5*time.Minute, 1*time.Hour) workflowComposer := composer.NewWorkflowEngine(rt, backendClient, elicitationHandler, stateStore, workflowAuditor) - // Validate workflows and create executors (fail fast on invalid workflows) - var workflowExecutors map[string]adapter.WorkflowExecutor + // composerFactory builds a per-session workflow engine at session registration + // time, binding composite tool routing to the session's own routing table and + // tool list. This removes composite tools' dependency on the discovery middleware + // injecting DiscoveredCapabilities into the request context. + sessionComposerFactory := func(sessionRT *vmcp.RoutingTable, sessionTools []vmcp.Tool) composer.Composer { + return composer.NewSessionWorkflowEngine( + router.NewSessionRouter(sessionRT), backendClient, elicitationHandler, stateStore, workflowAuditor, + sessionTools, + ) + } + + // Validate workflows (fail fast on invalid definitions) var err error - workflowDefs, workflowExecutors, err = validateAndCreateExecutors(workflowComposer, workflowDefs) + workflowDefs, err = validateWorkflows(workflowComposer, workflowDefs) if err != nil { return nil, fmt.Errorf("workflow validation failed: %w", err) } - // Decorate workflow executors with telemetry if provider is configured - if cfg.TelemetryProvider != nil && len(workflowExecutors) > 0 { - workflowExecutors, err = monitorWorkflowExecutors( + // Pre-create workflow telemetry instruments once so they can be reused + // across all per-session executor wrappers without re-registering metrics. + var workflowInstruments *workflowExecutorInstruments + if cfg.TelemetryProvider != nil && len(workflowDefs) > 0 { + workflowInstruments, err = newWorkflowExecutorInstruments( cfg.TelemetryProvider.MeterProvider(), cfg.TelemetryProvider.TracerProvider(), - workflowExecutors, ) if err != nil { - return nil, fmt.Errorf("failed to monitor workflow executors: %w", err) + return nil, fmt.Errorf("failed to create workflow executor telemetry: %w", err) } } @@ -396,21 +413,22 @@ func New( // Create Server instance srv := &Server{ - config: cfg, - mcpServer: mcpServer, - router: rt, - backendClient: backendClient, - handlerFactory: handlerFactory, - discoveryMgr: discoveryMgr, - backendRegistry: backendRegistry, - sessionManager: sessionManager, - capabilityAdapter: capabilityAdapter, - workflowDefs: workflowDefs, - workflowExecutors: workflowExecutors, - ready: make(chan struct{}), - healthMonitor: healthMon, - statusReporter: cfg.StatusReporter, - vmcpSessionMgr: vmcpSessMgr, + config: cfg, + mcpServer: mcpServer, + router: rt, + backendClient: backendClient, + handlerFactory: handlerFactory, + discoveryMgr: discoveryMgr, + backendRegistry: backendRegistry, + sessionManager: sessionManager, + capabilityAdapter: capabilityAdapter, + workflowDefs: workflowDefs, + composerFactory: sessionComposerFactory, + workflowInstruments: workflowInstruments, + ready: make(chan struct{}), + healthMonitor: healthMon, + statusReporter: cfg.StatusReporter, + vmcpSessionMgr: vmcpSessMgr, } // Register OnRegisterSession hook to inject capabilities after SDK registers session. @@ -1025,21 +1043,29 @@ func (s *Server) handleSessionRegistrationImpl(ctx context.Context, session serv return s.injectTools(ctx, session, adaptedTools, compositeSDKTools) } -// collectCompositeTools converts workflow definitions to SDK tools, +// collectCompositeTools converts workflow definitions to SDK tools for the given session, // validating that no composite tool name collides with a backend tool name. // Returns an empty slice (not an error) if no workflow defs are configured or conflicts are found. +// Composite tools whose underlying backend tools are not routable in this session are excluded, +// so a session that lacks access to a backend tool also cannot access composite tools that depend on it. func (s *Server) collectCompositeTools(sessionID string) ([]server.ServerTool, error) { if len(s.workflowDefs) == 0 { return nil, nil } - compositeTools := convertWorkflowDefsToTools(s.workflowDefs) multiSess, hasSess := s.vmcpSessionMgr.GetMultiSession(sessionID) if !hasSess { slog.Error("session not found after creation; skipping composite tools", "session_id", sessionID) return nil, nil } + + sessionDefs := filterWorkflowDefsForSession(s.workflowDefs, multiSess.GetRoutingTable()) + if len(sessionDefs) == 0 { + return nil, nil + } + + compositeTools := convertWorkflowDefsToTools(sessionDefs) if err := validateNoToolConflicts(multiSess.Tools(), compositeTools); err != nil { slog.Error("composite tool name conflict detected; skipping composite tools", "session_id", sessionID, @@ -1047,7 +1073,20 @@ func (s *Server) collectCompositeTools(sessionID string) ([]server.ServerTool, e return nil, nil } - sdkTools, err := s.capabilityAdapter.ToCompositeToolSDKTools(compositeTools, s.workflowExecutors) + // Build per-session workflow executors so that composite tool routing uses + // the session's own routing table rather than DiscoveredCapabilities from + // the request context (which is injected by the discovery middleware). + sessionComposer := s.composerFactory(multiSess.GetRoutingTable(), multiSess.Tools()) + sessionExecutors := make(map[string]adapter.WorkflowExecutor, len(sessionDefs)) + for name, def := range sessionDefs { + ex := newComposerWorkflowExecutor(sessionComposer, def) + if s.workflowInstruments != nil { + ex = s.workflowInstruments.wrapExecutor(name, ex) + } + sessionExecutors[name] = ex + } + + sdkTools, err := s.capabilityAdapter.ToCompositeToolSDKTools(compositeTools, sessionExecutors) if err != nil { return nil, fmt.Errorf("failed to convert composite tools: %w", err) } @@ -1102,33 +1141,30 @@ func (s *Server) injectTools( return nil } -// validateAndCreateExecutors validates workflow definitions and creates executors. +// validateWorkflows validates workflow definitions, returning only the valid ones. // // This function: // 1. Validates each workflow definition (cycle detection, tool references, etc.) // 2. Returns error on first validation failure (fail-fast) -// 3. Creates workflow executors for all valid workflows // // Failing fast on invalid workflows provides immediate user feedback and prevents // security issues (resource exhaustion from cycles, information disclosure from errors). -func validateAndCreateExecutors( +func validateWorkflows( validator composer.Composer, workflowDefs map[string]*composer.WorkflowDefinition, -) (map[string]*composer.WorkflowDefinition, map[string]adapter.WorkflowExecutor, error) { +) (map[string]*composer.WorkflowDefinition, error) { if len(workflowDefs) == 0 { - return nil, nil, nil + return nil, nil } validDefs := make(map[string]*composer.WorkflowDefinition, len(workflowDefs)) - validExecutors := make(map[string]adapter.WorkflowExecutor, len(workflowDefs)) for name, def := range workflowDefs { if err := validator.ValidateWorkflow(context.Background(), def); err != nil { - return nil, nil, fmt.Errorf("invalid workflow definition '%s': %w", name, err) + return nil, fmt.Errorf("invalid workflow definition '%s': %w", name, err) } validDefs[name] = def - validExecutors[name] = newComposerWorkflowExecutor(validator, def) slog.Debug("validated workflow definition", "name", name) } @@ -1136,7 +1172,7 @@ func validateAndCreateExecutors( slog.Info("loaded valid composite tool workflows", "count", len(validDefs)) } - return validDefs, validExecutors, nil + return validDefs, nil } // GetBackendHealthStatus returns the health status of a specific backend. diff --git a/pkg/vmcp/server/session_management_integration_test.go b/pkg/vmcp/server/session_management_integration_test.go index 127f30b295..7ddfc3785e 100644 --- a/pkg/vmcp/server/session_management_integration_test.go +++ b/pkg/vmcp/server/session_management_integration_test.go @@ -115,7 +115,13 @@ func newMockFactory(t *testing.T, ctrl *gomock.Controller, tools []vmcp.Tool) (* mock.EXPECT().Resources().Return(nil).AnyTimes() mock.EXPECT().Prompts().Return(nil).AnyTimes() mock.EXPECT().BackendSessions().Return(nil).AnyTimes() - mock.EXPECT().GetRoutingTable().Return(nil).AnyTimes() + // Build a routing table from the provided tools so that + // filterWorkflowDefsForSession can check tool accessibility per session. + rt := &vmcp.RoutingTable{Tools: make(map[string]*vmcp.BackendTarget, len(tools))} + for _, tool := range tools { + rt.Tools[tool.Name] = &vmcp.BackendTarget{WorkloadID: tool.Name} + } + mock.EXPECT().GetRoutingTable().Return(rt).AnyTimes() mock.EXPECT().ReadResource(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mock.EXPECT().GetPrompt(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() callResult := &vmcp.ToolCallResult{Content: []vmcp.Content{{Type: "text", Text: "fake result"}}} @@ -770,6 +776,79 @@ func TestIntegration_SessionManagement_CompositeToolConflict(t *testing.T) { "backend tool call should succeed after conflict detection; body: %s", string(respBody)) } +// TestIntegration_SessionManagement_CompositeToolsFilteredForSession verifies that +// composite tools whose underlying backend tools are not routable in a session are +// excluded from that session's tools/list. This enforces per-session authorization: +// a session that cannot access a backend tool also cannot access composite tools +// that depend on it. +func TestIntegration_SessionManagement_CompositeToolsFilteredForSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + // The session only has "allowed-tool"; it does NOT have "restricted-tool". + allowedTool := vmcp.Tool{Name: "allowed-tool", Description: "accessible backend tool"} + factory, _ := newMockFactory(t, ctrl, []vmcp.Tool{allowedTool}) + + // accessible-workflow only uses allowed-tool → should appear for this session. + accessibleDef := &composer.WorkflowDefinition{ + Name: "accessible-workflow", + Description: "uses only allowed backend tools", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "allowed-tool"}, + }, + } + // restricted-workflow uses restricted-tool which is absent from this session's + // routing table → must NOT appear for this session. + restrictedDef := &composer.WorkflowDefinition{ + Name: "restricted-workflow", + Description: "uses a backend tool not accessible in this session", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "allowed-tool"}, + {ID: "s2", Type: composer.StepTypeTool, Tool: "restricted-tool"}, + }, + } + + ts := buildTestServerWithOptions(t, factory, serverOptions{ + workflowDefs: map[string]*composer.WorkflowDefinition{ + "accessible-workflow": accessibleDef, + "restricted-workflow": restrictedDef, + }, + }) + + initResp := postMCP(t, ts.URL, map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{"name": "test", "version": "1.0"}, + }, + }, "") + defer initResp.Body.Close() + require.Equal(t, http.StatusOK, initResp.StatusCode) + + sessionID := initResp.Header.Get("Mcp-Session-Id") + require.NotEmpty(t, sessionID) + + // Wait until accessible-workflow appears, then verify restricted-workflow does not. + require.Eventually(t, func() bool { + for _, n := range listToolNames(t, ts.URL, sessionID) { + if n == "accessible-workflow" { + return true + } + } + return false + }, 2*time.Second, 20*time.Millisecond, + "accessible-workflow should appear in tools/list") + + toolNames := listToolNames(t, ts.URL, sessionID) + assert.Contains(t, toolNames, "accessible-workflow", + "composite tool whose backend tools are all accessible must be visible") + assert.NotContains(t, toolNames, "restricted-workflow", + "composite tool that depends on an inaccessible backend tool must be hidden") +} + // TestIntegration_SessionManagement_OptimizerMode verifies that when an optimizer // factory is configured with session management, tools/list exposes only // find_tool and call_tool (the optimizer wraps all backend tools). diff --git a/pkg/vmcp/server/telemetry.go b/pkg/vmcp/server/telemetry.go index dc000aa184..ba275fe86f 100644 --- a/pkg/vmcp/server/telemetry.go +++ b/pkg/vmcp/server/telemetry.go @@ -258,17 +258,23 @@ func (t telemetryBackendClient) ListCapabilities( return t.backendClient.ListCapabilities(ctx, target) } -// monitorWorkflowExecutors decorates workflow executors with telemetry recording. -// It wraps each executor to emit metrics and traces for execution count, duration, and errors. -func monitorWorkflowExecutors( +// workflowExecutorInstruments holds pre-created OTEL instruments for workflow telemetry. +// Instruments are created once at server startup and reused across all session registrations +// to avoid re-registering the same metric names on every session creation. +type workflowExecutorInstruments struct { + tracer trace.Tracer + executionsTotal metric.Int64Counter + errorsTotal metric.Int64Counter + executionDuration metric.Float64Histogram +} + +// newWorkflowExecutorInstruments creates the OTEL instruments used to decorate +// per-session workflow executors. Call this once at server startup; pass the +// result to wrapExecutor at session registration time. +func newWorkflowExecutorInstruments( meterProvider metric.MeterProvider, tracerProvider trace.TracerProvider, - executors map[string]adapter.WorkflowExecutor, -) (map[string]adapter.WorkflowExecutor, error) { - if len(executors) == 0 { - return executors, nil - } - +) (*workflowExecutorInstruments, error) { meter := meterProvider.Meter(instrumentationName) executionsTotal, err := meter.Int64Counter( @@ -297,21 +303,25 @@ func monitorWorkflowExecutors( return nil, fmt.Errorf("failed to create workflow duration histogram: %w", err) } - tracer := tracerProvider.Tracer(instrumentationName) + return &workflowExecutorInstruments{ + tracer: tracerProvider.Tracer(instrumentationName), + executionsTotal: executionsTotal, + errorsTotal: errorsTotal, + executionDuration: executionDuration, + }, nil +} - monitored := make(map[string]adapter.WorkflowExecutor, len(executors)) - for name, executor := range executors { - monitored[name] = &telemetryWorkflowExecutor{ - name: name, - executor: executor, - tracer: tracer, - executionsTotal: executionsTotal, - errorsTotal: errorsTotal, - executionDuration: executionDuration, - } +// wrapExecutor returns a telemetry-decorated WorkflowExecutor using the +// pre-created instruments. Safe to call on every session registration. +func (i *workflowExecutorInstruments) wrapExecutor(name string, ex adapter.WorkflowExecutor) adapter.WorkflowExecutor { + return &telemetryWorkflowExecutor{ + name: name, + executor: ex, + tracer: i.tracer, + executionsTotal: i.executionsTotal, + errorsTotal: i.errorsTotal, + executionDuration: i.executionDuration, } - - return monitored, nil } // telemetryWorkflowExecutor wraps a WorkflowExecutor with telemetry recording. From 2ce0d3f2ee2d6e2aececebfe5ce048959bdc081b Mon Sep 17 00:00:00 2001 From: taskbot Date: Wed, 18 Mar 2026 09:25:11 +0100 Subject: [PATCH 2/2] refactor(vmcp): unify composite tools and optimizer as session decorators MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both composite tools and the optimizer now implement the MultiSession decorator pattern (same as hijackPreventionDecorator) rather than having bespoke SDK wiring in handleSessionRegistrationImpl. New session decorators: - session/compositetools: appends composite tools to Tools(), routes their CallTool invocations to per-session workflow executors - session/optimizerdec: replaces Tools() with [find_tool, call_tool]; find_tool routes through the optimizer, call_tool delegates to the underlying session for normal backend routing sessionmanager.Manager gains DecorateSession() to swap in a wrapped session after creation. handleSessionRegistrationImpl becomes a flat decoration sequence (apply compositetools → apply optimizer → register tools) with no branching on optimizer vs non-optimizer paths. adapter.WorkflowExecutor/WorkflowResult become type aliases for the compositetools package types so the two layers share a single definition. adapter.CreateOptimizerTools is deleted; its logic lives in optimizerdec. --- .../composer/elicitation_integration_test.go | 12 +- pkg/vmcp/composer/testhelpers_test.go | 8 +- .../workflow_audit_integration_test.go | 10 +- pkg/vmcp/composer/workflow_engine.go | 46 ++-- pkg/vmcp/composer/workflow_engine_test.go | 19 +- pkg/vmcp/router/default_router.go | 7 + pkg/vmcp/router/mocks/mock_router.go | 14 ++ pkg/vmcp/router/router.go | 8 + pkg/vmcp/router/session_router.go | 70 +++++- pkg/vmcp/router/session_router_test.go | 105 +++++++++ pkg/vmcp/server/adapter/handler_factory.go | 18 +- pkg/vmcp/server/adapter/optimizer_adapter.go | 120 +---------- .../server/adapter/optimizer_adapter_test.go | 97 +-------- pkg/vmcp/server/server.go | 156 +++++++------- pkg/vmcp/server/session_manager_interface.go | 9 + .../server/sessionmanager/session_manager.go | 35 +++ .../sessionmanager/session_manager_test.go | 107 ++++++++++ pkg/vmcp/session/compositetools/decorator.go | 113 ++++++++++ .../session/compositetools/decorator_test.go | 126 +++++++++++ pkg/vmcp/session/optimizerdec/decorator.go | 178 ++++++++++++++++ .../session/optimizerdec/decorator_test.go | 199 ++++++++++++++++++ .../virtualmcp_optimizer_composite_test.go | 25 ++- .../virtualmcp/virtualmcp_optimizer_test.go | 11 +- 23 files changed, 1134 insertions(+), 359 deletions(-) create mode 100644 pkg/vmcp/session/compositetools/decorator.go create mode 100644 pkg/vmcp/session/compositetools/decorator_test.go create mode 100644 pkg/vmcp/session/optimizerdec/decorator.go create mode 100644 pkg/vmcp/session/optimizerdec/decorator_test.go diff --git a/pkg/vmcp/composer/elicitation_integration_test.go b/pkg/vmcp/composer/elicitation_integration_test.go index bff4d70fc6..d4edd40c05 100644 --- a/pkg/vmcp/composer/elicitation_integration_test.go +++ b/pkg/vmcp/composer/elicitation_integration_test.go @@ -34,7 +34,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_Accept(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "deployment-workflow", @@ -151,7 +151,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_Decline(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "test-workflow", @@ -228,7 +228,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_Cancel(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "test-workflow", @@ -275,7 +275,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_Timeout(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "test-workflow", @@ -309,7 +309,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_NoHandler(t *testing.T) { te := newTestEngine(t) // Create engine WITHOUT elicitation handler - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, nil, nil) workflow := &WorkflowDefinition{ Name: "test-workflow", @@ -348,7 +348,7 @@ func TestWorkflowEngine_MultiStepWithElicitation(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "multi-step-workflow", diff --git a/pkg/vmcp/composer/testhelpers_test.go b/pkg/vmcp/composer/testhelpers_test.go index 29b8cbac78..be1cdb89ba 100644 --- a/pkg/vmcp/composer/testhelpers_test.go +++ b/pkg/vmcp/composer/testhelpers_test.go @@ -30,8 +30,14 @@ func newTestEngine(t *testing.T) *testEngine { t.Cleanup(ctrl.Finish) mockRouter := routermocks.NewMockRouter(ctrl) + // ResolveToolName is called by getToolInputSchema on every tool step. + // For tests that use NewWorkflowEngine (no tools list), the result is + // always nil, so a pass-through AnyTimes expectation is sufficient. + mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, name string) string { return name }). + AnyTimes() mockBackend := mocks.NewMockBackendClient(ctrl) - engine := NewWorkflowEngine(mockRouter, mockBackend, nil, nil, nil) // nil elicitationHandler, stateStore, and auditor for simple tests + engine := NewWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, nil) // nil elicitationHandler, stateStore, auditor, and tools for simple tests return &testEngine{ Engine: engine, diff --git a/pkg/vmcp/composer/workflow_audit_integration_test.go b/pkg/vmcp/composer/workflow_audit_integration_test.go index ae7c23ccbd..5e5e4d7ded 100644 --- a/pkg/vmcp/composer/workflow_audit_integration_test.go +++ b/pkg/vmcp/composer/workflow_audit_integration_test.go @@ -39,7 +39,7 @@ func TestWorkflowEngine_WithAuditor_SuccessfulWorkflow(t *testing.T) { require.NoError(t, err) // Create engine with auditor - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) // Setup simple workflow workflow := simpleWorkflow("audit-test", @@ -86,7 +86,7 @@ func TestWorkflowEngine_WithAuditor_FailedWorkflow(t *testing.T) { }) require.NoError(t, err) - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) workflow := simpleWorkflow("fail-test", toolStep("step1", "tool1", map[string]any{"arg": "value"}), @@ -119,7 +119,7 @@ func TestWorkflowEngine_WithAuditor_WorkflowTimeout(t *testing.T) { }) require.NoError(t, err) - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) workflow := &WorkflowDefinition{ Name: "timeout-test", @@ -162,7 +162,7 @@ func TestWorkflowEngine_WithAuditor_StepSkipped(t *testing.T) { }) require.NoError(t, err) - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) workflow := &WorkflowDefinition{ Name: "skip-test", @@ -215,7 +215,7 @@ func TestWorkflowEngine_WithAuditor_RetryStep(t *testing.T) { }) require.NoError(t, err) - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) workflow := &WorkflowDefinition{ Name: "retry-test", diff --git a/pkg/vmcp/composer/workflow_engine.go b/pkg/vmcp/composer/workflow_engine.go index 32648b14d2..ecba1b7b10 100644 --- a/pkg/vmcp/composer/workflow_engine.go +++ b/pkg/vmcp/composer/workflow_engine.go @@ -46,7 +46,7 @@ type workflowEngine struct { backendClient vmcp.BackendClient // tools is the resolved tool list for the session, used by getToolInputSchema - // for argument type coercion. Set via NewSessionWorkflowEngine. + // for argument type coercion. Nil means no schema-based coercion (discovery-based routing). tools []vmcp.Tool // templateExpander handles template expansion. @@ -70,12 +70,12 @@ type workflowEngine struct { // NewWorkflowEngine creates a new workflow execution engine. // -// The elicitationHandler parameter is optional. If nil, elicitation steps will fail. -// This allows the engine to be used without elicitation support for simple workflows. +// tools is the resolved tool list for schema-based argument type coercion. Pass nil +// when the engine is used for validation or discovery-based routing only. // +// The elicitationHandler parameter is optional. If nil, elicitation steps will fail. // The stateStore parameter is optional. If nil, workflow status tracking and cancellation // will not be available. Use NewInMemoryStateStore() for basic state tracking. -// // The auditor parameter is optional. If nil, workflow execution will not be audited. func NewWorkflowEngine( rtr router.Router, @@ -83,28 +83,6 @@ func NewWorkflowEngine( elicitationHandler ElicitationProtocolHandler, stateStore WorkflowStateStore, auditor *audit.WorkflowAuditor, -) Composer { - return &workflowEngine{ - router: rtr, - backendClient: backendClient, - templateExpander: NewTemplateExpander(), - contextManager: newWorkflowContextManager(), - elicitationHandler: elicitationHandler, - dagExecutor: newDAGExecutor(defaultMaxParallelSteps), - stateStore: stateStore, - auditor: auditor, - } -} - -// NewSessionWorkflowEngine creates a per-session workflow engine bound to a resolved tool list. -// tools is required: it enables argument type coercion against the session's tool schemas. -// Use this when creating per-session engines via router.NewSessionRouter. -func NewSessionWorkflowEngine( - rtr router.Router, - backendClient vmcp.BackendClient, - elicitationHandler ElicitationProtocolHandler, - stateStore WorkflowStateStore, - auditor *audit.WorkflowAuditor, tools []vmcp.Tool, ) Composer { return &workflowEngine{ @@ -434,7 +412,7 @@ func (e *workflowEngine) executeToolStep( // Coerce expanded arguments to expected types based on backend tool schema. // Template expansion returns strings, but backend tools expect typed values // (integer, boolean, number) as defined in their InputSchema. - rawSchema := e.getToolInputSchema(step.Tool) + rawSchema := e.getToolInputSchema(ctx, step.Tool) s := schema.MakeSchema(rawSchema) if coerced, ok := s.TryCoerce(expandedArgs).(map[string]any); ok { expandedArgs = coerced @@ -1250,11 +1228,17 @@ func (e *workflowEngine) auditStepSkipped( } } -// getToolInputSchema looks up a tool's InputSchema from the session-bound tools list. -// Returns nil if the engine has no tools list or the tool is not found. -func (e *workflowEngine) getToolInputSchema(toolName string) map[string]any { +// getToolInputSchema looks up a tool's InputSchema from the session-bound tools +// list. If toolName uses the dot convention "{workloadID}.{originalCapabilityName}", +// ResolveToolName is called to translate it to the conflict-resolved key before +// lookup. Returns nil if the engine has no tools list or the tool is not found. +func (e *workflowEngine) getToolInputSchema(ctx context.Context, toolName string) map[string]any { + resolved := toolName + if e.router != nil { + resolved = e.router.ResolveToolName(ctx, toolName) + } for i := range e.tools { - if e.tools[i].Name == toolName { + if e.tools[i].Name == resolved { return e.tools[i].InputSchema } } diff --git a/pkg/vmcp/composer/workflow_engine_test.go b/pkg/vmcp/composer/workflow_engine_test.go index 2a97b81ab6..323937a3d7 100644 --- a/pkg/vmcp/composer/workflow_engine_test.go +++ b/pkg/vmcp/composer/workflow_engine_test.go @@ -390,9 +390,12 @@ func TestWorkflowEngine_ParallelExecution(t *testing.T) { defer ctrl.Finish() mockRouter := routermocks.NewMockRouter(ctrl) + mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, name string) string { return name }). + AnyTimes() mockBackend := mocks.NewMockBackendClient(ctrl) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(mockRouter, mockBackend, nil, stateStore, nil) + engine := NewWorkflowEngine(mockRouter, mockBackend, nil, stateStore, nil, nil) // Track execution timing to verify parallel execution var executionMu sync.Mutex @@ -746,12 +749,15 @@ func TestWorkflowEngine_SessionEngine_CoercesTemplateStringToTypedArg(t *testing t.Parallel() // Template expansion always produces strings. When the engine is created - // with NewSessionWorkflowEngine, getToolInputSchema resolves the target tool's InputSchema + // with a bound tool list, getToolInputSchema resolves the target tool's InputSchema // and the schema coercion layer converts "42" → 42 before calling the backend. ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) mockRouter := routermocks.NewMockRouter(ctrl) + mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, name string) string { return name }). + AnyTimes() mockBackend := mocks.NewMockBackendClient(ctrl) tools := []vmcp.Tool{ @@ -766,7 +772,7 @@ func TestWorkflowEngine_SessionEngine_CoercesTemplateStringToTypedArg(t *testing }, } - engine := NewSessionWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools) + engine := NewWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools) target := &vmcp.BackendTarget{WorkloadID: "backend1", BaseURL: "http://backend1:8080"} mockRouter.EXPECT().RouteTool(gomock.Any(), "count_items").Return(target, nil) @@ -804,17 +810,20 @@ func TestWorkflowEngine_SessionEngine_CoercesTemplateStringToTypedArg(t *testing func TestWorkflowEngine_SessionEngine_ToolNotInList_ReturnsNilSchema(t *testing.T) { t.Parallel() - // When NewSessionWorkflowEngine is used but the requested tool is not in the list, + // When a bound tool list is provided but the requested tool is not in it, // getToolInputSchema returns nil and coercion is a no-op. ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) mockRouter := routermocks.NewMockRouter(ctrl) + mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, name string) string { return name }). + AnyTimes() mockBackend := mocks.NewMockBackendClient(ctrl) // Tools list does not include "other_tool". tools := []vmcp.Tool{{Name: "known_tool", InputSchema: map[string]any{"type": "object"}}} - engine := NewSessionWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools) + engine := NewWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools) target := &vmcp.BackendTarget{WorkloadID: "backend1", BaseURL: "http://backend1:8080"} mockRouter.EXPECT().RouteTool(gomock.Any(), "other_tool").Return(target, nil) diff --git a/pkg/vmcp/router/default_router.go b/pkg/vmcp/router/default_router.go index 199e94243d..3e6f3bfba7 100644 --- a/pkg/vmcp/router/default_router.go +++ b/pkg/vmcp/router/default_router.go @@ -89,6 +89,13 @@ func (*defaultRouter) RouteTool(ctx context.Context, toolName string) (*vmcp.Bac ) } +// ResolveToolName returns toolName unchanged. The defaultRouter has no static +// routing table, so dot-convention resolution is not available; the caller +// should already be using resolved names when working with this router. +func (*defaultRouter) ResolveToolName(_ context.Context, toolName string) string { + return toolName +} + // RouteResource resolves a resource URI to its backend target. // With lazy discovery, this method gets capabilities from the request context // instead of using a cached routing table. diff --git a/pkg/vmcp/router/mocks/mock_router.go b/pkg/vmcp/router/mocks/mock_router.go index 6db4f729b5..768522c506 100644 --- a/pkg/vmcp/router/mocks/mock_router.go +++ b/pkg/vmcp/router/mocks/mock_router.go @@ -41,6 +41,20 @@ func (m *MockRouter) EXPECT() *MockRouterMockRecorder { return m.recorder } +// ResolveToolName mocks base method. +func (m *MockRouter) ResolveToolName(ctx context.Context, toolName string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResolveToolName", ctx, toolName) + ret0, _ := ret[0].(string) + return ret0 +} + +// ResolveToolName indicates an expected call of ResolveToolName. +func (mr *MockRouterMockRecorder) ResolveToolName(ctx, toolName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveToolName", reflect.TypeOf((*MockRouter)(nil).ResolveToolName), ctx, toolName) +} + // RoutePrompt mocks base method. func (m *MockRouter) RoutePrompt(ctx context.Context, name string) (*vmcp.BackendTarget, error) { m.ctrl.T.Helper() diff --git a/pkg/vmcp/router/router.go b/pkg/vmcp/router/router.go index 9b4cb52f59..63d7fc299b 100644 --- a/pkg/vmcp/router/router.go +++ b/pkg/vmcp/router/router.go @@ -28,6 +28,14 @@ type Router interface { // Returns ErrToolNotFound if the tool doesn't exist in any backend. RouteTool(ctx context.Context, toolName string) (*vmcp.BackendTarget, error) + // ResolveToolName translates a tool name (which may use the dot-convention + // "{workloadID}.{originalCapabilityName}") to the conflict-resolved routing + // table key used in the session tools list. Returns toolName unchanged when + // the name cannot be resolved or the router has no static routing table — + // pass-through semantics so callers can use the result directly without + // special-casing the unresolvable case. + ResolveToolName(ctx context.Context, toolName string) string + // RouteResource resolves a resource URI to its backend target. // Returns ErrResourceNotFound if the resource doesn't exist in any backend. RouteResource(ctx context.Context, uri string) (*vmcp.BackendTarget, error) diff --git a/pkg/vmcp/router/session_router.go b/pkg/vmcp/router/session_router.go index 4f67d173b7..436cc18c70 100644 --- a/pkg/vmcp/router/session_router.go +++ b/pkg/vmcp/router/session_router.go @@ -6,6 +6,7 @@ package router import ( "context" "fmt" + "strings" "github.com/stacklok/toolhive/pkg/vmcp" ) @@ -28,15 +29,76 @@ func NewSessionRouter(rt *vmcp.RoutingTable) Router { // RouteTool resolves a tool name to its backend target using the session's // routing table directly. +// +// Two naming conventions are supported: +// +// 1. Exact key: the resolved/conflict-resolved name stored in the routing +// table (e.g. "my-backend_echo" after prefix conflict resolution). +// +// 2. Dot convention "{workloadID}.{toolName}": the tool name is the original +// backend capability name and the workload ID is the prefix. This mirrors +// the isToolStepAccessible logic used when registering composite tools and +// lets workflow step definitions remain stable regardless of the conflict +// resolution strategy in use. +// +// The dot convention is necessary because composite workflow steps reference +// tools by their pre-conflict-resolution name (e.g. "my-backend.echo"), while +// the routing table may store them under a prefixed key ("my-backend_echo"). func (r *sessionRouter) RouteTool(_ context.Context, toolName string) (*vmcp.BackendTarget, error) { if r.routingTable == nil || r.routingTable.Tools == nil { return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName) } - target, exists := r.routingTable.Tools[toolName] - if !exists { - return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName) + + // Fast path: exact key match. + if target, exists := r.routingTable.Tools[toolName]; exists { + return target, nil } - return target, nil + + // Fallback: dot convention "{workloadID}.{toolName}". + // Workload IDs are Kubernetes resource names and cannot contain dots, + // so the first dot unambiguously separates the workload ID from the + // original backend capability name. + if dotIdx := strings.Index(toolName, "."); dotIdx > 0 { + workloadID := toolName[:dotIdx] + capName := toolName[dotIdx+1:] + for resolvedName, target := range r.routingTable.Tools { + if target.WorkloadID == workloadID && target.GetBackendCapabilityName(resolvedName) == capName { + return target, nil + } + } + } + + return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName) +} + +// ResolveToolName returns the routing table key (conflict-resolved name) for +// toolName. If toolName is an exact key it is returned unchanged. If it uses +// the dot convention "{workloadID}.{originalCapabilityName}", the matching +// routing table key is returned. Falls back to returning toolName unchanged +// when the routing table is absent or the name cannot be resolved (pass-through +// semantics, consistent with the Router interface contract). +func (r *sessionRouter) ResolveToolName(_ context.Context, toolName string) string { + if r.routingTable == nil || r.routingTable.Tools == nil { + return toolName + } + + // Fast path: exact key match. + if _, exists := r.routingTable.Tools[toolName]; exists { + return toolName + } + + // Fallback: dot convention "{workloadID}.{toolName}". + if dotIdx := strings.Index(toolName, "."); dotIdx > 0 { + workloadID := toolName[:dotIdx] + capName := toolName[dotIdx+1:] + for resolvedName, target := range r.routingTable.Tools { + if target.WorkloadID == workloadID && target.GetBackendCapabilityName(resolvedName) == capName { + return resolvedName + } + } + } + + return toolName } // RouteResource resolves a resource URI to its backend target using the diff --git a/pkg/vmcp/router/session_router_test.go b/pkg/vmcp/router/session_router_test.go index 496674753d..0b1d520891 100644 --- a/pkg/vmcp/router/session_router_test.go +++ b/pkg/vmcp/router/session_router_test.go @@ -65,6 +65,54 @@ func TestSessionRouter_RouteTool(t *testing.T) { expectError: true, errorContains: "tool not found", }, + { + // Composite workflow steps use "{workloadID}.{toolName}" where toolName + // is the original backend capability name. With prefix conflict resolution + // the routing table key is "{workloadID}_toolName", so an exact match + // fails. The dot-convention fallback must resolve it correctly. + name: "dot convention resolved via workload ID and original capability name", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "my-backend_echo": { + WorkloadID: "my-backend", + WorkloadName: "My Backend", + BaseURL: "http://my-backend:8080", + OriginalCapabilityName: "echo", + }, + }, + }, + toolName: "my-backend.echo", + expectedID: "my-backend", + expectError: false, + }, + { + name: "dot convention: workload not in session", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "other-backend_echo": { + WorkloadID: "other-backend", + OriginalCapabilityName: "echo", + }, + }, + }, + toolName: "my-backend.echo", + expectError: true, + errorContains: "tool not found", + }, + { + name: "dot convention: capability name mismatch", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "my-backend_echo": { + WorkloadID: "my-backend", + OriginalCapabilityName: "echo", + }, + }, + }, + toolName: "my-backend.fetch", + expectError: true, + errorContains: "tool not found", + }, } for _, tt := range tests { @@ -87,6 +135,63 @@ func TestSessionRouter_RouteTool(t *testing.T) { } } +func TestSessionRouter_ResolveToolName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + routingTable *vmcp.RoutingTable + toolName string + expectedName string + }{ + { + name: "exact key returned unchanged", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "my-backend_echo": {WorkloadID: "my-backend", OriginalCapabilityName: "echo"}, + }, + }, + toolName: "my-backend_echo", + expectedName: "my-backend_echo", + }, + { + name: "dot convention resolves to routing table key", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "my-backend_echo": {WorkloadID: "my-backend", OriginalCapabilityName: "echo"}, + }, + }, + toolName: "my-backend.echo", + expectedName: "my-backend_echo", + }, + { + name: "not found returns toolName unchanged (pass-through)", + routingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + }, + toolName: "missing_tool", + expectedName: "missing_tool", + }, + { + name: "nil routing table returns toolName unchanged (pass-through)", + routingTable: nil, + toolName: "any_tool", + expectedName: "any_tool", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := router.NewSessionRouter(tt.routingTable) + resolved := r.ResolveToolName(context.Background(), tt.toolName) + + assert.Equal(t, tt.expectedName, resolved) + }) + } +} + func TestSessionRouter_RouteResource(t *testing.T) { t.Parallel() diff --git a/pkg/vmcp/server/adapter/handler_factory.go b/pkg/vmcp/server/adapter/handler_factory.go index 73c5ddc12c..c14d73e15b 100644 --- a/pkg/vmcp/server/adapter/handler_factory.go +++ b/pkg/vmcp/server/adapter/handler_factory.go @@ -19,6 +19,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/conversion" "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/router" + "github.com/stacklok/toolhive/pkg/vmcp/session/compositetools" ) //go:generate mockgen -destination=mocks/mock_handler_factory.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/server/adapter HandlerFactory @@ -43,20 +44,13 @@ type HandlerFactory interface { } // WorkflowExecutor executes composite tool workflows. -// This interface abstracts the composer to enable testing without full composer setup. -type WorkflowExecutor interface { - // ExecuteWorkflow executes the workflow with the given parameters. - ExecuteWorkflow(ctx context.Context, params map[string]any) (*WorkflowResult, error) -} +// Type alias for compositetools.WorkflowExecutor so that adapter consumers and +// the session decorator share a single interface definition. +type WorkflowExecutor = compositetools.WorkflowExecutor // WorkflowResult represents the result of a workflow execution. -type WorkflowResult struct { - // Output contains the workflow output data (typically from the last step). - Output map[string]any - - // Error contains error information if the workflow failed. - Error error -} +// Type alias for compositetools.WorkflowResult. +type WorkflowResult = compositetools.WorkflowResult // DefaultHandlerFactory creates MCP request handlers that route to backend workloads. type DefaultHandlerFactory struct { diff --git a/pkg/vmcp/server/adapter/optimizer_adapter.go b/pkg/vmcp/server/adapter/optimizer_adapter.go index 717d4745e4..b962a0de88 100644 --- a/pkg/vmcp/server/adapter/optimizer_adapter.go +++ b/pkg/vmcp/server/adapter/optimizer_adapter.go @@ -3,124 +3,10 @@ package adapter -import ( - "context" - "encoding/json" - "fmt" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" - "github.com/stacklok/toolhive/pkg/vmcp/schema" -) - -// OptimizerToolNames defines the tool names exposed when optimizer is enabled. +// OptimizerToolNames defines the tool names exposed when optimizer mode is enabled. +// These constants are kept here for backwards compatibility with existing tests and +// callers. The actual tool implementation lives in the optimizerdec session decorator. const ( FindToolName = "find_tool" CallToolName = "call_tool" ) - -// Pre-generated schemas for optimizer tools. -// Generated at package init time so any schema errors panic at startup. -var ( - findToolInputSchema = mustGenerateSchema[optimizer.FindToolInput]() - callToolInputSchema = mustGenerateSchema[optimizer.CallToolInput]() -) - -// CreateOptimizerTools creates the SDK tools for optimizer mode. -// When optimizer is enabled, only these two tools are exposed to clients -// instead of all backend tools. -func CreateOptimizerTools(opt optimizer.Optimizer) []server.ServerTool { - return []server.ServerTool{ - { - Tool: mcp.Tool{ - Name: FindToolName, - Description: "Find and return tools that can help accomplish the user's request. " + - "This searches available MCP server tools using semantic and keyword-based matching. " + - "Use this function when you need to: " + - "(1) discover what tools are available for a specific task, " + - "(2) find the right tool(s) before attempting to solve a problem, " + - "(3) check if required functionality exists in the current environment. " + - "Returns matching tools ranked by relevance including their names, descriptions, " + - "required parameters and schemas, plus token efficiency metrics showing " + - "baseline_tokens, returned_tokens, and savings_percent. " + - "Example: for 'Find good restaurants in San Jose', call with " + - "tool_description='search the web' and tool_keywords='web search restaurants'. " + - "Always call this before call_tool to discover the correct tool name and parameter schema.", - RawInputSchema: findToolInputSchema, - }, - Handler: createFindToolHandler(opt), - }, - { - Tool: mcp.Tool{ - Name: CallToolName, - Description: "Execute a specific tool with the provided parameters. " + - "Use this function to: " + - "(1) run a tool after identifying it with find_tool, " + - "(2) execute operations that require specific MCP server functionality, " + - "(3) perform actions that go beyond your built-in capabilities. " + - "Important: always use find_tool first to get the correct tool_name " + - "and parameter schema before calling this function. " + - "The parameters must match the tool's input schema as returned by find_tool. " + - "Returns the tool's execution result which may include success/failure status, " + - "result data or content, and error messages if execution failed.", - RawInputSchema: callToolInputSchema, - }, - Handler: createCallToolHandler(opt), - }, - } -} - -// createFindToolHandler creates a handler for the find_tool optimizer operation. -func createFindToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - input, err := schema.Translate[optimizer.FindToolInput](request.Params.Arguments) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil - } - - output, err := opt.FindTool(ctx, input) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("find_tool failed: %v", err)), nil - } - - return mcp.NewToolResultStructuredOnly(output), nil - } -} - -// createCallToolHandler creates a handler for the call_tool optimizer operation. -func createCallToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - input, err := schema.Translate[optimizer.CallToolInput](request.Params.Arguments) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil - } - - result, err := opt.CallTool(ctx, input) - if err != nil { - // Exposing the error to the MCP client is important if you want it to correct its behavior. - // Without information on the failure, the model is pretty much hopeless in figuring out the problem. - return mcp.NewToolResultError(fmt.Sprintf("call_tool failed: %v", err)), nil - } - - return result, nil - } -} - -// mustMarshalSchema marshals a schema to JSON, panicking on error. -// This is safe because schemas are generated from known types at startup. -// This should NOT be called by runtime code. -func mustGenerateSchema[T any]() json.RawMessage { - s, err := schema.GenerateSchema[T]() - if err != nil { - panic(fmt.Sprintf("failed to generate schema: %v", err)) - } - - data, err := json.Marshal(s) - if err != nil { - panic(fmt.Sprintf("failed to marshal schema: %v", err)) - } - - return data -} diff --git a/pkg/vmcp/server/adapter/optimizer_adapter_test.go b/pkg/vmcp/server/adapter/optimizer_adapter_test.go index f6228746c7..53fba4b1a9 100644 --- a/pkg/vmcp/server/adapter/optimizer_adapter_test.go +++ b/pkg/vmcp/server/adapter/optimizer_adapter_test.go @@ -4,103 +4,14 @@ package adapter import ( - "context" "testing" - "github.com/mark3labs/mcp-go/mcp" - "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + "github.com/stretchr/testify/assert" ) -// mockOptimizer implements optimizer.Optimizer for testing. -type mockOptimizer struct { - findToolFunc func(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) - callToolFunc func(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) -} - -func (m *mockOptimizer) FindTool(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { - if m.findToolFunc != nil { - return m.findToolFunc(ctx, input) - } - return &optimizer.FindToolOutput{}, nil -} - -func (m *mockOptimizer) CallTool(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { - if m.callToolFunc != nil { - return m.callToolFunc(ctx, input) - } - return mcp.NewToolResultText("ok"), nil -} - -func TestCreateOptimizerTools(t *testing.T) { - t.Parallel() - - opt := &mockOptimizer{} - tools := CreateOptimizerTools(opt) - - require.Len(t, tools, 2) - require.Equal(t, FindToolName, tools[0].Tool.Name) - require.Equal(t, CallToolName, tools[1].Tool.Name) -} - -func TestFindToolHandler(t *testing.T) { +func TestOptimizerToolNameConstants(t *testing.T) { t.Parallel() - opt := &mockOptimizer{ - findToolFunc: func(_ context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { - require.Equal(t, "read files", input.ToolDescription) - return &optimizer.FindToolOutput{ - Tools: []mcp.Tool{ - { - Name: "read_file", - Description: "Read a file", - }, - }, - }, nil - }, - } - - tools := CreateOptimizerTools(opt) - handler := tools[0].Handler - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]any{ - "tool_description": "read files", - } - - result, err := handler(context.Background(), request) - require.NoError(t, err) - require.NotNil(t, result) - require.False(t, result.IsError) - require.Len(t, result.Content, 1) -} - -func TestCallToolHandler(t *testing.T) { - t.Parallel() - - opt := &mockOptimizer{ - callToolFunc: func(_ context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { - require.Equal(t, "read_file", input.ToolName) - require.Equal(t, "/etc/hosts", input.Parameters["path"]) - return mcp.NewToolResultText("file contents here"), nil - }, - } - - tools := CreateOptimizerTools(opt) - handler := tools[1].Handler - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]any{ - "tool_name": "read_file", - "parameters": map[string]any{ - "path": "/etc/hosts", - }, - } - - result, err := handler(context.Background(), request) - require.NoError(t, err) - require.NotNil(t, result) - require.False(t, result.IsError) - require.Len(t, result.Content, 1) + assert.Equal(t, "find_tool", FindToolName) + assert.Equal(t, "call_tool", CallToolName) } diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index d04178d22f..23b1d7523d 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -37,6 +37,9 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/server/adapter" "github.com/stacklok/toolhive/pkg/vmcp/server/sessionmanager" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" + "github.com/stacklok/toolhive/pkg/vmcp/session/compositetools" + "github.com/stacklok/toolhive/pkg/vmcp/session/optimizerdec" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" vmcpstatus "github.com/stacklok/toolhive/pkg/vmcp/status" ) @@ -340,14 +343,14 @@ func New( // The composer orchestrates multi-step workflows across backends // Use in-memory state store with 5-minute cleanup interval and 1-hour max age for completed workflows stateStore := composer.NewInMemoryStateStore(5*time.Minute, 1*time.Hour) - workflowComposer := composer.NewWorkflowEngine(rt, backendClient, elicitationHandler, stateStore, workflowAuditor) + workflowComposer := composer.NewWorkflowEngine(rt, backendClient, elicitationHandler, stateStore, workflowAuditor, nil) // composerFactory builds a per-session workflow engine at session registration // time, binding composite tool routing to the session's own routing table and // tool list. This removes composite tools' dependency on the discovery middleware // injecting DiscoveredCapabilities into the request context. sessionComposerFactory := func(sessionRT *vmcp.RoutingTable, sessionTools []vmcp.Tool) composer.Composer { - return composer.NewSessionWorkflowEngine( + return composer.NewWorkflowEngine( router.NewSessionRouter(sessionRT), backendClient, elicitationHandler, stateStore, workflowAuditor, sessionTools, ) @@ -1011,6 +1014,26 @@ func (s *Server) handleSessionRegistrationImpl(ctx context.Context, session serv return retErr } + // Apply composite tools decorator (if workflow defs are configured). + // This appends composite tools to the session's tool list and routes their + // CallTool invocations through per-session workflow executors. + if len(s.workflowDefs) > 0 { + if retErr = s.applyCompositeToolsDecorator(sessionID); retErr != nil { + return retErr + } + } + + // Apply optimizer decorator (if configured). + // Must come after composite tools so the optimizer indexes composite tools too. + // Replaces the full tool list with find_tool + call_tool. + if s.config.OptimizerFactory != nil { + if retErr = s.applyOptimizerDecorator(ctx, sessionID); retErr != nil { + return retErr + } + } + + // Uniform registration — same code path regardless of which decorators are active. + // session.Tools() returns the final decorated tool list. adaptedTools, retErr := s.vmcpSessionMgr.GetAdaptedTools(sessionID) if retErr != nil { slog.Error("failed to get session-scoped tools", @@ -1027,12 +1050,6 @@ func (s *Server) handleSessionRegistrationImpl(ctx context.Context, session serv return retErr } - // Collect composite SDK tools (with name-collision check against backend tools). - compositeSDKTools, retErr := s.collectCompositeTools(sessionID) - if retErr != nil { - return retErr - } - if len(adaptedResources) > 0 { if err := setSessionResourcesDirect(session, adaptedResources); err != nil { slog.Error("failed to add session resources", "session_id", sessionID, "error", err) @@ -1040,104 +1057,89 @@ func (s *Server) handleSessionRegistrationImpl(ctx context.Context, session serv } } - return s.injectTools(ctx, session, adaptedTools, compositeSDKTools) -} - -// collectCompositeTools converts workflow definitions to SDK tools for the given session, -// validating that no composite tool name collides with a backend tool name. -// Returns an empty slice (not an error) if no workflow defs are configured or conflicts are found. -// Composite tools whose underlying backend tools are not routable in this session are excluded, -// so a session that lacks access to a backend tool also cannot access composite tools that depend on it. -func (s *Server) collectCompositeTools(sessionID string) ([]server.ServerTool, error) { - if len(s.workflowDefs) == 0 { - return nil, nil + if len(adaptedTools) > 0 { + if err := setSessionToolsDirect(session, adaptedTools); err != nil { + slog.Error("failed to add session tools", "session_id", sessionID, "error", err) + return err + } } + slog.Info("session capabilities injected", + "session_id", sessionID, + "tool_count", len(adaptedTools)) + return nil +} + +// applyCompositeToolsDecorator wraps the session with a compositeToolsDecorator. +// It filters workflow definitions for the session, validates name conflicts with +// backend tools, builds per-session workflow executors, then calls DecorateSession. +// +// Non-fatal conditions (no accessible defs, name conflicts) log a warning and +// leave the session undecorated rather than failing. +func (s *Server) applyCompositeToolsDecorator(sessionID string) error { multiSess, hasSess := s.vmcpSessionMgr.GetMultiSession(sessionID) if !hasSess { - slog.Error("session not found after creation; skipping composite tools", + slog.Warn("session not found after creation; skipping composite tools", "session_id", sessionID) - return nil, nil + return nil } sessionDefs := filterWorkflowDefsForSession(s.workflowDefs, multiSess.GetRoutingTable()) if len(sessionDefs) == 0 { - return nil, nil + return nil } - compositeTools := convertWorkflowDefsToTools(sessionDefs) - if err := validateNoToolConflicts(multiSess.Tools(), compositeTools); err != nil { - slog.Error("composite tool name conflict detected; skipping composite tools", + compositeToolsMeta := convertWorkflowDefsToTools(sessionDefs) + if err := validateNoToolConflicts(multiSess.Tools(), compositeToolsMeta); err != nil { + slog.Warn("composite tool name conflict detected; skipping composite tools", "session_id", sessionID, "error", err) - return nil, nil + return nil } - // Build per-session workflow executors so that composite tool routing uses - // the session's own routing table rather than DiscoveredCapabilities from - // the request context (which is injected by the discovery middleware). + // Build per-session workflow executors bound to this session's routing table. sessionComposer := s.composerFactory(multiSess.GetRoutingTable(), multiSess.Tools()) - sessionExecutors := make(map[string]adapter.WorkflowExecutor, len(sessionDefs)) - for name, def := range sessionDefs { + sessionExecutors := make(map[string]compositetools.WorkflowExecutor, len(sessionDefs)) + for _, def := range sessionDefs { ex := newComposerWorkflowExecutor(sessionComposer, def) if s.workflowInstruments != nil { - ex = s.workflowInstruments.wrapExecutor(name, ex) + ex = s.workflowInstruments.wrapExecutor(def.Name, ex) } - sessionExecutors[name] = ex + sessionExecutors[def.Name] = ex } - sdkTools, err := s.capabilityAdapter.ToCompositeToolSDKTools(compositeTools, sessionExecutors) - if err != nil { - return nil, fmt.Errorf("failed to convert composite tools: %w", err) - } - return sdkTools, nil + return s.vmcpSessionMgr.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + return compositetools.NewDecorator(sess, compositeToolsMeta, sessionExecutors) + }) } -// injectTools registers backend and composite tools into the session. -// When the optimizer is configured, all tools are indexed and only -// find_tool/call_tool are exposed; otherwise tools are registered directly. -func (s *Server) injectTools( - ctx context.Context, - session server.ClientSession, - adaptedTools []server.ServerTool, - compositeSDKTools []server.ServerTool, -) error { - sessionID := session.SessionID() - - if s.config.OptimizerFactory != nil { - allTools := append(adaptedTools, compositeSDKTools...) - opt, err := s.config.OptimizerFactory(ctx, allTools) - if err != nil { - return fmt.Errorf("failed to create optimizer: %w", err) - } - if err = setSessionToolsDirect(session, adapter.CreateOptimizerTools(opt)); err != nil { - slog.Error("failed to add optimizer tools to session", "session_id", sessionID, "error", err) - return err - } - slog.Info("session capabilities injected (optimizer mode)", - "session_id", sessionID, - "indexed_tool_count", len(allTools)) - return nil +// applyOptimizerDecorator wraps the session with an optimizerDecorator. +// It reads the current session's tool list (including composite tools if applied), +// builds SDK tools for the optimizer factory to index, creates the optimizer, and +// calls DecorateSession. The optimizer replaces the tool list with find_tool + call_tool. +func (s *Server) applyOptimizerDecorator(ctx context.Context, sessionID string) error { + // Snapshot the pre-optimizer SDK tools for the optimizer to index. + // This includes composite tools if the composite decorator was already applied. + sdkTools, err := s.vmcpSessionMgr.GetAdaptedTools(sessionID) + if err != nil { + return fmt.Errorf("failed to get tools for optimizer: %w", err) } - if len(adaptedTools) > 0 { - if err := setSessionToolsDirect(session, adaptedTools); err != nil { - slog.Error("failed to add session tools", "session_id", sessionID, "error", err) - return err - } + opt, err := s.config.OptimizerFactory(ctx, sdkTools) + if err != nil { + return fmt.Errorf("failed to create optimizer: %w", err) } - if len(compositeSDKTools) > 0 { - if err := setSessionToolsDirect(session, compositeSDKTools); err != nil { - slog.Error("failed to add composite tools to session", "session_id", sessionID, "error", err) - return err - } - slog.Debug("added composite tools to session", "session_id", sessionID, "count", len(compositeSDKTools)) + + if err = s.vmcpSessionMgr.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + return optimizerdec.NewDecorator(sess, opt) + }); err != nil { + return err } - slog.Info("session capabilities injected", + slog.Info("session capabilities injected (optimizer mode)", "session_id", sessionID, - "tool_count", len(adaptedTools), - "composite_tool_count", len(compositeSDKTools)) + "indexed_tool_count", len(sdkTools)) + return nil } diff --git a/pkg/vmcp/server/session_manager_interface.go b/pkg/vmcp/server/session_manager_interface.go index 359cab43f1..770cb7fff9 100644 --- a/pkg/vmcp/server/session_manager_interface.go +++ b/pkg/vmcp/server/session_manager_interface.go @@ -9,6 +9,7 @@ import ( mcpserver "github.com/mark3labs/mcp-go/server" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) // SessionManager extends the SDK's SessionIdManager with Phase 2 session creation @@ -39,4 +40,12 @@ type SessionManager interface { // Returns (nil, false) if the session does not exist or is still a placeholder. // Used to access session-scoped backend tool metadata (e.g. for conflict validation). GetMultiSession(sessionID string) (vmcpsession.MultiSession, bool) + + // DecorateSession retrieves the MultiSession for sessionID, applies fn to it, + // and stores the result back. Used to stack session decorators (composite tools, + // optimizer) after the base session is created. + DecorateSession(sessionID string, fn func(sessiontypes.MultiSession) sessiontypes.MultiSession) error + + // Terminate terminates the session with the given ID, closing all backend connections. + Terminate(sessionID string) (bool, error) } diff --git a/pkg/vmcp/server/sessionmanager/session_manager.go b/pkg/vmcp/server/sessionmanager/session_manager.go index 71532abc0d..13bd0dd822 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager.go +++ b/pkg/vmcp/server/sessionmanager/session_manager.go @@ -322,6 +322,41 @@ func (sm *Manager) GetMultiSession(sessionID string) (vmcpsession.MultiSession, return multiSess, ok } +// DecorateSession retrieves the MultiSession for sessionID, applies fn to it, +// and stores the result back. Returns an error if the session is not found or +// has not yet been upgraded from placeholder to MultiSession. +// +// A re-check is performed immediately before UpsertSession to guard against a +// race with Terminate(): if the session is deleted between GetMultiSession and +// UpsertSession, the upsert would silently resurrect a terminated session. The +// re-check catches that window. A narrow TOCTOU gap remains between the +// re-check and the upsert, but its consequence is bounded: Terminate() already +// called Close() on the underlying MultiSession before deleting it, so any +// resurrected decorator wraps an already-closed session and will fail on first +// use rather than leaking backend connections. +func (sm *Manager) DecorateSession(sessionID string, fn func(sessiontypes.MultiSession) sessiontypes.MultiSession) error { + sess, ok := sm.GetMultiSession(sessionID) + if !ok { + return fmt.Errorf("DecorateSession: session %q not found or not a multi-session", sessionID) + } + decorated := fn(sess) + if decorated == nil { + return fmt.Errorf("DecorateSession: decorator returned nil session") + } + if decorated.ID() != sessionID { + return fmt.Errorf("DecorateSession: decorator changed session ID from %q to %q", sessionID, decorated.ID()) + } + // Re-check: guard against a race with Terminate() deleting the session + // between GetMultiSession (above) and UpsertSession (below). + if _, ok := sm.GetMultiSession(sessionID); !ok { + return fmt.Errorf("DecorateSession: session %q was terminated during decoration", sessionID) + } + if err := sm.storage.UpsertSession(decorated); err != nil { + return fmt.Errorf("DecorateSession: failed to store decorated session: %w", err) + } + return nil +} + // GetAdaptedTools returns SDK-format tools for the given session, with handlers // that delegate tool invocations directly to the session's CallTool() method. // diff --git a/pkg/vmcp/server/sessionmanager/session_manager_test.go b/pkg/vmcp/server/sessionmanager/session_manager_test.go index 481a542721..058dfd37be 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager_test.go +++ b/pkg/vmcp/server/sessionmanager/session_manager_test.go @@ -1378,6 +1378,113 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { }) } +// --------------------------------------------------------------------------- +// Tests: DecorateSession +// --------------------------------------------------------------------------- + +func TestSessionManager_DecorateSession(t *testing.T) { + t.Parallel() + + t.Run("replaces session with decorated result", func(t *testing.T) { + t.Parallel() + + tools := []vmcp.Tool{{Name: "hello", Description: "says hello"}} + ctrl := gomock.NewController(t) + factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) + factory.EXPECT(). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + return newMockSession(t, ctrl, id, tools), nil + }).Times(1) + + registry := newFakeRegistry() + sm, _ := newTestSessionManager(t, factory, registry) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + _, err := sm.CreateSession(context.Background(), sessionID) + require.NoError(t, err) + + // Apply a decorator that wraps with an extra tool. + extraTool := vmcp.Tool{Name: "extra", Description: "extra tool"} + err = sm.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + decorated := sessionmocks.NewMockMultiSession(ctrl) + // Delegate everything to base session + decorated.EXPECT().ID().Return(sess.ID()).AnyTimes() + decorated.EXPECT().Tools().Return(append(sess.Tools(), extraTool)).AnyTimes() + // other methods delegated via AnyTimes + decorated.EXPECT().Type().Return(sess.Type()).AnyTimes() + decorated.EXPECT().CreatedAt().Return(sess.CreatedAt()).AnyTimes() + decorated.EXPECT().UpdatedAt().Return(sess.UpdatedAt()).AnyTimes() + decorated.EXPECT().Touch().AnyTimes() + decorated.EXPECT().GetData().Return(nil).AnyTimes() + decorated.EXPECT().SetData(gomock.Any()).AnyTimes() + decorated.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() + decorated.EXPECT().SetMetadata(gomock.Any(), gomock.Any()).AnyTimes() + decorated.EXPECT().BackendSessions().Return(nil).AnyTimes() + decorated.EXPECT().GetRoutingTable().Return(nil).AnyTimes() + decorated.EXPECT().Prompts().Return(nil).AnyTimes() + return decorated + }) + require.NoError(t, err) + + // After decoration, GetMultiSession returns the decorated session with both tools. + multiSess, ok := sm.GetMultiSession(sessionID) + require.True(t, ok) + require.Len(t, multiSess.Tools(), 2) + assert.Equal(t, "hello", multiSess.Tools()[0].Name) + assert.Equal(t, "extra", multiSess.Tools()[1].Name) + }) + + t.Run("returns error for unknown session", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + sm, _ := newTestSessionManager(t, newMockFactory(t, ctrl, newMockSession(t, ctrl, "", nil)), newFakeRegistry()) + + err := sm.DecorateSession("ghost-session", func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + return sess + }) + require.Error(t, err) + }) + + t.Run("returns error if session terminated during decoration", func(t *testing.T) { + t.Parallel() + + // Simulate the race: Terminate() is called between GetMultiSession and + // UpsertSession. We do this by terminating the session inside the + // decorator fn, so the re-check that follows fn() sees it is gone. + ctrl := gomock.NewController(t) + factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) + factory.EXPECT(). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + sess := newMockSession(t, ctrl, id, nil) + sess.EXPECT().Close().Return(nil).AnyTimes() + return sess, nil + }).Times(1) + + sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + _, err := sm.CreateSession(context.Background(), sessionID) + require.NoError(t, err) + + err = sm.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + // Simulate concurrent Terminate() completing during decoration. + _, _ = sm.Terminate(sessionID) + return sess + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "terminated during decoration") + + // The session must not be resurrected. + _, ok := sm.GetMultiSession(sessionID) + assert.False(t, ok, "terminated session must not be resurrected by DecorateSession") + }) +} + // --------------------------------------------------------------------------- // Helper // --------------------------------------------------------------------------- diff --git a/pkg/vmcp/session/compositetools/decorator.go b/pkg/vmcp/session/compositetools/decorator.go new file mode 100644 index 0000000000..35f672bfd5 --- /dev/null +++ b/pkg/vmcp/session/compositetools/decorator.go @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package compositetools provides a MultiSession decorator that adds composite +// tool (workflow) capabilities to a session. +package compositetools + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" +) + +// WorkflowExecutor executes a named composite tool workflow. +type WorkflowExecutor interface { + ExecuteWorkflow(ctx context.Context, params map[string]any) (*WorkflowResult, error) +} + +// WorkflowResult holds the output of a workflow execution. +type WorkflowResult struct { + Output map[string]any + Error error +} + +// compositeToolsDecorator wraps a MultiSession to add composite tool routing. +// It overrides Tools() to append composite tool metadata and CallTool() to +// intercept composite tool names and dispatch them to workflow executors. +// All other MultiSession methods delegate to the embedded session. +type compositeToolsDecorator struct { + sessiontypes.MultiSession + compositeTools []vmcp.Tool + executors map[string]WorkflowExecutor +} + +func errorResult(msg string) *vmcp.ToolCallResult { + return &vmcp.ToolCallResult{ + Content: []vmcp.Content{{Type: "text", Text: msg}}, + IsError: true, + } +} + +// NewDecorator wraps sess with composite tool support. compositeTools is the +// metadata list appended to session.Tools(). executors maps each composite tool +// name to its workflow executor. Both may be nil/empty. +func NewDecorator( + sess sessiontypes.MultiSession, + compositeTools []vmcp.Tool, + executors map[string]WorkflowExecutor, +) sessiontypes.MultiSession { + return &compositeToolsDecorator{ + MultiSession: sess, + compositeTools: compositeTools, + executors: executors, + } +} + +// Tools returns backend tools followed by composite tools. +func (d *compositeToolsDecorator) Tools() []vmcp.Tool { + backend := d.MultiSession.Tools() + if len(d.compositeTools) == 0 { + return backend + } + out := make([]vmcp.Tool, len(backend), len(backend)+len(d.compositeTools)) + copy(out, backend) + return append(out, d.compositeTools...) +} + +// CallTool dispatches composite tool names to their workflow executors. +// Unknown names are delegated to the embedded session. +func (d *compositeToolsDecorator) CallTool( + ctx context.Context, + caller *auth.Identity, + toolName string, + arguments map[string]any, + meta map[string]any, +) (*vmcp.ToolCallResult, error) { + if exec, ok := d.executors[toolName]; ok { + slog.Debug("handling composite tool call", "tool", toolName) + res, err := exec.ExecuteWorkflow(ctx, arguments) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + slog.Warn("workflow execution timeout", "tool", toolName, "error", err) + return errorResult("Workflow execution timeout exceeded"), nil + } + slog.Error("workflow execution failed", "tool", toolName, "error", err) + return errorResult(fmt.Sprintf("Workflow execution failed: %v", err)), nil + } + if res == nil { + slog.Error("workflow executor returned nil result", "tool", toolName) + return errorResult("Workflow executor returned nil result"), nil + } + if res.Error != nil { + slog.Error("workflow completed with error", "tool", toolName, "error", res.Error) + return errorResult(fmt.Sprintf("Workflow error: %v", res.Error)), nil + } + slog.Debug("composite tool completed successfully", "tool", toolName) + jsonBytes, err := json.Marshal(res.Output) + if err != nil { + return errorResult(fmt.Sprintf("failed to marshal output: %v", err)), nil + } + return &vmcp.ToolCallResult{ + Content: []vmcp.Content{{Type: "text", Text: string(jsonBytes)}}, + StructuredContent: res.Output, + }, nil + } + return d.MultiSession.CallTool(ctx, caller, toolName, arguments, meta) +} diff --git a/pkg/vmcp/session/compositetools/decorator_test.go b/pkg/vmcp/session/compositetools/decorator_test.go new file mode 100644 index 0000000000..ba8bfe1fd3 --- /dev/null +++ b/pkg/vmcp/session/compositetools/decorator_test.go @@ -0,0 +1,126 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package compositetools_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/compositetools" + sessionmocks "github.com/stacklok/toolhive/pkg/vmcp/session/types/mocks" +) + +// stubExecutor is a simple WorkflowExecutor for tests. +type stubExecutor struct { + output map[string]any + err error +} + +func (s *stubExecutor) ExecuteWorkflow(_ context.Context, _ map[string]any) (*compositetools.WorkflowResult, error) { + return &compositetools.WorkflowResult{Output: s.output}, s.err +} + +func TestCompositeToolsDecorator_Tools(t *testing.T) { + t.Parallel() + + t.Run("appends composite tools to backend tools", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + backendTools := []vmcp.Tool{{Name: "backend_search", Description: "search"}} + base.EXPECT().Tools().Return(backendTools).AnyTimes() + + compositeToolList := []vmcp.Tool{{Name: "my_workflow", Description: "a workflow"}} + dec := compositetools.NewDecorator(base, compositeToolList, nil) + + got := dec.Tools() + require.Len(t, got, 2) + assert.Equal(t, "backend_search", got[0].Name) + assert.Equal(t, "my_workflow", got[1].Name) + }) + + t.Run("returns only backend tools when no composite tools", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + backendTools := []vmcp.Tool{{Name: "backend_search", Description: "search"}} + base.EXPECT().Tools().Return(backendTools).AnyTimes() + + dec := compositetools.NewDecorator(base, nil, nil) + + got := dec.Tools() + require.Len(t, got, 1) + assert.Equal(t, "backend_search", got[0].Name) + }) +} + +func TestCompositeToolsDecorator_CallTool(t *testing.T) { + t.Parallel() + + t.Run("routes composite tool name to executor", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + expectedOutput := map[string]any{"result": "done"} + exec := &stubExecutor{output: expectedOutput} + executors := map[string]compositetools.WorkflowExecutor{"my_workflow": exec} + + dec := compositetools.NewDecorator(base, []vmcp.Tool{{Name: "my_workflow"}}, executors) + result, err := dec.CallTool(context.Background(), nil, "my_workflow", map[string]any{"x": 1}, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, expectedOutput, result.StructuredContent) + }) + + t.Run("delegates unknown tool name to embedded session", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + expectedResult := &vmcp.ToolCallResult{IsError: false} + base.EXPECT(). + CallTool(gomock.Any(), gomock.Any(), "backend_tool", gomock.Any(), gomock.Any()). + Return(expectedResult, nil) + + dec := compositetools.NewDecorator(base, nil, nil) + result, err := dec.CallTool(context.Background(), &auth.Identity{}, "backend_tool", nil, nil) + + require.NoError(t, err) + assert.Equal(t, expectedResult, result) + }) + + t.Run("propagates executor error as tool error result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + execErr := errors.New("workflow failed") + exec := &stubExecutor{err: execErr} + executors := map[string]compositetools.WorkflowExecutor{"failing_wf": exec} + + dec := compositetools.NewDecorator(base, []vmcp.Tool{{Name: "failing_wf"}}, executors) + result, err := dec.CallTool(context.Background(), nil, "failing_wf", nil, nil) + + require.NoError(t, err) // errors surface as IsError results per MCP convention + require.NotNil(t, result) + assert.True(t, result.IsError) + }) +} diff --git a/pkg/vmcp/session/optimizerdec/decorator.go b/pkg/vmcp/session/optimizerdec/decorator.go new file mode 100644 index 0000000000..a59c00bbdd --- /dev/null +++ b/pkg/vmcp/session/optimizerdec/decorator.go @@ -0,0 +1,178 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package optimizerdec provides a MultiSession decorator that replaces the +// full tool list with two optimizer tools: find_tool and call_tool. +package optimizerdec + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/conversion" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + "github.com/stacklok/toolhive/pkg/vmcp/schema" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" +) + +const ( + // FindToolName is the tool name for semantic tool discovery. + FindToolName = "find_tool" + // CallToolName is the tool name for routing a call to any backend tool. + CallToolName = "call_tool" +) + +// Pre-generated schemas for find_tool and call_tool, computed at init time. +var ( + findToolInputSchema = mustGenerateSchema[optimizer.FindToolInput]() + callToolInputSchema = mustGenerateSchema[optimizer.CallToolInput]() +) + +// optimizerDecorator wraps a MultiSession to expose only find_tool and call_tool. +// Tools() returns only those two tools. CallTool("find_tool") routes through the +// optimizer's FindTool; CallTool("call_tool") routes through the optimizer's +// CallTool so that all optimizer telemetry (traces, metrics) is recorded. +type optimizerDecorator struct { + sessiontypes.MultiSession + opt optimizer.Optimizer + optimizerTools []vmcp.Tool +} + +// NewDecorator wraps sess with optimizer mode. Only find_tool and call_tool are +// exposed via Tools(). find_tool calls opt.FindTool. call_tool calls opt.CallTool, +// which routes through the instrumented optimizer (telemetry, traces, metrics). +func NewDecorator(sess sessiontypes.MultiSession, opt optimizer.Optimizer) sessiontypes.MultiSession { + return &optimizerDecorator{ + MultiSession: sess, + opt: opt, + optimizerTools: []vmcp.Tool{ + { + Name: FindToolName, + Description: "Find and return tools that can help accomplish the user's request. " + + "This searches available MCP server tools using semantic and keyword-based matching. " + + "Use this function when you need to: " + + "(1) discover what tools are available for a specific task, " + + "(2) find the right tool(s) before attempting to solve a problem, " + + "(3) check if required functionality exists in the current environment. " + + "Returns matching tools ranked by relevance including their names, descriptions, " + + "required parameters and schemas, plus token efficiency metrics showing " + + "baseline_tokens, returned_tokens, and savings_percent. " + + "Always call this before call_tool to discover the correct tool name and parameter schema.", + InputSchema: findToolInputSchema, + }, + { + Name: CallToolName, + Description: "Execute a specific tool with the provided parameters. " + + "Use this function to run a tool after identifying it with find_tool. " + + "Important: always use find_tool first to get the correct tool_name " + + "and parameter schema before calling this function.", + InputSchema: callToolInputSchema, + }, + }, + } +} + +// Tools returns only find_tool and call_tool, replacing the full backend tool list. +// A defensive copy is returned so callers cannot mutate the decorator's internal slice. +func (d *optimizerDecorator) Tools() []vmcp.Tool { + result := make([]vmcp.Tool, len(d.optimizerTools)) + copy(result, d.optimizerTools) + return result +} + +// CallTool handles find_tool and call_tool. Both route through the optimizer so +// that all optimizer telemetry is recorded. Any other tool name returns an error. +func (d *optimizerDecorator) CallTool( + ctx context.Context, + _ *auth.Identity, + toolName string, + arguments map[string]any, + _ map[string]any, +) (*vmcp.ToolCallResult, error) { + switch toolName { + case FindToolName: + return d.handleFindTool(ctx, arguments) + case CallToolName: + return d.handleCallTool(ctx, arguments) + default: + return nil, fmt.Errorf("tool not found: %s", toolName) + } +} + +func (d *optimizerDecorator) handleFindTool(ctx context.Context, arguments map[string]any) (*vmcp.ToolCallResult, error) { + input, err := schema.Translate[optimizer.FindToolInput](arguments) + if err != nil { + return errorResult(fmt.Sprintf("invalid arguments: %v", err)), nil + } + + output, err := d.opt.FindTool(ctx, input) + if err != nil { + return errorResult(fmt.Sprintf("find_tool failed: %v", err)), nil + } + if output == nil { + return errorResult("find_tool: optimizer returned nil result"), nil + } + + jsonBytes, err := json.Marshal(output) + if err != nil { + return errorResult(fmt.Sprintf("failed to marshal find_tool output: %v", err)), nil + } + + var structured map[string]any + _ = json.Unmarshal(jsonBytes, &structured) + + return &vmcp.ToolCallResult{ + Content: []vmcp.Content{{Type: "text", Text: string(jsonBytes)}}, + StructuredContent: structured, + }, nil +} + +func (d *optimizerDecorator) handleCallTool( + ctx context.Context, + arguments map[string]any, +) (*vmcp.ToolCallResult, error) { + input, err := schema.Translate[optimizer.CallToolInput](arguments) + if err != nil { + return errorResult(fmt.Sprintf("invalid arguments: %v", err)), nil + } + + mcpResult, err := d.opt.CallTool(ctx, input) + if err != nil { + return errorResult(fmt.Sprintf("call_tool failed: %v", err)), nil + } + if mcpResult == nil { + return errorResult("call_tool: optimizer returned nil result"), nil + } + + return mcpResultToVMCPResult(mcpResult), nil +} + +// mcpResultToVMCPResult converts an MCP SDK CallToolResult to the vmcp domain type. +func mcpResultToVMCPResult(r *mcp.CallToolResult) *vmcp.ToolCallResult { + structured, _ := r.StructuredContent.(map[string]any) + return &vmcp.ToolCallResult{ + Content: conversion.ConvertMCPContents(r.Content), + StructuredContent: structured, + IsError: r.IsError, + } +} + +func errorResult(msg string) *vmcp.ToolCallResult { + return &vmcp.ToolCallResult{ + Content: []vmcp.Content{{Type: "text", Text: msg}}, + IsError: true, + } +} + +func mustGenerateSchema[T any]() map[string]any { + s, err := schema.GenerateSchema[T]() + if err != nil { + panic(fmt.Sprintf("optimizerdec: failed to generate schema: %v", err)) + } + return s +} diff --git a/pkg/vmcp/session/optimizerdec/decorator_test.go b/pkg/vmcp/session/optimizerdec/decorator_test.go new file mode 100644 index 0000000000..4824cd3051 --- /dev/null +++ b/pkg/vmcp/session/optimizerdec/decorator_test.go @@ -0,0 +1,199 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizerdec_test + +import ( + "context" + "errors" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + "github.com/stacklok/toolhive/pkg/vmcp/session/optimizerdec" + sessionmocks "github.com/stacklok/toolhive/pkg/vmcp/session/types/mocks" +) + +// stubOptimizer implements optimizer.Optimizer for tests. +type stubOptimizer struct { + findOutput *optimizer.FindToolOutput + findErr error + callOutput *mcp.CallToolResult + callErr error +} + +func (s *stubOptimizer) FindTool(_ context.Context, _ optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { + return s.findOutput, s.findErr +} + +func (s *stubOptimizer) CallTool(_ context.Context, _ optimizer.CallToolInput) (*mcp.CallToolResult, error) { + return s.callOutput, s.callErr +} + +func TestOptimizerDecorator_Tools(t *testing.T) { + t.Parallel() + + t.Run("returns only find_tool and call_tool", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return([]vmcp.Tool{{Name: "backend_search"}}).AnyTimes() + + dec := optimizerdec.NewDecorator(base, &stubOptimizer{}) + + got := dec.Tools() + require.Len(t, got, 2) + assert.Equal(t, "find_tool", got[0].Name) + assert.Equal(t, "call_tool", got[1].Name) + // Both tools must have non-empty input schemas. + assert.NotEmpty(t, got[0].InputSchema) + assert.NotEmpty(t, got[1].InputSchema) + }) +} + +func TestOptimizerDecorator_CallTool_FindTool(t *testing.T) { + t.Parallel() + + t.Run("find_tool calls optimizer and returns JSON result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + findOutput := &optimizer.FindToolOutput{ + Tools: []mcp.Tool{{Name: "search"}}, + } + opt := &stubOptimizer{findOutput: findOutput} + dec := optimizerdec.NewDecorator(base, opt) + + args := map[string]any{"tool_description": "web search"} + result, err := dec.CallTool(context.Background(), nil, "find_tool", args, nil) + + require.NoError(t, err) + require.NotNil(t, result) + // Result should be non-error and contain the marshaled output. + assert.False(t, result.IsError) + // The structured content should be present or content should have JSON text. + require.NotEmpty(t, result.Content) + }) + + t.Run("find_tool propagates optimizer error as tool error result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + opt := &stubOptimizer{findErr: errors.New("index unavailable")} + dec := optimizerdec.NewDecorator(base, opt) + + result, err := dec.CallTool(context.Background(), nil, "find_tool", map[string]any{"tool_description": "x"}, nil) + + require.NoError(t, err) // errors are surfaced as IsError results per MCP convention + require.NotNil(t, result) + assert.True(t, result.IsError) + }) + + t.Run("find_tool returns error result when optimizer returns nil output", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + opt := &stubOptimizer{findOutput: nil, findErr: nil} + dec := optimizerdec.NewDecorator(base, opt) + + result, err := dec.CallTool(context.Background(), nil, "find_tool", map[string]any{"tool_description": "x"}, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) + }) +} + +func TestOptimizerDecorator_CallTool_CallTool(t *testing.T) { + t.Parallel() + + t.Run("call_tool routes through optimizer and converts result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + // The underlying session must NOT be called — call_tool routes through optimizer.CallTool. + + opt := &stubOptimizer{ + callOutput: mcp.NewToolResultText("fetched content"), + } + dec := optimizerdec.NewDecorator(base, opt) + + args := map[string]any{ + "tool_name": "backend_fetch", + "parameters": map[string]any{"url": "https://example.com"}, + } + result, err := dec.CallTool(context.Background(), &auth.Identity{}, "call_tool", args, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + require.Len(t, result.Content, 1) + assert.Equal(t, "fetched content", result.Content[0].Text) + }) + + t.Run("call_tool propagates optimizer error as tool error result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + opt := &stubOptimizer{callErr: errors.New("backend unreachable")} + dec := optimizerdec.NewDecorator(base, opt) + + args := map[string]any{"tool_name": "backend_fetch"} + result, err := dec.CallTool(context.Background(), nil, "call_tool", args, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) + }) + + t.Run("call_tool returns error result when tool_name missing", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + dec := optimizerdec.NewDecorator(base, &stubOptimizer{}) + + result, err := dec.CallTool(context.Background(), nil, "call_tool", map[string]any{}, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) + }) + + t.Run("unknown tool returns error", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + dec := optimizerdec.NewDecorator(base, &stubOptimizer{}) + + _, err := dec.CallTool(context.Background(), nil, "nonexistent_tool", nil, nil) + + require.Error(t, err) + }) +} diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_composite_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_composite_test.go index 485cb660e2..5bd875cdd9 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_composite_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_composite_test.go @@ -20,10 +20,15 @@ import ( "github.com/stacklok/toolhive/test/e2e/images" ) -// This test exercises composite tool execution through the optimizer's call_tool. -// Without the fix in injectOptimizerCapabilities, composite tools are registered -// with backend routing handlers (ToSDKTools) instead of workflow execution handlers -// (ToCompositeToolSDKTools), causing call_tool to fail with ErrToolNotFound. +// This test exercises composite tool discovery and execution through the +// optimizer's find_tool / call_tool interface. +// +// Composite tools are registered as session decorators (compositetools.Decorator) +// before the optimizer decorator is applied, so the optimizer indexes them +// alongside backend tools. Workflow steps reference backend tools via the +// "{workloadID}.{originalCapabilityName}" dot convention, which the session +// router resolves to the correct conflict-resolved routing table entry +// regardless of which conflict resolution strategy is in use. // // A lightweight fake embedding server replaces the heavyweight TEI image to keep // test setup fast while satisfying the optimizer's embedding service requirement. @@ -65,6 +70,14 @@ var _ = Describe("VirtualMCPServer Optimizer Composite Tools", Ordered, func() { "url": "{{.params.url}}", } + // Workflow steps use the "{workloadID}.{originalCapabilityName}" dot + // convention so that the session router can resolve them regardless of + // conflict resolution strategy. backendFetchToolName ("fetch") is the + // name the gofetch backend exposes; the aggregation override renames it + // to vmcpFetchToolName for clients, but the step references the + // original backend capability name via the dot convention. + fetchStepTool := backendName + "." + backendFetchToolName // "backend-opt-composite.fetch" + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ ObjectMeta: metav1.ObjectMeta{ Name: vmcpServerName, @@ -103,13 +116,13 @@ var _ = Describe("VirtualMCPServer Optimizer Composite Tools", Ordered, func() { { ID: "first_fetch", Type: "tool", - Tool: vmcpFetchToolName, + Tool: fetchStepTool, Arguments: thvjson.NewMap(stepArgs), }, { ID: "second_fetch", Type: "tool", - Tool: vmcpFetchToolName, + Tool: fetchStepTool, DependsOn: []string{"first_fetch"}, Arguments: thvjson.NewMap(stepArgs), }, diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go index 2b35e51859..26469b1e7e 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go @@ -69,6 +69,13 @@ var _ = Describe("VirtualMCPServer Optimizer Mode", Ordered, func() { "url": "{{.params.url}}", } + // Workflow steps use the "{workloadID}.{originalCapabilityName}" dot + // convention so the session router resolves them regardless of conflict + // resolution strategy. backendFetchToolName ("fetch") is the original + // backend capability; the aggregation override renames it to + // vmcpFetchToolName for clients, but steps must reference the original. + fetchStepTool := backendName + "." + backendFetchToolName // "backend-optimizer-fetch.fetch" + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ ObjectMeta: metav1.ObjectMeta{ Name: vmcpServerName, @@ -110,13 +117,13 @@ var _ = Describe("VirtualMCPServer Optimizer Mode", Ordered, func() { { ID: "first_fetch", Type: "tool", - Tool: vmcpFetchToolName, + Tool: fetchStepTool, Arguments: thvjson.NewMap(stepArgs), }, { ID: "second_fetch", Type: "tool", - Tool: vmcpFetchToolName, + Tool: fetchStepTool, DependsOn: []string{"first_fetch"}, Arguments: thvjson.NewMap(stepArgs), },