diff --git a/pkg/vmcp/composer/elicitation_integration_test.go b/pkg/vmcp/composer/elicitation_integration_test.go index bff4d70fc6..d4edd40c05 100644 --- a/pkg/vmcp/composer/elicitation_integration_test.go +++ b/pkg/vmcp/composer/elicitation_integration_test.go @@ -34,7 +34,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_Accept(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "deployment-workflow", @@ -151,7 +151,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_Decline(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "test-workflow", @@ -228,7 +228,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_Cancel(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "test-workflow", @@ -275,7 +275,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_Timeout(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "test-workflow", @@ -309,7 +309,7 @@ func TestWorkflowEngine_ExecuteElicitationStep_NoHandler(t *testing.T) { te := newTestEngine(t) // Create engine WITHOUT elicitation handler - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, nil, nil) workflow := &WorkflowDefinition{ Name: "test-workflow", @@ -348,7 +348,7 @@ func TestWorkflowEngine_MultiStepWithElicitation(t *testing.T) { handler := NewDefaultElicitationHandler(mockSDK) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil) + engine := NewWorkflowEngine(te.Router, te.Backend, handler, stateStore, nil, nil) workflow := &WorkflowDefinition{ Name: "multi-step-workflow", diff --git a/pkg/vmcp/composer/testhelpers_test.go b/pkg/vmcp/composer/testhelpers_test.go index 29b8cbac78..be1cdb89ba 100644 --- a/pkg/vmcp/composer/testhelpers_test.go +++ b/pkg/vmcp/composer/testhelpers_test.go @@ -30,8 +30,14 @@ func newTestEngine(t *testing.T) *testEngine { t.Cleanup(ctrl.Finish) mockRouter := routermocks.NewMockRouter(ctrl) + // ResolveToolName is called by getToolInputSchema on every tool step. + // For tests that use NewWorkflowEngine (no tools list), the result is + // always nil, so a pass-through AnyTimes expectation is sufficient. + mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, name string) string { return name }). + AnyTimes() mockBackend := mocks.NewMockBackendClient(ctrl) - engine := NewWorkflowEngine(mockRouter, mockBackend, nil, nil, nil) // nil elicitationHandler, stateStore, and auditor for simple tests + engine := NewWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, nil) // nil elicitationHandler, stateStore, auditor, and tools for simple tests return &testEngine{ Engine: engine, diff --git a/pkg/vmcp/composer/workflow_audit_integration_test.go b/pkg/vmcp/composer/workflow_audit_integration_test.go index ae7c23ccbd..5e5e4d7ded 100644 --- a/pkg/vmcp/composer/workflow_audit_integration_test.go +++ b/pkg/vmcp/composer/workflow_audit_integration_test.go @@ -39,7 +39,7 @@ func TestWorkflowEngine_WithAuditor_SuccessfulWorkflow(t *testing.T) { require.NoError(t, err) // Create engine with auditor - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) // Setup simple workflow workflow := simpleWorkflow("audit-test", @@ -86,7 +86,7 @@ func TestWorkflowEngine_WithAuditor_FailedWorkflow(t *testing.T) { }) require.NoError(t, err) - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) workflow := simpleWorkflow("fail-test", toolStep("step1", "tool1", map[string]any{"arg": "value"}), @@ -119,7 +119,7 @@ func TestWorkflowEngine_WithAuditor_WorkflowTimeout(t *testing.T) { }) require.NoError(t, err) - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) workflow := &WorkflowDefinition{ Name: "timeout-test", @@ -162,7 +162,7 @@ func TestWorkflowEngine_WithAuditor_StepSkipped(t *testing.T) { }) require.NoError(t, err) - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) workflow := &WorkflowDefinition{ Name: "skip-test", @@ -215,7 +215,7 @@ func TestWorkflowEngine_WithAuditor_RetryStep(t *testing.T) { }) require.NoError(t, err) - engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor) + engine := NewWorkflowEngine(te.Router, te.Backend, nil, nil, auditor, nil) workflow := &WorkflowDefinition{ Name: "retry-test", diff --git a/pkg/vmcp/composer/workflow_engine.go b/pkg/vmcp/composer/workflow_engine.go index c1d5b755c6..ecba1b7b10 100644 --- a/pkg/vmcp/composer/workflow_engine.go +++ b/pkg/vmcp/composer/workflow_engine.go @@ -17,7 +17,6 @@ import ( "github.com/stacklok/toolhive/pkg/audit" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/conversion" - "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/schema" ) @@ -46,6 +45,10 @@ type workflowEngine struct { // backendClient makes calls to backend MCP servers. backendClient vmcp.BackendClient + // tools is the resolved tool list for the session, used by getToolInputSchema + // for argument type coercion. Nil means no schema-based coercion (discovery-based routing). + tools []vmcp.Tool + // templateExpander handles template expansion. templateExpander TemplateExpander @@ -67,12 +70,12 @@ type workflowEngine struct { // NewWorkflowEngine creates a new workflow execution engine. // -// The elicitationHandler parameter is optional. If nil, elicitation steps will fail. -// This allows the engine to be used without elicitation support for simple workflows. +// tools is the resolved tool list for schema-based argument type coercion. Pass nil +// when the engine is used for validation or discovery-based routing only. // +// The elicitationHandler parameter is optional. If nil, elicitation steps will fail. // The stateStore parameter is optional. If nil, workflow status tracking and cancellation // will not be available. Use NewInMemoryStateStore() for basic state tracking. -// // The auditor parameter is optional. If nil, workflow execution will not be audited. func NewWorkflowEngine( rtr router.Router, @@ -80,6 +83,7 @@ func NewWorkflowEngine( elicitationHandler ElicitationProtocolHandler, stateStore WorkflowStateStore, auditor *audit.WorkflowAuditor, + tools []vmcp.Tool, ) Composer { return &workflowEngine{ router: rtr, @@ -90,6 +94,7 @@ func NewWorkflowEngine( dagExecutor: newDAGExecutor(defaultMaxParallelSteps), stateStore: stateStore, auditor: auditor, + tools: tools, } } @@ -1223,20 +1228,19 @@ func (e *workflowEngine) auditStepSkipped( } } -// getToolInputSchema looks up a tool's InputSchema from discovered capabilities. -// Returns nil if the tool is not found or capabilities are not in context. -func (*workflowEngine) getToolInputSchema(ctx context.Context, toolName string) map[string]any { - caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) - if !ok || caps == nil { - return nil - } - - // Search in backend tools - for i := range caps.Tools { - if caps.Tools[i].Name == toolName { - return caps.Tools[i].InputSchema +// getToolInputSchema looks up a tool's InputSchema from the session-bound tools +// list. If toolName uses the dot convention "{workloadID}.{originalCapabilityName}", +// ResolveToolName is called to translate it to the conflict-resolved key before +// lookup. Returns nil if the engine has no tools list or the tool is not found. +func (e *workflowEngine) getToolInputSchema(ctx context.Context, toolName string) map[string]any { + resolved := toolName + if e.router != nil { + resolved = e.router.ResolveToolName(ctx, toolName) + } + for i := range e.tools { + if e.tools[i].Name == resolved { + return e.tools[i].InputSchema } } - return nil } diff --git a/pkg/vmcp/composer/workflow_engine_test.go b/pkg/vmcp/composer/workflow_engine_test.go index 8c4039b202..323937a3d7 100644 --- a/pkg/vmcp/composer/workflow_engine_test.go +++ b/pkg/vmcp/composer/workflow_engine_test.go @@ -390,9 +390,12 @@ func TestWorkflowEngine_ParallelExecution(t *testing.T) { defer ctrl.Finish() mockRouter := routermocks.NewMockRouter(ctrl) + mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, name string) string { return name }). + AnyTimes() mockBackend := mocks.NewMockBackendClient(ctrl) stateStore := NewInMemoryStateStore(1*time.Minute, 1*time.Hour) - engine := NewWorkflowEngine(mockRouter, mockBackend, nil, stateStore, nil) + engine := NewWorkflowEngine(mockRouter, mockBackend, nil, stateStore, nil, nil) // Track execution timing to verify parallel execution var executionMu sync.Mutex @@ -741,3 +744,104 @@ func TestWorkflowEngine_WorkflowMetadataAvailableInTemplates(t *testing.T) { assert.Equal(t, WorkflowStatusCompleted, result.Status) assert.Len(t, result.Steps, 2) } + +func TestWorkflowEngine_SessionEngine_CoercesTemplateStringToTypedArg(t *testing.T) { + t.Parallel() + + // Template expansion always produces strings. When the engine is created + // with a bound tool list, getToolInputSchema resolves the target tool's InputSchema + // and the schema coercion layer converts "42" → 42 before calling the backend. + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routermocks.NewMockRouter(ctrl) + mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, name string) string { return name }). + AnyTimes() + mockBackend := mocks.NewMockBackendClient(ctrl) + + tools := []vmcp.Tool{ + { + Name: "count_items", + InputSchema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "limit": map[string]any{"type": "integer"}, + }, + }, + }, + } + + engine := NewWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools) + + target := &vmcp.BackendTarget{WorkloadID: "backend1", BaseURL: "http://backend1:8080"} + mockRouter.EXPECT().RouteTool(gomock.Any(), "count_items").Return(target, nil) + + // Expect the backend to receive the coerced integer, not the string "42". + coercedArgs := map[string]any{"limit": int64(42)} + mockBackend.EXPECT(). + CallTool(gomock.Any(), target, "count_items", coercedArgs, gomock.Any()). + Return(&vmcp.ToolCallResult{StructuredContent: map[string]any{"items": []any{}}, Content: []vmcp.Content{}}, nil) + + workflow := &WorkflowDefinition{ + Name: "coerce_test", + Parameters: map[string]any{ + "type": "object", + "properties": map[string]any{ + "n": map[string]any{"type": "string"}, + }, + }, + Steps: []WorkflowStep{ + { + ID: "step1", + Type: StepTypeTool, + Tool: "count_items", + // Template expansion produces a string; coercion must convert it to int. + Arguments: map[string]any{"limit": "{{.params.n}}"}, + }, + }, + } + + result, err := engine.ExecuteWorkflow(context.Background(), workflow, map[string]any{"n": "42"}) + require.NoError(t, err) + assert.Equal(t, WorkflowStatusCompleted, result.Status) +} + +func TestWorkflowEngine_SessionEngine_ToolNotInList_ReturnsNilSchema(t *testing.T) { + t.Parallel() + + // When a bound tool list is provided but the requested tool is not in it, + // getToolInputSchema returns nil and coercion is a no-op. + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRouter := routermocks.NewMockRouter(ctrl) + mockRouter.EXPECT().ResolveToolName(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, name string) string { return name }). + AnyTimes() + mockBackend := mocks.NewMockBackendClient(ctrl) + + // Tools list does not include "other_tool". + tools := []vmcp.Tool{{Name: "known_tool", InputSchema: map[string]any{"type": "object"}}} + engine := NewWorkflowEngine(mockRouter, mockBackend, nil, nil, nil, tools) + + target := &vmcp.BackendTarget{WorkloadID: "backend1", BaseURL: "http://backend1:8080"} + mockRouter.EXPECT().RouteTool(gomock.Any(), "other_tool").Return(target, nil) + + // Args pass through unmodified (string stays a string). + rawArgs := map[string]any{"value": "hello"} + mockBackend.EXPECT(). + CallTool(gomock.Any(), target, "other_tool", rawArgs, gomock.Any()). + Return(&vmcp.ToolCallResult{StructuredContent: map[string]any{"ok": true}, Content: []vmcp.Content{}}, nil) + + workflow := &WorkflowDefinition{ + Name: "no_schema_test", + Steps: []WorkflowStep{ + {ID: "s1", Type: StepTypeTool, Tool: "other_tool", Arguments: rawArgs}, + }, + } + + result, err := engine.ExecuteWorkflow(context.Background(), workflow, nil) + require.NoError(t, err) + assert.Equal(t, WorkflowStatusCompleted, result.Status) +} diff --git a/pkg/vmcp/discovery/middleware.go b/pkg/vmcp/discovery/middleware.go index 6b322307bf..323c7a8bc2 100644 --- a/pkg/vmcp/discovery/middleware.go +++ b/pkg/vmcp/discovery/middleware.go @@ -281,10 +281,12 @@ func handleSubsequentRequest( return ctx, fmt.Errorf("session not found: %s", sessionID) } - // Backend tool calls are routed by session-scoped handlers registered with the SDK. - // However, composite tool workflow steps go through the shared router which requires - // DiscoveredCapabilities in the context. Inject capabilities built from the session's - // routing table so composite workflows can route backend tool calls correctly. + // Backend tool handlers (created by DefaultHandlerFactory) resolve their backend + // target by calling router.RouteTool(ctx, name), which reads DiscoveredCapabilities + // from the request context. Inject capabilities built from the session's routing + // table so these handlers can route correctly on subsequent requests. + // Note: composite tool workflow engines are created per-session and route via + // SessionRouter directly, so they no longer depend on this context value. multiSess, isMulti := rawSess.(vmcpsession.MultiSession) if !isMulti { // The session is still a StreamableSession placeholder — Phase 2 diff --git a/pkg/vmcp/router/default_router.go b/pkg/vmcp/router/default_router.go index 199e94243d..3e6f3bfba7 100644 --- a/pkg/vmcp/router/default_router.go +++ b/pkg/vmcp/router/default_router.go @@ -89,6 +89,13 @@ func (*defaultRouter) RouteTool(ctx context.Context, toolName string) (*vmcp.Bac ) } +// ResolveToolName returns toolName unchanged. The defaultRouter has no static +// routing table, so dot-convention resolution is not available; the caller +// should already be using resolved names when working with this router. +func (*defaultRouter) ResolveToolName(_ context.Context, toolName string) string { + return toolName +} + // RouteResource resolves a resource URI to its backend target. // With lazy discovery, this method gets capabilities from the request context // instead of using a cached routing table. diff --git a/pkg/vmcp/router/mocks/mock_router.go b/pkg/vmcp/router/mocks/mock_router.go index 6db4f729b5..768522c506 100644 --- a/pkg/vmcp/router/mocks/mock_router.go +++ b/pkg/vmcp/router/mocks/mock_router.go @@ -41,6 +41,20 @@ func (m *MockRouter) EXPECT() *MockRouterMockRecorder { return m.recorder } +// ResolveToolName mocks base method. +func (m *MockRouter) ResolveToolName(ctx context.Context, toolName string) string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ResolveToolName", ctx, toolName) + ret0, _ := ret[0].(string) + return ret0 +} + +// ResolveToolName indicates an expected call of ResolveToolName. +func (mr *MockRouterMockRecorder) ResolveToolName(ctx, toolName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveToolName", reflect.TypeOf((*MockRouter)(nil).ResolveToolName), ctx, toolName) +} + // RoutePrompt mocks base method. func (m *MockRouter) RoutePrompt(ctx context.Context, name string) (*vmcp.BackendTarget, error) { m.ctrl.T.Helper() diff --git a/pkg/vmcp/router/router.go b/pkg/vmcp/router/router.go index 9b4cb52f59..63d7fc299b 100644 --- a/pkg/vmcp/router/router.go +++ b/pkg/vmcp/router/router.go @@ -28,6 +28,14 @@ type Router interface { // Returns ErrToolNotFound if the tool doesn't exist in any backend. RouteTool(ctx context.Context, toolName string) (*vmcp.BackendTarget, error) + // ResolveToolName translates a tool name (which may use the dot-convention + // "{workloadID}.{originalCapabilityName}") to the conflict-resolved routing + // table key used in the session tools list. Returns toolName unchanged when + // the name cannot be resolved or the router has no static routing table — + // pass-through semantics so callers can use the result directly without + // special-casing the unresolvable case. + ResolveToolName(ctx context.Context, toolName string) string + // RouteResource resolves a resource URI to its backend target. // Returns ErrResourceNotFound if the resource doesn't exist in any backend. RouteResource(ctx context.Context, uri string) (*vmcp.BackendTarget, error) diff --git a/pkg/vmcp/router/session_router.go b/pkg/vmcp/router/session_router.go new file mode 100644 index 0000000000..436cc18c70 --- /dev/null +++ b/pkg/vmcp/router/session_router.go @@ -0,0 +1,128 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package router + +import ( + "context" + "fmt" + "strings" + + "github.com/stacklok/toolhive/pkg/vmcp" +) + +// sessionRouter is a Router implementation backed directly by a RoutingTable, +// requiring no request context to resolve capabilities. It is used by +// per-session workflow engines so that composite tool execution does not depend +// on the discovery middleware injecting DiscoveredCapabilities into the context. +type sessionRouter struct { + routingTable *vmcp.RoutingTable +} + +// NewSessionRouter creates a Router that routes from the provided RoutingTable +// without reading the request context. This is the preferred router for +// composite tool workflow engines because it couples routing to the session +// rather than to middleware-managed context values. +func NewSessionRouter(rt *vmcp.RoutingTable) Router { + return &sessionRouter{routingTable: rt} +} + +// RouteTool resolves a tool name to its backend target using the session's +// routing table directly. +// +// Two naming conventions are supported: +// +// 1. Exact key: the resolved/conflict-resolved name stored in the routing +// table (e.g. "my-backend_echo" after prefix conflict resolution). +// +// 2. Dot convention "{workloadID}.{toolName}": the tool name is the original +// backend capability name and the workload ID is the prefix. This mirrors +// the isToolStepAccessible logic used when registering composite tools and +// lets workflow step definitions remain stable regardless of the conflict +// resolution strategy in use. +// +// The dot convention is necessary because composite workflow steps reference +// tools by their pre-conflict-resolution name (e.g. "my-backend.echo"), while +// the routing table may store them under a prefixed key ("my-backend_echo"). +func (r *sessionRouter) RouteTool(_ context.Context, toolName string) (*vmcp.BackendTarget, error) { + if r.routingTable == nil || r.routingTable.Tools == nil { + return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName) + } + + // Fast path: exact key match. + if target, exists := r.routingTable.Tools[toolName]; exists { + return target, nil + } + + // Fallback: dot convention "{workloadID}.{toolName}". + // Workload IDs are Kubernetes resource names and cannot contain dots, + // so the first dot unambiguously separates the workload ID from the + // original backend capability name. + if dotIdx := strings.Index(toolName, "."); dotIdx > 0 { + workloadID := toolName[:dotIdx] + capName := toolName[dotIdx+1:] + for resolvedName, target := range r.routingTable.Tools { + if target.WorkloadID == workloadID && target.GetBackendCapabilityName(resolvedName) == capName { + return target, nil + } + } + } + + return nil, fmt.Errorf("%w: %s", ErrToolNotFound, toolName) +} + +// ResolveToolName returns the routing table key (conflict-resolved name) for +// toolName. If toolName is an exact key it is returned unchanged. If it uses +// the dot convention "{workloadID}.{originalCapabilityName}", the matching +// routing table key is returned. Falls back to returning toolName unchanged +// when the routing table is absent or the name cannot be resolved (pass-through +// semantics, consistent with the Router interface contract). +func (r *sessionRouter) ResolveToolName(_ context.Context, toolName string) string { + if r.routingTable == nil || r.routingTable.Tools == nil { + return toolName + } + + // Fast path: exact key match. + if _, exists := r.routingTable.Tools[toolName]; exists { + return toolName + } + + // Fallback: dot convention "{workloadID}.{toolName}". + if dotIdx := strings.Index(toolName, "."); dotIdx > 0 { + workloadID := toolName[:dotIdx] + capName := toolName[dotIdx+1:] + for resolvedName, target := range r.routingTable.Tools { + if target.WorkloadID == workloadID && target.GetBackendCapabilityName(resolvedName) == capName { + return resolvedName + } + } + } + + return toolName +} + +// RouteResource resolves a resource URI to its backend target using the +// session's routing table directly. +func (r *sessionRouter) RouteResource(_ context.Context, uri string) (*vmcp.BackendTarget, error) { + if r.routingTable == nil || r.routingTable.Resources == nil { + return nil, fmt.Errorf("%w: %s", ErrResourceNotFound, uri) + } + target, exists := r.routingTable.Resources[uri] + if !exists { + return nil, fmt.Errorf("%w: %s", ErrResourceNotFound, uri) + } + return target, nil +} + +// RoutePrompt resolves a prompt name to its backend target using the session's +// routing table directly. +func (r *sessionRouter) RoutePrompt(_ context.Context, name string) (*vmcp.BackendTarget, error) { + if r.routingTable == nil || r.routingTable.Prompts == nil { + return nil, fmt.Errorf("%w: %s", ErrPromptNotFound, name) + } + target, exists := r.routingTable.Prompts[name] + if !exists { + return nil, fmt.Errorf("%w: %s", ErrPromptNotFound, name) + } + return target, nil +} diff --git a/pkg/vmcp/router/session_router_test.go b/pkg/vmcp/router/session_router_test.go new file mode 100644 index 0000000000..0b1d520891 --- /dev/null +++ b/pkg/vmcp/router/session_router_test.go @@ -0,0 +1,383 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package router_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/router" +) + +func TestSessionRouter_RouteTool(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + routingTable *vmcp.RoutingTable + toolName string + expectedID string + expectError bool + errorContains string + }{ + { + name: "route to existing tool", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": { + WorkloadID: "backend1", + WorkloadName: "Backend 1", + BaseURL: "http://backend1:8080", + }, + }, + }, + toolName: "test_tool", + expectedID: "backend1", + expectError: false, + }, + { + name: "tool not found", + routingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + }, + toolName: "nonexistent_tool", + expectError: true, + errorContains: "tool not found", + }, + { + name: "nil routing table", + routingTable: nil, + toolName: "test_tool", + expectError: true, + errorContains: "tool not found", + }, + { + name: "nil tools map", + routingTable: &vmcp.RoutingTable{ + Tools: nil, + }, + toolName: "test_tool", + expectError: true, + errorContains: "tool not found", + }, + { + // Composite workflow steps use "{workloadID}.{toolName}" where toolName + // is the original backend capability name. With prefix conflict resolution + // the routing table key is "{workloadID}_toolName", so an exact match + // fails. The dot-convention fallback must resolve it correctly. + name: "dot convention resolved via workload ID and original capability name", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "my-backend_echo": { + WorkloadID: "my-backend", + WorkloadName: "My Backend", + BaseURL: "http://my-backend:8080", + OriginalCapabilityName: "echo", + }, + }, + }, + toolName: "my-backend.echo", + expectedID: "my-backend", + expectError: false, + }, + { + name: "dot convention: workload not in session", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "other-backend_echo": { + WorkloadID: "other-backend", + OriginalCapabilityName: "echo", + }, + }, + }, + toolName: "my-backend.echo", + expectError: true, + errorContains: "tool not found", + }, + { + name: "dot convention: capability name mismatch", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "my-backend_echo": { + WorkloadID: "my-backend", + OriginalCapabilityName: "echo", + }, + }, + }, + toolName: "my-backend.fetch", + expectError: true, + errorContains: "tool not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := router.NewSessionRouter(tt.routingTable) + target, err := r.RouteTool(context.Background(), tt.toolName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + assert.Nil(t, target) + } else { + require.NoError(t, err) + require.NotNil(t, target) + assert.Equal(t, tt.expectedID, target.WorkloadID) + } + }) + } +} + +func TestSessionRouter_ResolveToolName(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + routingTable *vmcp.RoutingTable + toolName string + expectedName string + }{ + { + name: "exact key returned unchanged", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "my-backend_echo": {WorkloadID: "my-backend", OriginalCapabilityName: "echo"}, + }, + }, + toolName: "my-backend_echo", + expectedName: "my-backend_echo", + }, + { + name: "dot convention resolves to routing table key", + routingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "my-backend_echo": {WorkloadID: "my-backend", OriginalCapabilityName: "echo"}, + }, + }, + toolName: "my-backend.echo", + expectedName: "my-backend_echo", + }, + { + name: "not found returns toolName unchanged (pass-through)", + routingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + }, + toolName: "missing_tool", + expectedName: "missing_tool", + }, + { + name: "nil routing table returns toolName unchanged (pass-through)", + routingTable: nil, + toolName: "any_tool", + expectedName: "any_tool", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := router.NewSessionRouter(tt.routingTable) + resolved := r.ResolveToolName(context.Background(), tt.toolName) + + assert.Equal(t, tt.expectedName, resolved) + }) + } +} + +func TestSessionRouter_RouteResource(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + routingTable *vmcp.RoutingTable + uri string + expectedID string + expectError bool + errorContains string + }{ + { + name: "route to existing resource", + routingTable: &vmcp.RoutingTable{ + Resources: map[string]*vmcp.BackendTarget{ + "file:///path/to/resource": { + WorkloadID: "backend2", + WorkloadName: "Backend 2", + BaseURL: "http://backend2:8080", + }, + }, + }, + uri: "file:///path/to/resource", + expectedID: "backend2", + expectError: false, + }, + { + name: "resource not found", + routingTable: &vmcp.RoutingTable{ + Resources: make(map[string]*vmcp.BackendTarget), + }, + uri: "file:///nonexistent", + expectError: true, + errorContains: "resource not found", + }, + { + name: "nil routing table", + routingTable: nil, + uri: "file:///test", + expectError: true, + errorContains: "resource not found", + }, + { + name: "nil resources map", + routingTable: &vmcp.RoutingTable{ + Resources: nil, + }, + uri: "file:///test", + expectError: true, + errorContains: "resource not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := router.NewSessionRouter(tt.routingTable) + target, err := r.RouteResource(context.Background(), tt.uri) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + assert.Nil(t, target) + } else { + require.NoError(t, err) + require.NotNil(t, target) + assert.Equal(t, tt.expectedID, target.WorkloadID) + } + }) + } +} + +func TestSessionRouter_RoutePrompt(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + routingTable *vmcp.RoutingTable + promptName string + expectedID string + expectError bool + errorContains string + }{ + { + name: "route to existing prompt", + routingTable: &vmcp.RoutingTable{ + Prompts: map[string]*vmcp.BackendTarget{ + "greeting": { + WorkloadID: "backend3", + WorkloadName: "Backend 3", + BaseURL: "http://backend3:8080", + }, + }, + }, + promptName: "greeting", + expectedID: "backend3", + expectError: false, + }, + { + name: "prompt not found", + routingTable: &vmcp.RoutingTable{ + Prompts: make(map[string]*vmcp.BackendTarget), + }, + promptName: "nonexistent", + expectError: true, + errorContains: "prompt not found", + }, + { + name: "nil routing table", + routingTable: nil, + promptName: "test", + expectError: true, + errorContains: "prompt not found", + }, + { + name: "nil prompts map", + routingTable: &vmcp.RoutingTable{ + Prompts: nil, + }, + promptName: "test", + expectError: true, + errorContains: "prompt not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + r := router.NewSessionRouter(tt.routingTable) + target, err := r.RoutePrompt(context.Background(), tt.promptName) + + if tt.expectError { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errorContains) + assert.Nil(t, target) + } else { + require.NoError(t, err) + require.NotNil(t, target) + assert.Equal(t, tt.expectedID, target.WorkloadID) + } + }) + } +} + +func TestSessionRouter_ConcurrentAccess(t *testing.T) { + t.Parallel() + + table := &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "tool1": {WorkloadID: "backend1"}, + "tool2": {WorkloadID: "backend2"}, + }, + Resources: map[string]*vmcp.BackendTarget{ + "res1": {WorkloadID: "backend1"}, + }, + Prompts: map[string]*vmcp.BackendTarget{ + "prompt1": {WorkloadID: "backend2"}, + }, + } + + r := router.NewSessionRouter(table) + ctx := context.Background() + + const numGoroutines = 10 + const numOperations = 100 + + done := make(chan bool, numGoroutines) + + for i := 0; i < numGoroutines; i++ { + go func() { + for j := 0; j < numOperations; j++ { + _, _ = r.RouteTool(ctx, "tool1") + _, _ = r.RouteResource(ctx, "res1") + _, _ = r.RoutePrompt(ctx, "prompt1") + } + done <- true + }() + } + + for i := 0; i < numGoroutines; i++ { + <-done + } + + target, err := r.RouteTool(ctx, "tool1") + require.NoError(t, err) + assert.Equal(t, "backend1", target.WorkloadID) +} diff --git a/pkg/vmcp/server/adapter/handler_factory.go b/pkg/vmcp/server/adapter/handler_factory.go index 73c5ddc12c..c14d73e15b 100644 --- a/pkg/vmcp/server/adapter/handler_factory.go +++ b/pkg/vmcp/server/adapter/handler_factory.go @@ -19,6 +19,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/conversion" "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/router" + "github.com/stacklok/toolhive/pkg/vmcp/session/compositetools" ) //go:generate mockgen -destination=mocks/mock_handler_factory.go -package=mocks github.com/stacklok/toolhive/pkg/vmcp/server/adapter HandlerFactory @@ -43,20 +44,13 @@ type HandlerFactory interface { } // WorkflowExecutor executes composite tool workflows. -// This interface abstracts the composer to enable testing without full composer setup. -type WorkflowExecutor interface { - // ExecuteWorkflow executes the workflow with the given parameters. - ExecuteWorkflow(ctx context.Context, params map[string]any) (*WorkflowResult, error) -} +// Type alias for compositetools.WorkflowExecutor so that adapter consumers and +// the session decorator share a single interface definition. +type WorkflowExecutor = compositetools.WorkflowExecutor // WorkflowResult represents the result of a workflow execution. -type WorkflowResult struct { - // Output contains the workflow output data (typically from the last step). - Output map[string]any - - // Error contains error information if the workflow failed. - Error error -} +// Type alias for compositetools.WorkflowResult. +type WorkflowResult = compositetools.WorkflowResult // DefaultHandlerFactory creates MCP request handlers that route to backend workloads. type DefaultHandlerFactory struct { diff --git a/pkg/vmcp/server/adapter/optimizer_adapter.go b/pkg/vmcp/server/adapter/optimizer_adapter.go index 717d4745e4..b962a0de88 100644 --- a/pkg/vmcp/server/adapter/optimizer_adapter.go +++ b/pkg/vmcp/server/adapter/optimizer_adapter.go @@ -3,124 +3,10 @@ package adapter -import ( - "context" - "encoding/json" - "fmt" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" - "github.com/stacklok/toolhive/pkg/vmcp/schema" -) - -// OptimizerToolNames defines the tool names exposed when optimizer is enabled. +// OptimizerToolNames defines the tool names exposed when optimizer mode is enabled. +// These constants are kept here for backwards compatibility with existing tests and +// callers. The actual tool implementation lives in the optimizerdec session decorator. const ( FindToolName = "find_tool" CallToolName = "call_tool" ) - -// Pre-generated schemas for optimizer tools. -// Generated at package init time so any schema errors panic at startup. -var ( - findToolInputSchema = mustGenerateSchema[optimizer.FindToolInput]() - callToolInputSchema = mustGenerateSchema[optimizer.CallToolInput]() -) - -// CreateOptimizerTools creates the SDK tools for optimizer mode. -// When optimizer is enabled, only these two tools are exposed to clients -// instead of all backend tools. -func CreateOptimizerTools(opt optimizer.Optimizer) []server.ServerTool { - return []server.ServerTool{ - { - Tool: mcp.Tool{ - Name: FindToolName, - Description: "Find and return tools that can help accomplish the user's request. " + - "This searches available MCP server tools using semantic and keyword-based matching. " + - "Use this function when you need to: " + - "(1) discover what tools are available for a specific task, " + - "(2) find the right tool(s) before attempting to solve a problem, " + - "(3) check if required functionality exists in the current environment. " + - "Returns matching tools ranked by relevance including their names, descriptions, " + - "required parameters and schemas, plus token efficiency metrics showing " + - "baseline_tokens, returned_tokens, and savings_percent. " + - "Example: for 'Find good restaurants in San Jose', call with " + - "tool_description='search the web' and tool_keywords='web search restaurants'. " + - "Always call this before call_tool to discover the correct tool name and parameter schema.", - RawInputSchema: findToolInputSchema, - }, - Handler: createFindToolHandler(opt), - }, - { - Tool: mcp.Tool{ - Name: CallToolName, - Description: "Execute a specific tool with the provided parameters. " + - "Use this function to: " + - "(1) run a tool after identifying it with find_tool, " + - "(2) execute operations that require specific MCP server functionality, " + - "(3) perform actions that go beyond your built-in capabilities. " + - "Important: always use find_tool first to get the correct tool_name " + - "and parameter schema before calling this function. " + - "The parameters must match the tool's input schema as returned by find_tool. " + - "Returns the tool's execution result which may include success/failure status, " + - "result data or content, and error messages if execution failed.", - RawInputSchema: callToolInputSchema, - }, - Handler: createCallToolHandler(opt), - }, - } -} - -// createFindToolHandler creates a handler for the find_tool optimizer operation. -func createFindToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - input, err := schema.Translate[optimizer.FindToolInput](request.Params.Arguments) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil - } - - output, err := opt.FindTool(ctx, input) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("find_tool failed: %v", err)), nil - } - - return mcp.NewToolResultStructuredOnly(output), nil - } -} - -// createCallToolHandler creates a handler for the call_tool optimizer operation. -func createCallToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - input, err := schema.Translate[optimizer.CallToolInput](request.Params.Arguments) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil - } - - result, err := opt.CallTool(ctx, input) - if err != nil { - // Exposing the error to the MCP client is important if you want it to correct its behavior. - // Without information on the failure, the model is pretty much hopeless in figuring out the problem. - return mcp.NewToolResultError(fmt.Sprintf("call_tool failed: %v", err)), nil - } - - return result, nil - } -} - -// mustMarshalSchema marshals a schema to JSON, panicking on error. -// This is safe because schemas are generated from known types at startup. -// This should NOT be called by runtime code. -func mustGenerateSchema[T any]() json.RawMessage { - s, err := schema.GenerateSchema[T]() - if err != nil { - panic(fmt.Sprintf("failed to generate schema: %v", err)) - } - - data, err := json.Marshal(s) - if err != nil { - panic(fmt.Sprintf("failed to marshal schema: %v", err)) - } - - return data -} diff --git a/pkg/vmcp/server/adapter/optimizer_adapter_test.go b/pkg/vmcp/server/adapter/optimizer_adapter_test.go index f6228746c7..53fba4b1a9 100644 --- a/pkg/vmcp/server/adapter/optimizer_adapter_test.go +++ b/pkg/vmcp/server/adapter/optimizer_adapter_test.go @@ -4,103 +4,14 @@ package adapter import ( - "context" "testing" - "github.com/mark3labs/mcp-go/mcp" - "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + "github.com/stretchr/testify/assert" ) -// mockOptimizer implements optimizer.Optimizer for testing. -type mockOptimizer struct { - findToolFunc func(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) - callToolFunc func(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) -} - -func (m *mockOptimizer) FindTool(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { - if m.findToolFunc != nil { - return m.findToolFunc(ctx, input) - } - return &optimizer.FindToolOutput{}, nil -} - -func (m *mockOptimizer) CallTool(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { - if m.callToolFunc != nil { - return m.callToolFunc(ctx, input) - } - return mcp.NewToolResultText("ok"), nil -} - -func TestCreateOptimizerTools(t *testing.T) { - t.Parallel() - - opt := &mockOptimizer{} - tools := CreateOptimizerTools(opt) - - require.Len(t, tools, 2) - require.Equal(t, FindToolName, tools[0].Tool.Name) - require.Equal(t, CallToolName, tools[1].Tool.Name) -} - -func TestFindToolHandler(t *testing.T) { +func TestOptimizerToolNameConstants(t *testing.T) { t.Parallel() - opt := &mockOptimizer{ - findToolFunc: func(_ context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { - require.Equal(t, "read files", input.ToolDescription) - return &optimizer.FindToolOutput{ - Tools: []mcp.Tool{ - { - Name: "read_file", - Description: "Read a file", - }, - }, - }, nil - }, - } - - tools := CreateOptimizerTools(opt) - handler := tools[0].Handler - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]any{ - "tool_description": "read files", - } - - result, err := handler(context.Background(), request) - require.NoError(t, err) - require.NotNil(t, result) - require.False(t, result.IsError) - require.Len(t, result.Content, 1) -} - -func TestCallToolHandler(t *testing.T) { - t.Parallel() - - opt := &mockOptimizer{ - callToolFunc: func(_ context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { - require.Equal(t, "read_file", input.ToolName) - require.Equal(t, "/etc/hosts", input.Parameters["path"]) - return mcp.NewToolResultText("file contents here"), nil - }, - } - - tools := CreateOptimizerTools(opt) - handler := tools[1].Handler - - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]any{ - "tool_name": "read_file", - "parameters": map[string]any{ - "path": "/etc/hosts", - }, - } - - result, err := handler(context.Background(), request) - require.NoError(t, err) - require.NotNil(t, result) - require.False(t, result.IsError) - require.Len(t, result.Content, 1) + assert.Equal(t, "find_tool", FindToolName) + assert.Equal(t, "call_tool", CallToolName) } diff --git a/pkg/vmcp/server/composite_tool_converter.go b/pkg/vmcp/server/composite_tool_converter.go index 48a2b70448..953b0bb6cf 100644 --- a/pkg/vmcp/server/composite_tool_converter.go +++ b/pkg/vmcp/server/composite_tool_converter.go @@ -7,12 +7,95 @@ package server import ( "fmt" + "strings" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/composer" "github.com/stacklok/toolhive/pkg/vmcp/config" ) +// filterWorkflowDefsForSession returns only the workflow definitions whose every +// tool step references a backend tool that is present in the session routing table. +// +// If a session does not have access to a backend tool (e.g. due to identity-based +// filtering), any composite tool that depends on that backend tool is also excluded. +// This prevents a session from invoking a composite tool that would fail at runtime +// because one or more of its underlying tools are not routable for that session. +func filterWorkflowDefsForSession( + defs map[string]*composer.WorkflowDefinition, + rt *vmcp.RoutingTable, +) map[string]*composer.WorkflowDefinition { + if len(defs) == 0 { + return defs + } + + filtered := make(map[string]*composer.WorkflowDefinition, len(defs)) + for name, def := range defs { + if allToolStepsAccessible(def, rt) { + filtered[name] = def + } + } + return filtered +} + +// allToolStepsAccessible reports whether every tool step in the workflow +// references a backend tool that is present in the session routing table. +// Returns false if rt is nil and the workflow contains any tool steps, +// since a nil routing table means no tools are routable in this session. +// +// Composite tool step names use the convention "{workloadID}.{toolName}" where +// workloadID is a Kubernetes resource name (no dots). The routing table may store +// tools under resolved/prefixed names (e.g. "{workloadID}_echo" with prefix strategy), +// so we look up by BackendTarget.WorkloadID rather than the resolved key directly. +func allToolStepsAccessible(def *composer.WorkflowDefinition, rt *vmcp.RoutingTable) bool { + for _, step := range def.Steps { + if step.Type == composer.StepTypeTool { + if rt == nil { + return false + } + if !isToolStepAccessible(step.Tool, rt) { + return false + } + } + } + return true +} + +// isToolStepAccessible reports whether a composite tool step's tool name can be +// resolved to an accessible backend tool in the given routing table. +// +// Step tool names use the "{workloadID}.{toolName}" convention. Since conflict +// resolution strategies (e.g. prefix) may rename tools in the routing table +// (e.g. "echo" → "yardstick-backend_echo"), we check for accessibility by +// matching on WorkloadID and the original backend capability name rather than +// the resolved routing table key. +func isToolStepAccessible(stepTool string, rt *vmcp.RoutingTable) bool { + // Fast path: exact match in the routing table. + if _, ok := rt.Tools[stepTool]; ok { + return true + } + + // Parse "{workloadID}.{toolName}" convention. + // Workload IDs are Kubernetes resource names and cannot contain dots, + // so the first dot separates the workload ID from the tool name. + dotIdx := strings.Index(stepTool, ".") + if dotIdx <= 0 { + return false + } + workloadID := stepTool[:dotIdx] + originalName := stepTool[dotIdx+1:] + + for resolvedName, target := range rt.Tools { + if target.WorkloadID != workloadID { + continue + } + if target.GetBackendCapabilityName(resolvedName) == originalName { + return true + } + } + return false +} + // convertWorkflowDefsToTools converts workflow definitions to vmcp.Tool format. // // This creates the tool metadata (name, description, schema) that gets exposed diff --git a/pkg/vmcp/server/composite_tool_converter_test.go b/pkg/vmcp/server/composite_tool_converter_test.go index d8c32cecc9..6c05200a85 100644 --- a/pkg/vmcp/server/composite_tool_converter_test.go +++ b/pkg/vmcp/server/composite_tool_converter_test.go @@ -8,6 +8,7 @@ import ( "github.com/google/go-cmp/cmp" + "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/composer" "github.com/stacklok/toolhive/pkg/vmcp/config" ) @@ -375,6 +376,166 @@ func TestConvertWorkflowDefsToToolsWithOutputSchema(t *testing.T) { } } +func TestFilterWorkflowDefsForSession(t *testing.T) { + t.Parallel() + + makeRT := func(toolNames ...string) *vmcp.RoutingTable { + rt := &vmcp.RoutingTable{Tools: make(map[string]*vmcp.BackendTarget)} + for _, name := range toolNames { + rt.Tools[name] = &vmcp.BackendTarget{WorkloadID: name} + } + return rt + } + + tests := []struct { + name string + defs map[string]*composer.WorkflowDefinition + rt *vmcp.RoutingTable + wantNames []string // workflow names expected in result + }{ + { + name: "empty defs", + defs: map[string]*composer.WorkflowDefinition{}, + rt: makeRT("tool_a"), + wantNames: []string{}, + }, + { + name: "all tools accessible", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{{ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}}, + }, + }, + rt: makeRT("tool_a", "tool_b"), + wantNames: []string{"wf1"}, + }, + { + name: "missing tool excludes workflow", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{{ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}}, + }, + }, + rt: makeRT("tool_b"), + wantNames: []string{}, + }, + { + name: "partially accessible: only accessible workflow included", + defs: map[string]*composer.WorkflowDefinition{ + "wf_ok": { + Name: "wf_ok", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}, + }, + }, + "wf_restricted": { + Name: "wf_restricted", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}, + {ID: "s2", Type: composer.StepTypeTool, Tool: "tool_secret"}, + }, + }, + }, + rt: makeRT("tool_a"), + wantNames: []string{"wf_ok"}, + }, + { + name: "elicitation steps do not require routing table entry", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeElicitation}, + {ID: "s2", Type: composer.StepTypeTool, Tool: "tool_a"}, + }, + }, + }, + rt: makeRT("tool_a"), + wantNames: []string{"wf1"}, + }, + { + // Composite tool steps use "{workloadID}.{toolName}" convention. + // With prefix conflict resolution the routing table key is + // "{workloadID}_echo", but the step still uses "{workloadID}.echo". + // The filter must resolve via WorkloadID + OriginalCapabilityName. + name: "dotted step tool resolved via workload ID and original name", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "my-backend.echo"}, + }, + }, + }, + rt: func() *vmcp.RoutingTable { + rt := &vmcp.RoutingTable{Tools: make(map[string]*vmcp.BackendTarget)} + // Prefix strategy stores "my-backend_echo" as the resolved key. + rt.Tools["my-backend_echo"] = &vmcp.BackendTarget{ + WorkloadID: "my-backend", + OriginalCapabilityName: "echo", + } + return rt + }(), + wantNames: []string{"wf1"}, + }, + { + name: "dotted step tool excluded when workload not in session", + defs: map[string]*composer.WorkflowDefinition{ + "wf1": { + Name: "wf1", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "restricted-backend.echo"}, + }, + }, + }, + rt: func() *vmcp.RoutingTable { + rt := &vmcp.RoutingTable{Tools: make(map[string]*vmcp.BackendTarget)} + rt.Tools["other-backend_echo"] = &vmcp.BackendTarget{ + WorkloadID: "other-backend", + OriginalCapabilityName: "echo", + } + return rt + }(), + wantNames: []string{}, + }, + { + name: "nil routing table excludes workflows with tool steps", + defs: map[string]*composer.WorkflowDefinition{ + "wf_tool": { + Name: "wf_tool", + Steps: []composer.WorkflowStep{{ID: "s1", Type: composer.StepTypeTool, Tool: "tool_a"}}, + }, + "wf_elicit_only": { + Name: "wf_elicit_only", + Steps: []composer.WorkflowStep{{ID: "s1", Type: composer.StepTypeElicitation}}, + }, + }, + rt: nil, + wantNames: []string{"wf_elicit_only"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := filterWorkflowDefsForSession(tt.defs, tt.rt) + + if len(got) != len(tt.wantNames) { + t.Errorf("filterWorkflowDefsForSession() returned %d defs, want %d (%v)", + len(got), len(tt.wantNames), tt.wantNames) + } + for _, name := range tt.wantNames { + if _, ok := got[name]; !ok { + t.Errorf("expected workflow %q in result but it was absent", name) + } + } + }) + } +} + func TestBuildOutputPropertySchema(t *testing.T) { t.Parallel() diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 2d267ce4a3..23b1d7523d 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -37,6 +37,9 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/server/adapter" "github.com/stacklok/toolhive/pkg/vmcp/server/sessionmanager" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" + "github.com/stacklok/toolhive/pkg/vmcp/session/compositetools" + "github.com/stacklok/toolhive/pkg/vmcp/session/optimizerdec" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" vmcpstatus "github.com/stacklok/toolhive/pkg/vmcp/status" ) @@ -213,11 +216,17 @@ type Server struct { // Thread-safety: Safe for concurrent reads (no writes after initialization). workflowDefs map[string]*composer.WorkflowDefinition - // Workflow executors for composite tools (adapters around composer + definition). - // Used by capability adapter to create composite tool handlers. - // Initialized during construction and read-only thereafter. - // Thread-safety: Safe for concurrent reads (no writes after initialization). - workflowExecutors map[string]adapter.WorkflowExecutor + // composerFactory creates a per-session workflow Composer bound to the + // session's routing table and tool list. Called at session registration + // time so that composite tool execution routes via the session rather than + // depending on the discovery middleware injecting DiscoveredCapabilities + // into the request context. + composerFactory func(rt *vmcp.RoutingTable, tools []vmcp.Tool) composer.Composer + + // workflowInstruments holds pre-created OTEL metric instruments for workflow + // telemetry. Nil when telemetry is disabled. Created once at server startup + // and shared across all per-session executor wrappers. + workflowInstruments *workflowExecutorInstruments // Ready channel signals when the server is ready to accept connections. // Closed once the listener is created and serving. @@ -334,25 +343,36 @@ func New( // The composer orchestrates multi-step workflows across backends // Use in-memory state store with 5-minute cleanup interval and 1-hour max age for completed workflows stateStore := composer.NewInMemoryStateStore(5*time.Minute, 1*time.Hour) - workflowComposer := composer.NewWorkflowEngine(rt, backendClient, elicitationHandler, stateStore, workflowAuditor) + workflowComposer := composer.NewWorkflowEngine(rt, backendClient, elicitationHandler, stateStore, workflowAuditor, nil) + + // composerFactory builds a per-session workflow engine at session registration + // time, binding composite tool routing to the session's own routing table and + // tool list. This removes composite tools' dependency on the discovery middleware + // injecting DiscoveredCapabilities into the request context. + sessionComposerFactory := func(sessionRT *vmcp.RoutingTable, sessionTools []vmcp.Tool) composer.Composer { + return composer.NewWorkflowEngine( + router.NewSessionRouter(sessionRT), backendClient, elicitationHandler, stateStore, workflowAuditor, + sessionTools, + ) + } - // Validate workflows and create executors (fail fast on invalid workflows) - var workflowExecutors map[string]adapter.WorkflowExecutor + // Validate workflows (fail fast on invalid definitions) var err error - workflowDefs, workflowExecutors, err = validateAndCreateExecutors(workflowComposer, workflowDefs) + workflowDefs, err = validateWorkflows(workflowComposer, workflowDefs) if err != nil { return nil, fmt.Errorf("workflow validation failed: %w", err) } - // Decorate workflow executors with telemetry if provider is configured - if cfg.TelemetryProvider != nil && len(workflowExecutors) > 0 { - workflowExecutors, err = monitorWorkflowExecutors( + // Pre-create workflow telemetry instruments once so they can be reused + // across all per-session executor wrappers without re-registering metrics. + var workflowInstruments *workflowExecutorInstruments + if cfg.TelemetryProvider != nil && len(workflowDefs) > 0 { + workflowInstruments, err = newWorkflowExecutorInstruments( cfg.TelemetryProvider.MeterProvider(), cfg.TelemetryProvider.TracerProvider(), - workflowExecutors, ) if err != nil { - return nil, fmt.Errorf("failed to monitor workflow executors: %w", err) + return nil, fmt.Errorf("failed to create workflow executor telemetry: %w", err) } } @@ -396,21 +416,22 @@ func New( // Create Server instance srv := &Server{ - config: cfg, - mcpServer: mcpServer, - router: rt, - backendClient: backendClient, - handlerFactory: handlerFactory, - discoveryMgr: discoveryMgr, - backendRegistry: backendRegistry, - sessionManager: sessionManager, - capabilityAdapter: capabilityAdapter, - workflowDefs: workflowDefs, - workflowExecutors: workflowExecutors, - ready: make(chan struct{}), - healthMonitor: healthMon, - statusReporter: cfg.StatusReporter, - vmcpSessionMgr: vmcpSessMgr, + config: cfg, + mcpServer: mcpServer, + router: rt, + backendClient: backendClient, + handlerFactory: handlerFactory, + discoveryMgr: discoveryMgr, + backendRegistry: backendRegistry, + sessionManager: sessionManager, + capabilityAdapter: capabilityAdapter, + workflowDefs: workflowDefs, + composerFactory: sessionComposerFactory, + workflowInstruments: workflowInstruments, + ready: make(chan struct{}), + healthMonitor: healthMon, + statusReporter: cfg.StatusReporter, + vmcpSessionMgr: vmcpSessMgr, } // Register OnRegisterSession hook to inject capabilities after SDK registers session. @@ -993,6 +1014,26 @@ func (s *Server) handleSessionRegistrationImpl(ctx context.Context, session serv return retErr } + // Apply composite tools decorator (if workflow defs are configured). + // This appends composite tools to the session's tool list and routes their + // CallTool invocations through per-session workflow executors. + if len(s.workflowDefs) > 0 { + if retErr = s.applyCompositeToolsDecorator(sessionID); retErr != nil { + return retErr + } + } + + // Apply optimizer decorator (if configured). + // Must come after composite tools so the optimizer indexes composite tools too. + // Replaces the full tool list with find_tool + call_tool. + if s.config.OptimizerFactory != nil { + if retErr = s.applyOptimizerDecorator(ctx, sessionID); retErr != nil { + return retErr + } + } + + // Uniform registration — same code path regardless of which decorators are active. + // session.Tools() returns the final decorated tool list. adaptedTools, retErr := s.vmcpSessionMgr.GetAdaptedTools(sessionID) if retErr != nil { slog.Error("failed to get session-scoped tools", @@ -1009,12 +1050,6 @@ func (s *Server) handleSessionRegistrationImpl(ctx context.Context, session serv return retErr } - // Collect composite SDK tools (with name-collision check against backend tools). - compositeSDKTools, retErr := s.collectCompositeTools(sessionID) - if retErr != nil { - return retErr - } - if len(adaptedResources) > 0 { if err := setSessionResourcesDirect(session, adaptedResources); err != nil { slog.Error("failed to add session resources", "session_id", sessionID, "error", err) @@ -1022,113 +1057,116 @@ func (s *Server) handleSessionRegistrationImpl(ctx context.Context, session serv } } - return s.injectTools(ctx, session, adaptedTools, compositeSDKTools) -} - -// collectCompositeTools converts workflow definitions to SDK tools, -// validating that no composite tool name collides with a backend tool name. -// Returns an empty slice (not an error) if no workflow defs are configured or conflicts are found. -func (s *Server) collectCompositeTools(sessionID string) ([]server.ServerTool, error) { - if len(s.workflowDefs) == 0 { - return nil, nil + if len(adaptedTools) > 0 { + if err := setSessionToolsDirect(session, adaptedTools); err != nil { + slog.Error("failed to add session tools", "session_id", sessionID, "error", err) + return err + } } - compositeTools := convertWorkflowDefsToTools(s.workflowDefs) + slog.Info("session capabilities injected", + "session_id", sessionID, + "tool_count", len(adaptedTools)) + return nil +} + +// applyCompositeToolsDecorator wraps the session with a compositeToolsDecorator. +// It filters workflow definitions for the session, validates name conflicts with +// backend tools, builds per-session workflow executors, then calls DecorateSession. +// +// Non-fatal conditions (no accessible defs, name conflicts) log a warning and +// leave the session undecorated rather than failing. +func (s *Server) applyCompositeToolsDecorator(sessionID string) error { multiSess, hasSess := s.vmcpSessionMgr.GetMultiSession(sessionID) if !hasSess { - slog.Error("session not found after creation; skipping composite tools", + slog.Warn("session not found after creation; skipping composite tools", "session_id", sessionID) - return nil, nil + return nil } - if err := validateNoToolConflicts(multiSess.Tools(), compositeTools); err != nil { - slog.Error("composite tool name conflict detected; skipping composite tools", + + sessionDefs := filterWorkflowDefsForSession(s.workflowDefs, multiSess.GetRoutingTable()) + if len(sessionDefs) == 0 { + return nil + } + + compositeToolsMeta := convertWorkflowDefsToTools(sessionDefs) + if err := validateNoToolConflicts(multiSess.Tools(), compositeToolsMeta); err != nil { + slog.Warn("composite tool name conflict detected; skipping composite tools", "session_id", sessionID, "error", err) - return nil, nil + return nil } - sdkTools, err := s.capabilityAdapter.ToCompositeToolSDKTools(compositeTools, s.workflowExecutors) - if err != nil { - return nil, fmt.Errorf("failed to convert composite tools: %w", err) + // Build per-session workflow executors bound to this session's routing table. + sessionComposer := s.composerFactory(multiSess.GetRoutingTable(), multiSess.Tools()) + sessionExecutors := make(map[string]compositetools.WorkflowExecutor, len(sessionDefs)) + for _, def := range sessionDefs { + ex := newComposerWorkflowExecutor(sessionComposer, def) + if s.workflowInstruments != nil { + ex = s.workflowInstruments.wrapExecutor(def.Name, ex) + } + sessionExecutors[def.Name] = ex } - return sdkTools, nil -} -// injectTools registers backend and composite tools into the session. -// When the optimizer is configured, all tools are indexed and only -// find_tool/call_tool are exposed; otherwise tools are registered directly. -func (s *Server) injectTools( - ctx context.Context, - session server.ClientSession, - adaptedTools []server.ServerTool, - compositeSDKTools []server.ServerTool, -) error { - sessionID := session.SessionID() + return s.vmcpSessionMgr.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + return compositetools.NewDecorator(sess, compositeToolsMeta, sessionExecutors) + }) +} - if s.config.OptimizerFactory != nil { - allTools := append(adaptedTools, compositeSDKTools...) - opt, err := s.config.OptimizerFactory(ctx, allTools) - if err != nil { - return fmt.Errorf("failed to create optimizer: %w", err) - } - if err = setSessionToolsDirect(session, adapter.CreateOptimizerTools(opt)); err != nil { - slog.Error("failed to add optimizer tools to session", "session_id", sessionID, "error", err) - return err - } - slog.Info("session capabilities injected (optimizer mode)", - "session_id", sessionID, - "indexed_tool_count", len(allTools)) - return nil +// applyOptimizerDecorator wraps the session with an optimizerDecorator. +// It reads the current session's tool list (including composite tools if applied), +// builds SDK tools for the optimizer factory to index, creates the optimizer, and +// calls DecorateSession. The optimizer replaces the tool list with find_tool + call_tool. +func (s *Server) applyOptimizerDecorator(ctx context.Context, sessionID string) error { + // Snapshot the pre-optimizer SDK tools for the optimizer to index. + // This includes composite tools if the composite decorator was already applied. + sdkTools, err := s.vmcpSessionMgr.GetAdaptedTools(sessionID) + if err != nil { + return fmt.Errorf("failed to get tools for optimizer: %w", err) } - if len(adaptedTools) > 0 { - if err := setSessionToolsDirect(session, adaptedTools); err != nil { - slog.Error("failed to add session tools", "session_id", sessionID, "error", err) - return err - } + opt, err := s.config.OptimizerFactory(ctx, sdkTools) + if err != nil { + return fmt.Errorf("failed to create optimizer: %w", err) } - if len(compositeSDKTools) > 0 { - if err := setSessionToolsDirect(session, compositeSDKTools); err != nil { - slog.Error("failed to add composite tools to session", "session_id", sessionID, "error", err) - return err - } - slog.Debug("added composite tools to session", "session_id", sessionID, "count", len(compositeSDKTools)) + + if err = s.vmcpSessionMgr.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + return optimizerdec.NewDecorator(sess, opt) + }); err != nil { + return err } - slog.Info("session capabilities injected", + slog.Info("session capabilities injected (optimizer mode)", "session_id", sessionID, - "tool_count", len(adaptedTools), - "composite_tool_count", len(compositeSDKTools)) + "indexed_tool_count", len(sdkTools)) + return nil } -// validateAndCreateExecutors validates workflow definitions and creates executors. +// validateWorkflows validates workflow definitions, returning only the valid ones. // // This function: // 1. Validates each workflow definition (cycle detection, tool references, etc.) // 2. Returns error on first validation failure (fail-fast) -// 3. Creates workflow executors for all valid workflows // // Failing fast on invalid workflows provides immediate user feedback and prevents // security issues (resource exhaustion from cycles, information disclosure from errors). -func validateAndCreateExecutors( +func validateWorkflows( validator composer.Composer, workflowDefs map[string]*composer.WorkflowDefinition, -) (map[string]*composer.WorkflowDefinition, map[string]adapter.WorkflowExecutor, error) { +) (map[string]*composer.WorkflowDefinition, error) { if len(workflowDefs) == 0 { - return nil, nil, nil + return nil, nil } validDefs := make(map[string]*composer.WorkflowDefinition, len(workflowDefs)) - validExecutors := make(map[string]adapter.WorkflowExecutor, len(workflowDefs)) for name, def := range workflowDefs { if err := validator.ValidateWorkflow(context.Background(), def); err != nil { - return nil, nil, fmt.Errorf("invalid workflow definition '%s': %w", name, err) + return nil, fmt.Errorf("invalid workflow definition '%s': %w", name, err) } validDefs[name] = def - validExecutors[name] = newComposerWorkflowExecutor(validator, def) slog.Debug("validated workflow definition", "name", name) } @@ -1136,7 +1174,7 @@ func validateAndCreateExecutors( slog.Info("loaded valid composite tool workflows", "count", len(validDefs)) } - return validDefs, validExecutors, nil + return validDefs, nil } // GetBackendHealthStatus returns the health status of a specific backend. diff --git a/pkg/vmcp/server/session_management_integration_test.go b/pkg/vmcp/server/session_management_integration_test.go index 127f30b295..7ddfc3785e 100644 --- a/pkg/vmcp/server/session_management_integration_test.go +++ b/pkg/vmcp/server/session_management_integration_test.go @@ -115,7 +115,13 @@ func newMockFactory(t *testing.T, ctrl *gomock.Controller, tools []vmcp.Tool) (* mock.EXPECT().Resources().Return(nil).AnyTimes() mock.EXPECT().Prompts().Return(nil).AnyTimes() mock.EXPECT().BackendSessions().Return(nil).AnyTimes() - mock.EXPECT().GetRoutingTable().Return(nil).AnyTimes() + // Build a routing table from the provided tools so that + // filterWorkflowDefsForSession can check tool accessibility per session. + rt := &vmcp.RoutingTable{Tools: make(map[string]*vmcp.BackendTarget, len(tools))} + for _, tool := range tools { + rt.Tools[tool.Name] = &vmcp.BackendTarget{WorkloadID: tool.Name} + } + mock.EXPECT().GetRoutingTable().Return(rt).AnyTimes() mock.EXPECT().ReadResource(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mock.EXPECT().GetPrompt(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() callResult := &vmcp.ToolCallResult{Content: []vmcp.Content{{Type: "text", Text: "fake result"}}} @@ -770,6 +776,79 @@ func TestIntegration_SessionManagement_CompositeToolConflict(t *testing.T) { "backend tool call should succeed after conflict detection; body: %s", string(respBody)) } +// TestIntegration_SessionManagement_CompositeToolsFilteredForSession verifies that +// composite tools whose underlying backend tools are not routable in a session are +// excluded from that session's tools/list. This enforces per-session authorization: +// a session that cannot access a backend tool also cannot access composite tools +// that depend on it. +func TestIntegration_SessionManagement_CompositeToolsFilteredForSession(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + // The session only has "allowed-tool"; it does NOT have "restricted-tool". + allowedTool := vmcp.Tool{Name: "allowed-tool", Description: "accessible backend tool"} + factory, _ := newMockFactory(t, ctrl, []vmcp.Tool{allowedTool}) + + // accessible-workflow only uses allowed-tool → should appear for this session. + accessibleDef := &composer.WorkflowDefinition{ + Name: "accessible-workflow", + Description: "uses only allowed backend tools", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "allowed-tool"}, + }, + } + // restricted-workflow uses restricted-tool which is absent from this session's + // routing table → must NOT appear for this session. + restrictedDef := &composer.WorkflowDefinition{ + Name: "restricted-workflow", + Description: "uses a backend tool not accessible in this session", + Steps: []composer.WorkflowStep{ + {ID: "s1", Type: composer.StepTypeTool, Tool: "allowed-tool"}, + {ID: "s2", Type: composer.StepTypeTool, Tool: "restricted-tool"}, + }, + } + + ts := buildTestServerWithOptions(t, factory, serverOptions{ + workflowDefs: map[string]*composer.WorkflowDefinition{ + "accessible-workflow": accessibleDef, + "restricted-workflow": restrictedDef, + }, + }) + + initResp := postMCP(t, ts.URL, map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": map[string]any{ + "protocolVersion": "2025-06-18", + "capabilities": map[string]any{}, + "clientInfo": map[string]any{"name": "test", "version": "1.0"}, + }, + }, "") + defer initResp.Body.Close() + require.Equal(t, http.StatusOK, initResp.StatusCode) + + sessionID := initResp.Header.Get("Mcp-Session-Id") + require.NotEmpty(t, sessionID) + + // Wait until accessible-workflow appears, then verify restricted-workflow does not. + require.Eventually(t, func() bool { + for _, n := range listToolNames(t, ts.URL, sessionID) { + if n == "accessible-workflow" { + return true + } + } + return false + }, 2*time.Second, 20*time.Millisecond, + "accessible-workflow should appear in tools/list") + + toolNames := listToolNames(t, ts.URL, sessionID) + assert.Contains(t, toolNames, "accessible-workflow", + "composite tool whose backend tools are all accessible must be visible") + assert.NotContains(t, toolNames, "restricted-workflow", + "composite tool that depends on an inaccessible backend tool must be hidden") +} + // TestIntegration_SessionManagement_OptimizerMode verifies that when an optimizer // factory is configured with session management, tools/list exposes only // find_tool and call_tool (the optimizer wraps all backend tools). diff --git a/pkg/vmcp/server/session_manager_interface.go b/pkg/vmcp/server/session_manager_interface.go index 359cab43f1..770cb7fff9 100644 --- a/pkg/vmcp/server/session_manager_interface.go +++ b/pkg/vmcp/server/session_manager_interface.go @@ -9,6 +9,7 @@ import ( mcpserver "github.com/mark3labs/mcp-go/server" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" ) // SessionManager extends the SDK's SessionIdManager with Phase 2 session creation @@ -39,4 +40,12 @@ type SessionManager interface { // Returns (nil, false) if the session does not exist or is still a placeholder. // Used to access session-scoped backend tool metadata (e.g. for conflict validation). GetMultiSession(sessionID string) (vmcpsession.MultiSession, bool) + + // DecorateSession retrieves the MultiSession for sessionID, applies fn to it, + // and stores the result back. Used to stack session decorators (composite tools, + // optimizer) after the base session is created. + DecorateSession(sessionID string, fn func(sessiontypes.MultiSession) sessiontypes.MultiSession) error + + // Terminate terminates the session with the given ID, closing all backend connections. + Terminate(sessionID string) (bool, error) } diff --git a/pkg/vmcp/server/sessionmanager/session_manager.go b/pkg/vmcp/server/sessionmanager/session_manager.go index 71532abc0d..13bd0dd822 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager.go +++ b/pkg/vmcp/server/sessionmanager/session_manager.go @@ -322,6 +322,41 @@ func (sm *Manager) GetMultiSession(sessionID string) (vmcpsession.MultiSession, return multiSess, ok } +// DecorateSession retrieves the MultiSession for sessionID, applies fn to it, +// and stores the result back. Returns an error if the session is not found or +// has not yet been upgraded from placeholder to MultiSession. +// +// A re-check is performed immediately before UpsertSession to guard against a +// race with Terminate(): if the session is deleted between GetMultiSession and +// UpsertSession, the upsert would silently resurrect a terminated session. The +// re-check catches that window. A narrow TOCTOU gap remains between the +// re-check and the upsert, but its consequence is bounded: Terminate() already +// called Close() on the underlying MultiSession before deleting it, so any +// resurrected decorator wraps an already-closed session and will fail on first +// use rather than leaking backend connections. +func (sm *Manager) DecorateSession(sessionID string, fn func(sessiontypes.MultiSession) sessiontypes.MultiSession) error { + sess, ok := sm.GetMultiSession(sessionID) + if !ok { + return fmt.Errorf("DecorateSession: session %q not found or not a multi-session", sessionID) + } + decorated := fn(sess) + if decorated == nil { + return fmt.Errorf("DecorateSession: decorator returned nil session") + } + if decorated.ID() != sessionID { + return fmt.Errorf("DecorateSession: decorator changed session ID from %q to %q", sessionID, decorated.ID()) + } + // Re-check: guard against a race with Terminate() deleting the session + // between GetMultiSession (above) and UpsertSession (below). + if _, ok := sm.GetMultiSession(sessionID); !ok { + return fmt.Errorf("DecorateSession: session %q was terminated during decoration", sessionID) + } + if err := sm.storage.UpsertSession(decorated); err != nil { + return fmt.Errorf("DecorateSession: failed to store decorated session: %w", err) + } + return nil +} + // GetAdaptedTools returns SDK-format tools for the given session, with handlers // that delegate tool invocations directly to the session's CallTool() method. // diff --git a/pkg/vmcp/server/sessionmanager/session_manager_test.go b/pkg/vmcp/server/sessionmanager/session_manager_test.go index 481a542721..058dfd37be 100644 --- a/pkg/vmcp/server/sessionmanager/session_manager_test.go +++ b/pkg/vmcp/server/sessionmanager/session_manager_test.go @@ -1378,6 +1378,113 @@ func TestSessionManager_GetAdaptedResources(t *testing.T) { }) } +// --------------------------------------------------------------------------- +// Tests: DecorateSession +// --------------------------------------------------------------------------- + +func TestSessionManager_DecorateSession(t *testing.T) { + t.Parallel() + + t.Run("replaces session with decorated result", func(t *testing.T) { + t.Parallel() + + tools := []vmcp.Tool{{Name: "hello", Description: "says hello"}} + ctrl := gomock.NewController(t) + factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) + factory.EXPECT(). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + return newMockSession(t, ctrl, id, tools), nil + }).Times(1) + + registry := newFakeRegistry() + sm, _ := newTestSessionManager(t, factory, registry) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + _, err := sm.CreateSession(context.Background(), sessionID) + require.NoError(t, err) + + // Apply a decorator that wraps with an extra tool. + extraTool := vmcp.Tool{Name: "extra", Description: "extra tool"} + err = sm.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + decorated := sessionmocks.NewMockMultiSession(ctrl) + // Delegate everything to base session + decorated.EXPECT().ID().Return(sess.ID()).AnyTimes() + decorated.EXPECT().Tools().Return(append(sess.Tools(), extraTool)).AnyTimes() + // other methods delegated via AnyTimes + decorated.EXPECT().Type().Return(sess.Type()).AnyTimes() + decorated.EXPECT().CreatedAt().Return(sess.CreatedAt()).AnyTimes() + decorated.EXPECT().UpdatedAt().Return(sess.UpdatedAt()).AnyTimes() + decorated.EXPECT().Touch().AnyTimes() + decorated.EXPECT().GetData().Return(nil).AnyTimes() + decorated.EXPECT().SetData(gomock.Any()).AnyTimes() + decorated.EXPECT().GetMetadata().Return(map[string]string{}).AnyTimes() + decorated.EXPECT().SetMetadata(gomock.Any(), gomock.Any()).AnyTimes() + decorated.EXPECT().BackendSessions().Return(nil).AnyTimes() + decorated.EXPECT().GetRoutingTable().Return(nil).AnyTimes() + decorated.EXPECT().Prompts().Return(nil).AnyTimes() + return decorated + }) + require.NoError(t, err) + + // After decoration, GetMultiSession returns the decorated session with both tools. + multiSess, ok := sm.GetMultiSession(sessionID) + require.True(t, ok) + require.Len(t, multiSess.Tools(), 2) + assert.Equal(t, "hello", multiSess.Tools()[0].Name) + assert.Equal(t, "extra", multiSess.Tools()[1].Name) + }) + + t.Run("returns error for unknown session", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + sm, _ := newTestSessionManager(t, newMockFactory(t, ctrl, newMockSession(t, ctrl, "", nil)), newFakeRegistry()) + + err := sm.DecorateSession("ghost-session", func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + return sess + }) + require.Error(t, err) + }) + + t.Run("returns error if session terminated during decoration", func(t *testing.T) { + t.Parallel() + + // Simulate the race: Terminate() is called between GetMultiSession and + // UpsertSession. We do this by terminating the session inside the + // decorator fn, so the re-check that follows fn() sees it is gone. + ctrl := gomock.NewController(t) + factory := sessionfactorymocks.NewMockMultiSessionFactory(ctrl) + factory.EXPECT(). + MakeSessionWithID(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, id string, _ *auth.Identity, _ bool, _ []*vmcp.Backend) (vmcpsession.MultiSession, error) { + sess := newMockSession(t, ctrl, id, nil) + sess.EXPECT().Close().Return(nil).AnyTimes() + return sess, nil + }).Times(1) + + sm, _ := newTestSessionManager(t, factory, newFakeRegistry()) + + sessionID := sm.Generate() + require.NotEmpty(t, sessionID) + _, err := sm.CreateSession(context.Background(), sessionID) + require.NoError(t, err) + + err = sm.DecorateSession(sessionID, func(sess sessiontypes.MultiSession) sessiontypes.MultiSession { + // Simulate concurrent Terminate() completing during decoration. + _, _ = sm.Terminate(sessionID) + return sess + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "terminated during decoration") + + // The session must not be resurrected. + _, ok := sm.GetMultiSession(sessionID) + assert.False(t, ok, "terminated session must not be resurrected by DecorateSession") + }) +} + // --------------------------------------------------------------------------- // Helper // --------------------------------------------------------------------------- diff --git a/pkg/vmcp/server/telemetry.go b/pkg/vmcp/server/telemetry.go index dc000aa184..ba275fe86f 100644 --- a/pkg/vmcp/server/telemetry.go +++ b/pkg/vmcp/server/telemetry.go @@ -258,17 +258,23 @@ func (t telemetryBackendClient) ListCapabilities( return t.backendClient.ListCapabilities(ctx, target) } -// monitorWorkflowExecutors decorates workflow executors with telemetry recording. -// It wraps each executor to emit metrics and traces for execution count, duration, and errors. -func monitorWorkflowExecutors( +// workflowExecutorInstruments holds pre-created OTEL instruments for workflow telemetry. +// Instruments are created once at server startup and reused across all session registrations +// to avoid re-registering the same metric names on every session creation. +type workflowExecutorInstruments struct { + tracer trace.Tracer + executionsTotal metric.Int64Counter + errorsTotal metric.Int64Counter + executionDuration metric.Float64Histogram +} + +// newWorkflowExecutorInstruments creates the OTEL instruments used to decorate +// per-session workflow executors. Call this once at server startup; pass the +// result to wrapExecutor at session registration time. +func newWorkflowExecutorInstruments( meterProvider metric.MeterProvider, tracerProvider trace.TracerProvider, - executors map[string]adapter.WorkflowExecutor, -) (map[string]adapter.WorkflowExecutor, error) { - if len(executors) == 0 { - return executors, nil - } - +) (*workflowExecutorInstruments, error) { meter := meterProvider.Meter(instrumentationName) executionsTotal, err := meter.Int64Counter( @@ -297,21 +303,25 @@ func monitorWorkflowExecutors( return nil, fmt.Errorf("failed to create workflow duration histogram: %w", err) } - tracer := tracerProvider.Tracer(instrumentationName) + return &workflowExecutorInstruments{ + tracer: tracerProvider.Tracer(instrumentationName), + executionsTotal: executionsTotal, + errorsTotal: errorsTotal, + executionDuration: executionDuration, + }, nil +} - monitored := make(map[string]adapter.WorkflowExecutor, len(executors)) - for name, executor := range executors { - monitored[name] = &telemetryWorkflowExecutor{ - name: name, - executor: executor, - tracer: tracer, - executionsTotal: executionsTotal, - errorsTotal: errorsTotal, - executionDuration: executionDuration, - } +// wrapExecutor returns a telemetry-decorated WorkflowExecutor using the +// pre-created instruments. Safe to call on every session registration. +func (i *workflowExecutorInstruments) wrapExecutor(name string, ex adapter.WorkflowExecutor) adapter.WorkflowExecutor { + return &telemetryWorkflowExecutor{ + name: name, + executor: ex, + tracer: i.tracer, + executionsTotal: i.executionsTotal, + errorsTotal: i.errorsTotal, + executionDuration: i.executionDuration, } - - return monitored, nil } // telemetryWorkflowExecutor wraps a WorkflowExecutor with telemetry recording. diff --git a/pkg/vmcp/session/compositetools/decorator.go b/pkg/vmcp/session/compositetools/decorator.go new file mode 100644 index 0000000000..35f672bfd5 --- /dev/null +++ b/pkg/vmcp/session/compositetools/decorator.go @@ -0,0 +1,113 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package compositetools provides a MultiSession decorator that adds composite +// tool (workflow) capabilities to a session. +package compositetools + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" +) + +// WorkflowExecutor executes a named composite tool workflow. +type WorkflowExecutor interface { + ExecuteWorkflow(ctx context.Context, params map[string]any) (*WorkflowResult, error) +} + +// WorkflowResult holds the output of a workflow execution. +type WorkflowResult struct { + Output map[string]any + Error error +} + +// compositeToolsDecorator wraps a MultiSession to add composite tool routing. +// It overrides Tools() to append composite tool metadata and CallTool() to +// intercept composite tool names and dispatch them to workflow executors. +// All other MultiSession methods delegate to the embedded session. +type compositeToolsDecorator struct { + sessiontypes.MultiSession + compositeTools []vmcp.Tool + executors map[string]WorkflowExecutor +} + +func errorResult(msg string) *vmcp.ToolCallResult { + return &vmcp.ToolCallResult{ + Content: []vmcp.Content{{Type: "text", Text: msg}}, + IsError: true, + } +} + +// NewDecorator wraps sess with composite tool support. compositeTools is the +// metadata list appended to session.Tools(). executors maps each composite tool +// name to its workflow executor. Both may be nil/empty. +func NewDecorator( + sess sessiontypes.MultiSession, + compositeTools []vmcp.Tool, + executors map[string]WorkflowExecutor, +) sessiontypes.MultiSession { + return &compositeToolsDecorator{ + MultiSession: sess, + compositeTools: compositeTools, + executors: executors, + } +} + +// Tools returns backend tools followed by composite tools. +func (d *compositeToolsDecorator) Tools() []vmcp.Tool { + backend := d.MultiSession.Tools() + if len(d.compositeTools) == 0 { + return backend + } + out := make([]vmcp.Tool, len(backend), len(backend)+len(d.compositeTools)) + copy(out, backend) + return append(out, d.compositeTools...) +} + +// CallTool dispatches composite tool names to their workflow executors. +// Unknown names are delegated to the embedded session. +func (d *compositeToolsDecorator) CallTool( + ctx context.Context, + caller *auth.Identity, + toolName string, + arguments map[string]any, + meta map[string]any, +) (*vmcp.ToolCallResult, error) { + if exec, ok := d.executors[toolName]; ok { + slog.Debug("handling composite tool call", "tool", toolName) + res, err := exec.ExecuteWorkflow(ctx, arguments) + if err != nil { + if errors.Is(err, context.DeadlineExceeded) { + slog.Warn("workflow execution timeout", "tool", toolName, "error", err) + return errorResult("Workflow execution timeout exceeded"), nil + } + slog.Error("workflow execution failed", "tool", toolName, "error", err) + return errorResult(fmt.Sprintf("Workflow execution failed: %v", err)), nil + } + if res == nil { + slog.Error("workflow executor returned nil result", "tool", toolName) + return errorResult("Workflow executor returned nil result"), nil + } + if res.Error != nil { + slog.Error("workflow completed with error", "tool", toolName, "error", res.Error) + return errorResult(fmt.Sprintf("Workflow error: %v", res.Error)), nil + } + slog.Debug("composite tool completed successfully", "tool", toolName) + jsonBytes, err := json.Marshal(res.Output) + if err != nil { + return errorResult(fmt.Sprintf("failed to marshal output: %v", err)), nil + } + return &vmcp.ToolCallResult{ + Content: []vmcp.Content{{Type: "text", Text: string(jsonBytes)}}, + StructuredContent: res.Output, + }, nil + } + return d.MultiSession.CallTool(ctx, caller, toolName, arguments, meta) +} diff --git a/pkg/vmcp/session/compositetools/decorator_test.go b/pkg/vmcp/session/compositetools/decorator_test.go new file mode 100644 index 0000000000..ba8bfe1fd3 --- /dev/null +++ b/pkg/vmcp/session/compositetools/decorator_test.go @@ -0,0 +1,126 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package compositetools_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/session/compositetools" + sessionmocks "github.com/stacklok/toolhive/pkg/vmcp/session/types/mocks" +) + +// stubExecutor is a simple WorkflowExecutor for tests. +type stubExecutor struct { + output map[string]any + err error +} + +func (s *stubExecutor) ExecuteWorkflow(_ context.Context, _ map[string]any) (*compositetools.WorkflowResult, error) { + return &compositetools.WorkflowResult{Output: s.output}, s.err +} + +func TestCompositeToolsDecorator_Tools(t *testing.T) { + t.Parallel() + + t.Run("appends composite tools to backend tools", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + backendTools := []vmcp.Tool{{Name: "backend_search", Description: "search"}} + base.EXPECT().Tools().Return(backendTools).AnyTimes() + + compositeToolList := []vmcp.Tool{{Name: "my_workflow", Description: "a workflow"}} + dec := compositetools.NewDecorator(base, compositeToolList, nil) + + got := dec.Tools() + require.Len(t, got, 2) + assert.Equal(t, "backend_search", got[0].Name) + assert.Equal(t, "my_workflow", got[1].Name) + }) + + t.Run("returns only backend tools when no composite tools", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + backendTools := []vmcp.Tool{{Name: "backend_search", Description: "search"}} + base.EXPECT().Tools().Return(backendTools).AnyTimes() + + dec := compositetools.NewDecorator(base, nil, nil) + + got := dec.Tools() + require.Len(t, got, 1) + assert.Equal(t, "backend_search", got[0].Name) + }) +} + +func TestCompositeToolsDecorator_CallTool(t *testing.T) { + t.Parallel() + + t.Run("routes composite tool name to executor", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + expectedOutput := map[string]any{"result": "done"} + exec := &stubExecutor{output: expectedOutput} + executors := map[string]compositetools.WorkflowExecutor{"my_workflow": exec} + + dec := compositetools.NewDecorator(base, []vmcp.Tool{{Name: "my_workflow"}}, executors) + result, err := dec.CallTool(context.Background(), nil, "my_workflow", map[string]any{"x": 1}, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, expectedOutput, result.StructuredContent) + }) + + t.Run("delegates unknown tool name to embedded session", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + expectedResult := &vmcp.ToolCallResult{IsError: false} + base.EXPECT(). + CallTool(gomock.Any(), gomock.Any(), "backend_tool", gomock.Any(), gomock.Any()). + Return(expectedResult, nil) + + dec := compositetools.NewDecorator(base, nil, nil) + result, err := dec.CallTool(context.Background(), &auth.Identity{}, "backend_tool", nil, nil) + + require.NoError(t, err) + assert.Equal(t, expectedResult, result) + }) + + t.Run("propagates executor error as tool error result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + execErr := errors.New("workflow failed") + exec := &stubExecutor{err: execErr} + executors := map[string]compositetools.WorkflowExecutor{"failing_wf": exec} + + dec := compositetools.NewDecorator(base, []vmcp.Tool{{Name: "failing_wf"}}, executors) + result, err := dec.CallTool(context.Background(), nil, "failing_wf", nil, nil) + + require.NoError(t, err) // errors surface as IsError results per MCP convention + require.NotNil(t, result) + assert.True(t, result.IsError) + }) +} diff --git a/pkg/vmcp/session/optimizerdec/decorator.go b/pkg/vmcp/session/optimizerdec/decorator.go new file mode 100644 index 0000000000..a59c00bbdd --- /dev/null +++ b/pkg/vmcp/session/optimizerdec/decorator.go @@ -0,0 +1,178 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package optimizerdec provides a MultiSession decorator that replaces the +// full tool list with two optimizer tools: find_tool and call_tool. +package optimizerdec + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/conversion" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + "github.com/stacklok/toolhive/pkg/vmcp/schema" + sessiontypes "github.com/stacklok/toolhive/pkg/vmcp/session/types" +) + +const ( + // FindToolName is the tool name for semantic tool discovery. + FindToolName = "find_tool" + // CallToolName is the tool name for routing a call to any backend tool. + CallToolName = "call_tool" +) + +// Pre-generated schemas for find_tool and call_tool, computed at init time. +var ( + findToolInputSchema = mustGenerateSchema[optimizer.FindToolInput]() + callToolInputSchema = mustGenerateSchema[optimizer.CallToolInput]() +) + +// optimizerDecorator wraps a MultiSession to expose only find_tool and call_tool. +// Tools() returns only those two tools. CallTool("find_tool") routes through the +// optimizer's FindTool; CallTool("call_tool") routes through the optimizer's +// CallTool so that all optimizer telemetry (traces, metrics) is recorded. +type optimizerDecorator struct { + sessiontypes.MultiSession + opt optimizer.Optimizer + optimizerTools []vmcp.Tool +} + +// NewDecorator wraps sess with optimizer mode. Only find_tool and call_tool are +// exposed via Tools(). find_tool calls opt.FindTool. call_tool calls opt.CallTool, +// which routes through the instrumented optimizer (telemetry, traces, metrics). +func NewDecorator(sess sessiontypes.MultiSession, opt optimizer.Optimizer) sessiontypes.MultiSession { + return &optimizerDecorator{ + MultiSession: sess, + opt: opt, + optimizerTools: []vmcp.Tool{ + { + Name: FindToolName, + Description: "Find and return tools that can help accomplish the user's request. " + + "This searches available MCP server tools using semantic and keyword-based matching. " + + "Use this function when you need to: " + + "(1) discover what tools are available for a specific task, " + + "(2) find the right tool(s) before attempting to solve a problem, " + + "(3) check if required functionality exists in the current environment. " + + "Returns matching tools ranked by relevance including their names, descriptions, " + + "required parameters and schemas, plus token efficiency metrics showing " + + "baseline_tokens, returned_tokens, and savings_percent. " + + "Always call this before call_tool to discover the correct tool name and parameter schema.", + InputSchema: findToolInputSchema, + }, + { + Name: CallToolName, + Description: "Execute a specific tool with the provided parameters. " + + "Use this function to run a tool after identifying it with find_tool. " + + "Important: always use find_tool first to get the correct tool_name " + + "and parameter schema before calling this function.", + InputSchema: callToolInputSchema, + }, + }, + } +} + +// Tools returns only find_tool and call_tool, replacing the full backend tool list. +// A defensive copy is returned so callers cannot mutate the decorator's internal slice. +func (d *optimizerDecorator) Tools() []vmcp.Tool { + result := make([]vmcp.Tool, len(d.optimizerTools)) + copy(result, d.optimizerTools) + return result +} + +// CallTool handles find_tool and call_tool. Both route through the optimizer so +// that all optimizer telemetry is recorded. Any other tool name returns an error. +func (d *optimizerDecorator) CallTool( + ctx context.Context, + _ *auth.Identity, + toolName string, + arguments map[string]any, + _ map[string]any, +) (*vmcp.ToolCallResult, error) { + switch toolName { + case FindToolName: + return d.handleFindTool(ctx, arguments) + case CallToolName: + return d.handleCallTool(ctx, arguments) + default: + return nil, fmt.Errorf("tool not found: %s", toolName) + } +} + +func (d *optimizerDecorator) handleFindTool(ctx context.Context, arguments map[string]any) (*vmcp.ToolCallResult, error) { + input, err := schema.Translate[optimizer.FindToolInput](arguments) + if err != nil { + return errorResult(fmt.Sprintf("invalid arguments: %v", err)), nil + } + + output, err := d.opt.FindTool(ctx, input) + if err != nil { + return errorResult(fmt.Sprintf("find_tool failed: %v", err)), nil + } + if output == nil { + return errorResult("find_tool: optimizer returned nil result"), nil + } + + jsonBytes, err := json.Marshal(output) + if err != nil { + return errorResult(fmt.Sprintf("failed to marshal find_tool output: %v", err)), nil + } + + var structured map[string]any + _ = json.Unmarshal(jsonBytes, &structured) + + return &vmcp.ToolCallResult{ + Content: []vmcp.Content{{Type: "text", Text: string(jsonBytes)}}, + StructuredContent: structured, + }, nil +} + +func (d *optimizerDecorator) handleCallTool( + ctx context.Context, + arguments map[string]any, +) (*vmcp.ToolCallResult, error) { + input, err := schema.Translate[optimizer.CallToolInput](arguments) + if err != nil { + return errorResult(fmt.Sprintf("invalid arguments: %v", err)), nil + } + + mcpResult, err := d.opt.CallTool(ctx, input) + if err != nil { + return errorResult(fmt.Sprintf("call_tool failed: %v", err)), nil + } + if mcpResult == nil { + return errorResult("call_tool: optimizer returned nil result"), nil + } + + return mcpResultToVMCPResult(mcpResult), nil +} + +// mcpResultToVMCPResult converts an MCP SDK CallToolResult to the vmcp domain type. +func mcpResultToVMCPResult(r *mcp.CallToolResult) *vmcp.ToolCallResult { + structured, _ := r.StructuredContent.(map[string]any) + return &vmcp.ToolCallResult{ + Content: conversion.ConvertMCPContents(r.Content), + StructuredContent: structured, + IsError: r.IsError, + } +} + +func errorResult(msg string) *vmcp.ToolCallResult { + return &vmcp.ToolCallResult{ + Content: []vmcp.Content{{Type: "text", Text: msg}}, + IsError: true, + } +} + +func mustGenerateSchema[T any]() map[string]any { + s, err := schema.GenerateSchema[T]() + if err != nil { + panic(fmt.Sprintf("optimizerdec: failed to generate schema: %v", err)) + } + return s +} diff --git a/pkg/vmcp/session/optimizerdec/decorator_test.go b/pkg/vmcp/session/optimizerdec/decorator_test.go new file mode 100644 index 0000000000..4824cd3051 --- /dev/null +++ b/pkg/vmcp/session/optimizerdec/decorator_test.go @@ -0,0 +1,199 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizerdec_test + +import ( + "context" + "errors" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + "github.com/stacklok/toolhive/pkg/vmcp/session/optimizerdec" + sessionmocks "github.com/stacklok/toolhive/pkg/vmcp/session/types/mocks" +) + +// stubOptimizer implements optimizer.Optimizer for tests. +type stubOptimizer struct { + findOutput *optimizer.FindToolOutput + findErr error + callOutput *mcp.CallToolResult + callErr error +} + +func (s *stubOptimizer) FindTool(_ context.Context, _ optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { + return s.findOutput, s.findErr +} + +func (s *stubOptimizer) CallTool(_ context.Context, _ optimizer.CallToolInput) (*mcp.CallToolResult, error) { + return s.callOutput, s.callErr +} + +func TestOptimizerDecorator_Tools(t *testing.T) { + t.Parallel() + + t.Run("returns only find_tool and call_tool", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return([]vmcp.Tool{{Name: "backend_search"}}).AnyTimes() + + dec := optimizerdec.NewDecorator(base, &stubOptimizer{}) + + got := dec.Tools() + require.Len(t, got, 2) + assert.Equal(t, "find_tool", got[0].Name) + assert.Equal(t, "call_tool", got[1].Name) + // Both tools must have non-empty input schemas. + assert.NotEmpty(t, got[0].InputSchema) + assert.NotEmpty(t, got[1].InputSchema) + }) +} + +func TestOptimizerDecorator_CallTool_FindTool(t *testing.T) { + t.Parallel() + + t.Run("find_tool calls optimizer and returns JSON result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + findOutput := &optimizer.FindToolOutput{ + Tools: []mcp.Tool{{Name: "search"}}, + } + opt := &stubOptimizer{findOutput: findOutput} + dec := optimizerdec.NewDecorator(base, opt) + + args := map[string]any{"tool_description": "web search"} + result, err := dec.CallTool(context.Background(), nil, "find_tool", args, nil) + + require.NoError(t, err) + require.NotNil(t, result) + // Result should be non-error and contain the marshaled output. + assert.False(t, result.IsError) + // The structured content should be present or content should have JSON text. + require.NotEmpty(t, result.Content) + }) + + t.Run("find_tool propagates optimizer error as tool error result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + opt := &stubOptimizer{findErr: errors.New("index unavailable")} + dec := optimizerdec.NewDecorator(base, opt) + + result, err := dec.CallTool(context.Background(), nil, "find_tool", map[string]any{"tool_description": "x"}, nil) + + require.NoError(t, err) // errors are surfaced as IsError results per MCP convention + require.NotNil(t, result) + assert.True(t, result.IsError) + }) + + t.Run("find_tool returns error result when optimizer returns nil output", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + opt := &stubOptimizer{findOutput: nil, findErr: nil} + dec := optimizerdec.NewDecorator(base, opt) + + result, err := dec.CallTool(context.Background(), nil, "find_tool", map[string]any{"tool_description": "x"}, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) + }) +} + +func TestOptimizerDecorator_CallTool_CallTool(t *testing.T) { + t.Parallel() + + t.Run("call_tool routes through optimizer and converts result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + // The underlying session must NOT be called — call_tool routes through optimizer.CallTool. + + opt := &stubOptimizer{ + callOutput: mcp.NewToolResultText("fetched content"), + } + dec := optimizerdec.NewDecorator(base, opt) + + args := map[string]any{ + "tool_name": "backend_fetch", + "parameters": map[string]any{"url": "https://example.com"}, + } + result, err := dec.CallTool(context.Background(), &auth.Identity{}, "call_tool", args, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.False(t, result.IsError) + require.Len(t, result.Content, 1) + assert.Equal(t, "fetched content", result.Content[0].Text) + }) + + t.Run("call_tool propagates optimizer error as tool error result", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + opt := &stubOptimizer{callErr: errors.New("backend unreachable")} + dec := optimizerdec.NewDecorator(base, opt) + + args := map[string]any{"tool_name": "backend_fetch"} + result, err := dec.CallTool(context.Background(), nil, "call_tool", args, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) + }) + + t.Run("call_tool returns error result when tool_name missing", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + dec := optimizerdec.NewDecorator(base, &stubOptimizer{}) + + result, err := dec.CallTool(context.Background(), nil, "call_tool", map[string]any{}, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.True(t, result.IsError) + }) + + t.Run("unknown tool returns error", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + base := sessionmocks.NewMockMultiSession(ctrl) + base.EXPECT().Tools().Return(nil).AnyTimes() + + dec := optimizerdec.NewDecorator(base, &stubOptimizer{}) + + _, err := dec.CallTool(context.Background(), nil, "nonexistent_tool", nil, nil) + + require.Error(t, err) + }) +} diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_composite_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_composite_test.go index 485cb660e2..5bd875cdd9 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_composite_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_composite_test.go @@ -20,10 +20,15 @@ import ( "github.com/stacklok/toolhive/test/e2e/images" ) -// This test exercises composite tool execution through the optimizer's call_tool. -// Without the fix in injectOptimizerCapabilities, composite tools are registered -// with backend routing handlers (ToSDKTools) instead of workflow execution handlers -// (ToCompositeToolSDKTools), causing call_tool to fail with ErrToolNotFound. +// This test exercises composite tool discovery and execution through the +// optimizer's find_tool / call_tool interface. +// +// Composite tools are registered as session decorators (compositetools.Decorator) +// before the optimizer decorator is applied, so the optimizer indexes them +// alongside backend tools. Workflow steps reference backend tools via the +// "{workloadID}.{originalCapabilityName}" dot convention, which the session +// router resolves to the correct conflict-resolved routing table entry +// regardless of which conflict resolution strategy is in use. // // A lightweight fake embedding server replaces the heavyweight TEI image to keep // test setup fast while satisfying the optimizer's embedding service requirement. @@ -65,6 +70,14 @@ var _ = Describe("VirtualMCPServer Optimizer Composite Tools", Ordered, func() { "url": "{{.params.url}}", } + // Workflow steps use the "{workloadID}.{originalCapabilityName}" dot + // convention so that the session router can resolve them regardless of + // conflict resolution strategy. backendFetchToolName ("fetch") is the + // name the gofetch backend exposes; the aggregation override renames it + // to vmcpFetchToolName for clients, but the step references the + // original backend capability name via the dot convention. + fetchStepTool := backendName + "." + backendFetchToolName // "backend-opt-composite.fetch" + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ ObjectMeta: metav1.ObjectMeta{ Name: vmcpServerName, @@ -103,13 +116,13 @@ var _ = Describe("VirtualMCPServer Optimizer Composite Tools", Ordered, func() { { ID: "first_fetch", Type: "tool", - Tool: vmcpFetchToolName, + Tool: fetchStepTool, Arguments: thvjson.NewMap(stepArgs), }, { ID: "second_fetch", Type: "tool", - Tool: vmcpFetchToolName, + Tool: fetchStepTool, DependsOn: []string{"first_fetch"}, Arguments: thvjson.NewMap(stepArgs), }, diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go index 2b35e51859..26469b1e7e 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_optimizer_test.go @@ -69,6 +69,13 @@ var _ = Describe("VirtualMCPServer Optimizer Mode", Ordered, func() { "url": "{{.params.url}}", } + // Workflow steps use the "{workloadID}.{originalCapabilityName}" dot + // convention so the session router resolves them regardless of conflict + // resolution strategy. backendFetchToolName ("fetch") is the original + // backend capability; the aggregation override renames it to + // vmcpFetchToolName for clients, but steps must reference the original. + fetchStepTool := backendName + "." + backendFetchToolName // "backend-optimizer-fetch.fetch" + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ ObjectMeta: metav1.ObjectMeta{ Name: vmcpServerName, @@ -110,13 +117,13 @@ var _ = Describe("VirtualMCPServer Optimizer Mode", Ordered, func() { { ID: "first_fetch", Type: "tool", - Tool: vmcpFetchToolName, + Tool: fetchStepTool, Arguments: thvjson.NewMap(stepArgs), }, { ID: "second_fetch", Type: "tool", - Tool: vmcpFetchToolName, + Tool: fetchStepTool, DependsOn: []string{"first_fetch"}, Arguments: thvjson.NewMap(stepArgs), },