From a93efacd8b3ef93883f7143f8d98859c68c3bea0 Mon Sep 17 00:00:00 2001 From: Jeremy Drouillard Date: Tue, 13 Jan 2026 09:41:53 -0800 Subject: [PATCH 01/16] Update vmcp/README --- cmd/vmcp/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cmd/vmcp/README.md b/cmd/vmcp/README.md index 30ac862ca2..10a60bef2a 100644 --- a/cmd/vmcp/README.md +++ b/cmd/vmcp/README.md @@ -6,7 +6,7 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC ## Features -### Implemented (Phase 1) +### Implemented - ✅ **Group-Based Backend Management**: Automatic workload discovery from ToolHive groups - ✅ **Tool Aggregation**: Combines tools from multiple MCP servers with conflict resolution (prefix, priority, manual) - ✅ **Resource & Prompt Aggregation**: Unified access to resources and prompts from all backends @@ -15,12 +15,14 @@ The Virtual MCP Server (vmcp) is a standalone binary that aggregates multiple MC - ✅ **Health Endpoints**: `/health` and `/ping` for service monitoring - ✅ **Configuration Validation**: `vmcp validate` command for config verification - ✅ **Observability**: OpenTelemetry metrics and traces for backend operations and workflow executions +- ✅ **Composite Tools**: Multi-step workflows with elicitation support ### In Progress - 🚧 **Incoming Authentication** (Issue #165): OIDC, local, anonymous authentication - 🚧 **Outgoing Authentication** (Issue #160): RFC 8693 token exchange for backend API access - 🚧 **Token Caching**: Memory and Redis cache providers - 🚧 **Health Monitoring** (Issue #166): Circuit breakers, backend health checks +- 🚧 **Optimizer** Support the MCP optimizer in vMCP for context optimization on large toolsets. ### Future (Phase 2+) - 📋 **Authorization**: Cedar policy-based access control From 16b950336a82724f72d2d6adaf40bb770227f835 Mon Sep 17 00:00:00 2001 From: Nigel Brown Date: Thu, 15 Jan 2026 15:03:06 +0000 Subject: [PATCH 02/16] feat: Add optimizer package with semantic tool discovery and ingestion (#3253) * feat: Add optimizer package with semantic tool discovery and ingestion This PR introduces the optimizer package, a Go port of the mcp-optimizer Python service that provides semantic tool discovery and ingestion for MCP servers. - **Semantic tool search** using vector embeddings (384-dim) - **Token counting** for LLM cost estimation - **Full-text search** via SQLite FTS5 - **Multiple embedding backends**: Ollama, vLLM, or placeholder (testing) - **Production-ready database** with sqlite-vec for vector similarity search --- .golangci.yml | 2 + Taskfile.yml | 11 +- cmd/thv-operator/Taskfile.yml | 2 +- cmd/vmcp/app/commands.go | 36 ++ ...olhive.stacklok.dev_virtualmcpservers.yaml | 78 +++ ...olhive.stacklok.dev_virtualmcpservers.yaml | 78 +++ docs/operator/crd-api.md | 27 + examples/vmcp-config-optimizer.yaml | 113 ++++ go.mod | 9 +- go.sum | 38 +- pkg/optimizer/INTEGRATION.md | 131 ++++ pkg/optimizer/README.md | 337 ++++++++++ pkg/optimizer/db/backend_server.go | 234 +++++++ pkg/optimizer/db/backend_server_test.go | 424 +++++++++++++ pkg/optimizer/db/backend_tool.go | 310 ++++++++++ pkg/optimizer/db/backend_tool_test.go | 579 ++++++++++++++++++ pkg/optimizer/db/db.go | 182 ++++++ pkg/optimizer/db/fts.go | 341 +++++++++++ pkg/optimizer/db/hybrid.go | 167 +++++ pkg/optimizer/db/schema_fts.sql | 120 ++++ pkg/optimizer/db/sqlite_fts.go | 8 + pkg/optimizer/doc.go | 83 +++ pkg/optimizer/embeddings/cache.go | 101 +++ pkg/optimizer/embeddings/cache_test.go | 169 +++++ pkg/optimizer/embeddings/manager.go | 281 +++++++++ pkg/optimizer/embeddings/ollama.go | 128 ++++ pkg/optimizer/embeddings/ollama_test.go | 106 ++++ pkg/optimizer/embeddings/openai_compatible.go | 149 +++++ .../embeddings/openai_compatible_test.go | 235 +++++++ pkg/optimizer/ingestion/errors.go | 21 + pkg/optimizer/ingestion/service.go | 215 +++++++ pkg/optimizer/ingestion/service_test.go | 148 +++++ pkg/optimizer/models/errors.go | 16 + pkg/optimizer/models/models.go | 173 ++++++ pkg/optimizer/models/models_test.go | 270 ++++++++ pkg/optimizer/models/transport.go | 111 ++++ pkg/optimizer/models/transport_test.go | 273 +++++++++ pkg/optimizer/tokens/counter.go | 65 ++ pkg/optimizer/tokens/counter_test.go | 143 +++++ pkg/vmcp/config/config.go | 81 +++ pkg/vmcp/config/zz_generated.deepcopy.go | 20 + pkg/vmcp/optimizer/optimizer.go | 364 +++++++++++ .../optimizer/optimizer_integration_test.go | 167 +++++ pkg/vmcp/optimizer/optimizer_unit_test.go | 260 ++++++++ pkg/vmcp/router/default_router.go | 15 + pkg/vmcp/server/mocks/mock_watcher.go | 83 +++ pkg/vmcp/server/server.go | 144 ++++- scripts/README.md | 96 +++ .../inspect-chromem-raw.go | 106 ++++ scripts/inspect-chromem/inspect-chromem.go | 123 ++++ scripts/inspect-optimizer-db.sh | 63 ++ scripts/query-optimizer-db.sh | 46 ++ scripts/test-optimizer-with-sqlite-vec.sh | 117 ++++ .../view-chromem-tool/view-chromem-tool.go | 153 +++++ 54 files changed, 7732 insertions(+), 20 deletions(-) create mode 100644 examples/vmcp-config-optimizer.yaml create mode 100644 pkg/optimizer/INTEGRATION.md create mode 100644 pkg/optimizer/README.md create mode 100644 pkg/optimizer/db/backend_server.go create mode 100644 pkg/optimizer/db/backend_server_test.go create mode 100644 pkg/optimizer/db/backend_tool.go create mode 100644 pkg/optimizer/db/backend_tool_test.go create mode 100644 pkg/optimizer/db/db.go create mode 100644 pkg/optimizer/db/fts.go create mode 100644 pkg/optimizer/db/hybrid.go create mode 100644 pkg/optimizer/db/schema_fts.sql create mode 100644 pkg/optimizer/db/sqlite_fts.go create mode 100644 pkg/optimizer/doc.go create mode 100644 pkg/optimizer/embeddings/cache.go create mode 100644 pkg/optimizer/embeddings/cache_test.go create mode 100644 pkg/optimizer/embeddings/manager.go create mode 100644 pkg/optimizer/embeddings/ollama.go create mode 100644 pkg/optimizer/embeddings/ollama_test.go create mode 100644 pkg/optimizer/embeddings/openai_compatible.go create mode 100644 pkg/optimizer/embeddings/openai_compatible_test.go create mode 100644 pkg/optimizer/ingestion/errors.go create mode 100644 pkg/optimizer/ingestion/service.go create mode 100644 pkg/optimizer/ingestion/service_test.go create mode 100644 pkg/optimizer/models/errors.go create mode 100644 pkg/optimizer/models/models.go create mode 100644 pkg/optimizer/models/models_test.go create mode 100644 pkg/optimizer/models/transport.go create mode 100644 pkg/optimizer/models/transport_test.go create mode 100644 pkg/optimizer/tokens/counter.go create mode 100644 pkg/optimizer/tokens/counter_test.go create mode 100644 pkg/vmcp/optimizer/optimizer.go create mode 100644 pkg/vmcp/optimizer/optimizer_integration_test.go create mode 100644 pkg/vmcp/optimizer/optimizer_unit_test.go create mode 100644 scripts/README.md create mode 100644 scripts/inspect-chromem-raw/inspect-chromem-raw.go create mode 100644 scripts/inspect-chromem/inspect-chromem.go create mode 100755 scripts/inspect-optimizer-db.sh create mode 100755 scripts/query-optimizer-db.sh create mode 100755 scripts/test-optimizer-with-sqlite-vec.sh create mode 100644 scripts/view-chromem-tool/view-chromem-tool.go diff --git a/.golangci.yml b/.golangci.yml index 62c3611473..ff2b3d54e9 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -139,6 +139,7 @@ linters: - third_party$ - builtin$ - examples$ + - scripts$ formatters: enable: - gci @@ -155,3 +156,4 @@ formatters: - third_party$ - builtin$ - examples$ + - scripts$ diff --git a/Taskfile.yml b/Taskfile.yml index 0554492f72..3ec329b706 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -153,6 +153,11 @@ tasks: - task: test-e2e-windows platforms: [windows] + test-optimizer: + desc: Run optimizer integration tests with sqlite-vec + cmds: + - ./scripts/test-optimizer-with-sqlite-vec.sh + test-all: desc: Run all tests (unit and e2e) deps: [test, test-e2e] @@ -200,12 +205,12 @@ tasks: cmds: - cmd: mkdir -p bin platforms: [linux, darwin] - - cmd: go build -ldflags "-s -w -X github.com/stacklok/toolhive/pkg/versions.Version={{.VERSION}} -X github.com/stacklok/toolhive/pkg/versions.Commit={{.COMMIT}} -X github.com/stacklok/toolhive/pkg/versions.BuildDate={{.BUILD_DATE}}" -o bin/vmcp ./cmd/vmcp + - cmd: go build -tags="fts5" -ldflags "-s -w -X github.com/stacklok/toolhive/pkg/versions.Version={{.VERSION}} -X github.com/stacklok/toolhive/pkg/versions.Commit={{.COMMIT}} -X github.com/stacklok/toolhive/pkg/versions.BuildDate={{.BUILD_DATE}}" -o bin/vmcp ./cmd/vmcp platforms: [linux, darwin] - cmd: cmd.exe /c mkdir bin platforms: [windows] ignore_error: true - - cmd: go build -ldflags "-s -w -X github.com/stacklok/toolhive/pkg/versions.Version={{.VERSION}} -X github.com/stacklok/toolhive/pkg/versions.Commit={{.COMMIT}} -X github.com/stacklok/toolhive/pkg/versions.BuildDate={{.BUILD_DATE}}" -o bin/vmcp.exe ./cmd/vmcp + - cmd: go build -tags="fts5" -ldflags "-s -w -X github.com/stacklok/toolhive/pkg/versions.Version={{.VERSION}} -X github.com/stacklok/toolhive/pkg/versions.Commit={{.COMMIT}} -X github.com/stacklok/toolhive/pkg/versions.BuildDate={{.BUILD_DATE}}" -o bin/vmcp.exe ./cmd/vmcp platforms: [windows] install-vmcp: @@ -217,7 +222,7 @@ tasks: sh: git rev-parse --short HEAD || echo "unknown" BUILD_DATE: '{{dateInZone "2006-01-02T15:04:05Z" (now) "UTC"}}' cmds: - - go install -ldflags "-s -w -X github.com/stacklok/toolhive/pkg/versions.Version={{.VERSION}} -X github.com/stacklok/toolhive/pkg/versions.Commit={{.COMMIT}} -X github.com/stacklok/toolhive/pkg/versions.BuildDate={{.BUILD_DATE}}" -v ./cmd/vmcp + - go install -tags="fts5" -ldflags "-s -w -X github.com/stacklok/toolhive/pkg/versions.Version={{.VERSION}} -X github.com/stacklok/toolhive/pkg/versions.Commit={{.COMMIT}} -X github.com/stacklok/toolhive/pkg/versions.BuildDate={{.BUILD_DATE}}" -v ./cmd/vmcp all: desc: Run linting, tests, and build diff --git a/cmd/thv-operator/Taskfile.yml b/cmd/thv-operator/Taskfile.yml index f67050e875..0bee121944 100644 --- a/cmd/thv-operator/Taskfile.yml +++ b/cmd/thv-operator/Taskfile.yml @@ -200,7 +200,7 @@ tasks: ignore_error: true # Windows has no mkdir -p, so just ignore error if it exists - go install sigs.k8s.io/controller-tools/cmd/controller-gen@v0.17.3 - $(go env GOPATH)/bin/controller-gen rbac:roleName=toolhive-operator-manager-role paths="{{.CONTROLLER_GEN_PATHS}}" output:rbac:artifacts:config={{.PROJECT_ROOT}}/deploy/charts/operator/templates/clusterrole - - $(go env GOPATH)/bin/controller-gen crd webhook paths="{{.CONTROLLER_GEN_PATHS}}" output:crd:artifacts:config={{.PROJECT_ROOT}}/deploy/charts/operator-crds/files/crds + - $(go env GOPATH)/bin/controller-gen crd:allowDangerousTypes=true webhook paths="{{.CONTROLLER_GEN_PATHS}}" output:crd:artifacts:config={{.PROJECT_ROOT}}/deploy/charts/operator-crds/files/crds # Wrap CRDs with Helm templates for conditional installation - go run {{.PROJECT_ROOT}}/deploy/charts/operator-crds/crd-helm-wrapper/main.go -source {{.PROJECT_ROOT}}/deploy/charts/operator-crds/files/crds -target {{.PROJECT_ROOT}}/deploy/charts/operator-crds/templates # - "{{.PROJECT_ROOT}}/deploy/charts/operator-crds/scripts/wrap-crds.sh" diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index ca6060bcab..2d5c8f28e3 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -416,6 +416,42 @@ func runServe(cmd *cobra.Command, _ []string) error { Watcher: backendWatcher, } + // Configure optimizer if enabled in YAML config + if cfg.Optimizer != nil && cfg.Optimizer.Enabled { + logger.Info("🔬 Optimizer enabled via configuration (chromem-go)") + hybridRatio := 0.7 // Default + if cfg.Optimizer.HybridSearchRatio != nil { + hybridRatio = *cfg.Optimizer.HybridSearchRatio + } + serverCfg.OptimizerConfig = &vmcpserver.OptimizerConfig{ + Enabled: cfg.Optimizer.Enabled, + PersistPath: cfg.Optimizer.PersistPath, + FTSDBPath: cfg.Optimizer.FTSDBPath, + HybridSearchRatio: hybridRatio, + EmbeddingBackend: cfg.Optimizer.EmbeddingBackend, + EmbeddingURL: cfg.Optimizer.EmbeddingURL, + EmbeddingModel: cfg.Optimizer.EmbeddingModel, + EmbeddingDimension: cfg.Optimizer.EmbeddingDimension, + } + persistInfo := "in-memory" + if cfg.Optimizer.PersistPath != "" { + persistInfo = cfg.Optimizer.PersistPath + } + // FTS5 is always enabled with configurable semantic/BM25 ratio + ratio := 0.7 // Default + if cfg.Optimizer.HybridSearchRatio != nil { + ratio = *cfg.Optimizer.HybridSearchRatio + } + searchMode := fmt.Sprintf("hybrid (%.0f%% semantic, %.0f%% BM25)", + ratio*100, + (1-ratio)*100) + logger.Infof("Optimizer configured: backend=%s, dimension=%d, persistence=%s, search=%s", + cfg.Optimizer.EmbeddingBackend, + cfg.Optimizer.EmbeddingDimension, + persistInfo, + searchMode) + } + // Convert composite tool configurations to workflow definitions workflowDefs, err := vmcpserver.ConvertConfigToWorkflowDefinitions(cfg.CompositeTools) if err != nil { diff --git a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml index 0806bb46b4..a20b4b1625 100644 --- a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -628,6 +628,84 @@ spec: type: object type: object type: object + optimizer: + description: |- + Optimizer configures the MCP optimizer for context optimization on large toolsets. + When enabled, vMCP exposes optim.find_tool and optim.call_tool operations to clients + instead of all backend tools directly. This reduces token usage by allowing + LLMs to discover relevant tools on demand rather than receiving all tool definitions. + properties: + embeddingBackend: + description: |- + EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder". + - "ollama": Uses local Ollama HTTP API for embeddings + - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + - "placeholder": Uses deterministic hash-based embeddings (for testing/development) + enum: + - ollama + - openai-compatible + - placeholder + type: string + embeddingDimension: + description: |- + EmbeddingDimension is the dimension of the embedding vectors. + Common values: + - 384: all-MiniLM-L6-v2, nomic-embed-text + - 768: BAAI/bge-small-en-v1.5 + - 1536: OpenAI text-embedding-3-small + minimum: 1 + type: integer + embeddingModel: + description: |- + EmbeddingModel is the model name to use for embeddings. + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "nomic-embed-text", "all-minilm" + - vLLM: "BAAI/bge-small-en-v1.5" + - OpenAI: "text-embedding-3-small" + type: string + embeddingService: + description: |- + EmbeddingService is the name of a Kubernetes Service that provides embeddings (K8s only). + This is an alternative to EmbeddingURL for in-cluster deployments. + When set, vMCP will resolve the service DNS name for the embedding API. + type: string + embeddingURL: + description: |- + EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API). + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "http://localhost:11434" + - vLLM: "http://vllm-service:8000/v1" + - OpenAI: "https://api.openai.com/v1" + type: string + enabled: + description: |- + Enabled determines whether the optimizer is active. + When true, vMCP exposes optim.find_tool and optim.call_tool instead of all backend tools. + type: boolean + ftsDBPath: + description: |- + FTSDBPath is the path to the SQLite FTS5 database for BM25 text search. + If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + Hybrid search (semantic + BM25) is always enabled. + type: string + hybridSearchRatio: + description: |- + HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search. + Value range: 0.0 (all BM25) to 1.0 (all semantic). + Default: 0.7 (70% semantic, 30% BM25) + Only used when FTSDBPath is set. + maximum: 1 + minimum: 0 + type: number + persistPath: + description: |- + PersistPath is the optional filesystem path for persisting the chromem-go database. + If empty, the database will be in-memory only (ephemeral). + When set, tool metadata and embeddings are persisted to disk for faster restarts. + type: string + type: object outgoingAuth: description: |- OutgoingAuth configures how the virtual MCP server authenticates to backends. diff --git a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml index 250c99f8d6..466a0906ce 100644 --- a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -631,6 +631,84 @@ spec: type: object type: object type: object + optimizer: + description: |- + Optimizer configures the MCP optimizer for context optimization on large toolsets. + When enabled, vMCP exposes optim.find_tool and optim.call_tool operations to clients + instead of all backend tools directly. This reduces token usage by allowing + LLMs to discover relevant tools on demand rather than receiving all tool definitions. + properties: + embeddingBackend: + description: |- + EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder". + - "ollama": Uses local Ollama HTTP API for embeddings + - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + - "placeholder": Uses deterministic hash-based embeddings (for testing/development) + enum: + - ollama + - openai-compatible + - placeholder + type: string + embeddingDimension: + description: |- + EmbeddingDimension is the dimension of the embedding vectors. + Common values: + - 384: all-MiniLM-L6-v2, nomic-embed-text + - 768: BAAI/bge-small-en-v1.5 + - 1536: OpenAI text-embedding-3-small + minimum: 1 + type: integer + embeddingModel: + description: |- + EmbeddingModel is the model name to use for embeddings. + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "nomic-embed-text", "all-minilm" + - vLLM: "BAAI/bge-small-en-v1.5" + - OpenAI: "text-embedding-3-small" + type: string + embeddingService: + description: |- + EmbeddingService is the name of a Kubernetes Service that provides embeddings (K8s only). + This is an alternative to EmbeddingURL for in-cluster deployments. + When set, vMCP will resolve the service DNS name for the embedding API. + type: string + embeddingURL: + description: |- + EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API). + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "http://localhost:11434" + - vLLM: "http://vllm-service:8000/v1" + - OpenAI: "https://api.openai.com/v1" + type: string + enabled: + description: |- + Enabled determines whether the optimizer is active. + When true, vMCP exposes optim.find_tool and optim.call_tool instead of all backend tools. + type: boolean + ftsDBPath: + description: |- + FTSDBPath is the path to the SQLite FTS5 database for BM25 text search. + If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + Hybrid search (semantic + BM25) is always enabled. + type: string + hybridSearchRatio: + description: |- + HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search. + Value range: 0.0 (all BM25) to 1.0 (all semantic). + Default: 0.7 (70% semantic, 30% BM25) + Only used when FTSDBPath is set. + maximum: 1 + minimum: 0 + type: number + persistPath: + description: |- + PersistPath is the optional filesystem path for persisting the chromem-go database. + If empty, the database will be in-memory only (ephemeral). + When set, tool metadata and embeddings are persisted to disk for faster restarts. + type: string + type: object outgoingAuth: description: |- OutgoingAuth configures how the virtual MCP server authenticates to backends. diff --git a/docs/operator/crd-api.md b/docs/operator/crd-api.md index cbd532f4c7..4250170268 100644 --- a/docs/operator/crd-api.md +++ b/docs/operator/crd-api.md @@ -244,6 +244,7 @@ _Appears in:_ | `metadata` _object (keys:string, values:string)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | | `telemetry` _[pkg.telemetry.Config](#pkgtelemetryconfig)_ | Telemetry configures OpenTelemetry-based observability for the Virtual MCP server
including distributed tracing, OTLP metrics export, and Prometheus metrics endpoint. | | | | `audit` _[pkg.audit.Config](#pkgauditconfig)_ | Audit configures audit logging for the Virtual MCP server.
When present, audit logs include MCP protocol operations.
See audit.Config for available configuration options. | | | +| `optimizer` _[vmcp.config.OptimizerConfig](#vmcpconfigoptimizerconfig)_ | Optimizer configures the MCP optimizer for context optimization on large toolsets.
When enabled, vMCP exposes optim.find_tool and optim.call_tool operations to clients
instead of all backend tools directly. This reduces token usage by allowing
LLMs to discover relevant tools on demand rather than receiving all tool definitions. | | | #### vmcp.config.ConflictResolutionConfig @@ -371,6 +372,32 @@ _Appears in:_ | `failureHandling` _[vmcp.config.FailureHandlingConfig](#vmcpconfigfailurehandlingconfig)_ | FailureHandling configures failure handling behavior. | | | +#### vmcp.config.OptimizerConfig + + + +OptimizerConfig configures the MCP optimizer for semantic tool discovery. +The optimizer reduces token usage by allowing LLMs to discover relevant tools +on demand rather than receiving all tool definitions upfront. + + + +_Appears in:_ +- [vmcp.config.Config](#vmcpconfigconfig) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `enabled` _boolean_ | Enabled determines whether the optimizer is active.
When true, vMCP exposes optim.find_tool and optim.call_tool instead of all backend tools. | | | +| `embeddingBackend` _string_ | EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder".
- "ollama": Uses local Ollama HTTP API for embeddings
- "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.)
- "placeholder": Uses deterministic hash-based embeddings (for testing/development) | | Enum: [ollama openai-compatible placeholder]
| +| `embeddingURL` _string_ | EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API).
Required when EmbeddingBackend is "ollama" or "openai-compatible".
Examples:
- Ollama: "http://localhost:11434"
- vLLM: "http://vllm-service:8000/v1"
- OpenAI: "https://api.openai.com/v1" | | | +| `embeddingModel` _string_ | EmbeddingModel is the model name to use for embeddings.
Required when EmbeddingBackend is "ollama" or "openai-compatible".
Examples:
- Ollama: "nomic-embed-text", "all-minilm"
- vLLM: "BAAI/bge-small-en-v1.5"
- OpenAI: "text-embedding-3-small" | | | +| `embeddingDimension` _integer_ | EmbeddingDimension is the dimension of the embedding vectors.
Common values:
- 384: all-MiniLM-L6-v2, nomic-embed-text
- 768: BAAI/bge-small-en-v1.5
- 1536: OpenAI text-embedding-3-small | | Minimum: 1
| +| `persistPath` _string_ | PersistPath is the optional filesystem path for persisting the chromem-go database.
If empty, the database will be in-memory only (ephemeral).
When set, tool metadata and embeddings are persisted to disk for faster restarts. | | | +| `ftsDBPath` _string_ | FTSDBPath is the path to the SQLite FTS5 database for BM25 text search.
If empty, defaults to ":memory:" for in-memory FTS5, or "\{PersistPath\}/fts.db" if PersistPath is set.
Hybrid search (semantic + BM25) is always enabled. | | | +| `hybridSearchRatio` _float_ | HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search.
Value range: 0.0 (all BM25) to 1.0 (all semantic).
Default: 0.7 (70% semantic, 30% BM25)
Only used when FTSDBPath is set. | | Maximum: 1
Minimum: 0
| +| `embeddingService` _string_ | EmbeddingService is the name of a Kubernetes Service that provides embeddings (K8s only).
This is an alternative to EmbeddingURL for in-cluster deployments.
When set, vMCP will resolve the service DNS name for the embedding API. | | | + + #### vmcp.config.OutgoingAuthConfig diff --git a/examples/vmcp-config-optimizer.yaml b/examples/vmcp-config-optimizer.yaml new file mode 100644 index 0000000000..5b20b074d9 --- /dev/null +++ b/examples/vmcp-config-optimizer.yaml @@ -0,0 +1,113 @@ +# vMCP Configuration with Optimizer Enabled +# This configuration enables the optimizer for semantic tool discovery + +name: "vmcp-debug" + +# Reference to ToolHive group containing MCP servers +groupRef: "default" + +# Client authentication (anonymous for local development) +incomingAuth: + type: anonymous + +# Backend authentication (unauthenticated for local development) +outgoingAuth: + source: inline + default: + type: unauthenticated + +# Tool aggregation settings +aggregation: + conflictResolution: prefix + conflictResolutionConfig: + prefixFormat: "{workload}_" + +# Operational settings +operational: + timeouts: + default: 30s + failureHandling: + healthCheckInterval: 30s + unhealthyThreshold: 3 + partialFailureMode: fail + +# ============================================================================= +# OPTIMIZER CONFIGURATION +# ============================================================================= +# When enabled, vMCP exposes optim.find_tool and optim.call_tool instead of +# all backend tools directly. This reduces token usage by allowing LLMs to +# discover relevant tools on demand via semantic search. +# +# The optimizer ingests tools from all backends in the group, generates +# embeddings, and provides semantic search capabilities. + +optimizer: + # Enable the optimizer + enabled: true + + # Embedding backend: "ollama", "openai-compatible", or "placeholder" + # - "ollama": Uses local Ollama HTTP API for embeddings + # - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + # - "placeholder": Uses deterministic hash-based embeddings (for testing) + embeddingBackend: placeholder + + # Embedding dimension (common values: 384, 768, 1536) + # 384 is standard for all-MiniLM-L6-v2 and nomic-embed-text + embeddingDimension: 384 + + # Optional: Path for persisting the chromem-go database + # If omitted, the database will be in-memory only (ephemeral) + persistPath: /tmp/vmcp-optimizer-debug.db + + # Optional: Path for the SQLite FTS5 database (for hybrid search) + # Default: ":memory:" (in-memory) or "{persistPath}/fts.db" if persistPath is set + # Hybrid search (semantic + BM25) is ALWAYS enabled + ftsDBPath: /tmp/vmcp-optimizer-fts.db # Uncomment to customize location + + # Optional: Hybrid search ratio (0.0 = all BM25, 1.0 = all semantic) + # Default: 0.7 (70% semantic, 30% BM25) + # hybridSearchRatio: 0.7 + + # ============================================================================= + # PRODUCTION CONFIGURATIONS (Commented Examples) + # ============================================================================= + + # Option 1: Local Ollama (good for development/testing) + # embeddingBackend: ollama + # embeddingURL: http://localhost:11434 + # embeddingModel: nomic-embed-text + # embeddingDimension: 384 + + # Option 2: vLLM (recommended for production with GPU acceleration) + # embeddingBackend: openai-compatible + # embeddingURL: http://vllm-service:8000/v1 + # embeddingModel: BAAI/bge-small-en-v1.5 + # embeddingDimension: 768 + + # Option 3: OpenAI API (cloud-based) + # embeddingBackend: openai-compatible + # embeddingURL: https://api.openai.com/v1 + # embeddingModel: text-embedding-3-small + # embeddingDimension: 1536 + # (requires OPENAI_API_KEY environment variable) + + # Option 4: Kubernetes in-cluster service (K8s deployments) + # embeddingService: embedding-service-name + # (vMCP will resolve the service DNS name) + +# ============================================================================= +# USAGE +# ============================================================================= +# 1. Start MCP backends in the group: +# thv run weather --group default +# thv run github --group default +# +# 2. Start vMCP with optimizer: +# thv vmcp serve --config examples/vmcp-config-optimizer.yaml +# +# 3. Connect MCP client to vMCP +# +# 4. Available tools from vMCP: +# - optim.find_tool: Search for tools by semantic query +# - optim.call_tool: Execute a tool by name +# - (backend tools are NOT directly exposed when optimizer is enabled) diff --git a/go.mod b/go.mod index 1cec9beed0..ceebcde0ec 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/onsi/ginkgo/v2 v2.27.5 github.com/onsi/gomega v1.39.0 github.com/ory/fosite v0.49.0 + github.com/philippgille/chromem-go v0.7.0 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/prometheus/client_golang v1.23.2 github.com/sigstore/protobuf-specs v0.5.0 @@ -59,6 +60,7 @@ require ( k8s.io/api v0.35.0 k8s.io/apimachinery v0.35.0 k8s.io/utils v0.0.0-20260108192941-914a6e750570 + modernc.org/sqlite v1.44.0 sigs.k8s.io/controller-runtime v0.22.4 sigs.k8s.io/yaml v1.6.0 ) @@ -174,6 +176,7 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect github.com/olekukonko/errors v1.1.0 // indirect @@ -188,6 +191,7 @@ require ( github.com/prometheus/common v0.67.4 // indirect github.com/prometheus/otlptranslator v1.0.0 // indirect github.com/prometheus/procfs v0.19.2 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect @@ -251,6 +255,9 @@ require ( k8s.io/apiextensions-apiserver v0.34.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect + modernc.org/libc v1.67.4 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect @@ -286,7 +293,7 @@ require ( go.opentelemetry.io/otel/metric v1.39.0 go.opentelemetry.io/otel/trace v1.39.0 golang.org/x/crypto v0.47.0 - golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/sys v0.40.0 k8s.io/client-go v0.35.0 ) diff --git a/go.sum b/go.sum index 536fae90cc..1c62e31397 100644 --- a/go.sum +++ b/go.sum @@ -600,6 +600,8 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nyaruka/phonenumbers v1.1.6 h1:DcueYq7QrOArAprAYNoQfDgp0KetO4LqtnBtQC6Wyes= github.com/nyaruka/phonenumbers v1.1.6/go.mod h1:yShPJHDSH3aTKzCbXyVxNpbl2kA+F+Ne5Pun/MvFRos= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= @@ -638,6 +640,8 @@ github.com/ory/x v0.0.665 h1:61vv0ObCDSX1vOQYbxBeqDiv4YiPmMT91lYxDaaKX08= github.com/ory/x v0.0.665/go.mod h1:7SCTki3N0De3ZpqlxhxU/94ZrOCfNEnXwVtd0xVt+L8= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY= +github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -659,6 +663,8 @@ github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEo github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM= github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -907,8 +913,8 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp/event v0.0.0-20251219203646-944ab1f22d93 h1:Fee8ke0jLfLhU4ywDLs7IYmhJ8MrSP0iZE3p39EKKSc= golang.org/x/exp/event v0.0.0-20251219203646-944ab1f22d93/go.mod h1:HgAgrKXB9WF2wFZJBGBnRVkmsC8n+v2ja/8VR0H3QkY= golang.org/x/exp/jsonrpc2 v0.0.0-20260112195511-716be5621a96 h1:cN9X2vSBmT3Ruw2UlbJNLJh0iBqTmtSB0dRfh5aumiY= @@ -1084,6 +1090,34 @@ k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 h1:Y3gxNAuB0OBLImH611+UDZ k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912/go.mod h1:kdmbQkyfwUagLfXIad1y2TdrjPFWp2Q89B3qkRwf/pQ= k8s.io/utils v0.0.0-20260108192941-914a6e750570 h1:JT4W8lsdrGENg9W+YwwdLJxklIuKWdRm+BC+xt33FOY= k8s.io/utils v0.0.0-20260108192941-914a6e750570/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg= +modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.44.0 h1:YjCKJnzZde2mLVy0cMKTSL4PxCmbIguOq9lGp8ZvGOc= +modernc.org/sqlite v1.44.0/go.mod h1:2Dq41ir5/qri7QJJJKNZcP4UF7TsX/KNeykYgPDtGhE= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= sigs.k8s.io/controller-runtime v0.22.4 h1:GEjV7KV3TY8e+tJ2LCTxUTanW4z/FmNB7l327UfMq9A= sigs.k8s.io/controller-runtime v0.22.4/go.mod h1:+QX1XUpTXN4mLoblf4tqr5CQcyHPAki2HLXqQMY6vh8= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg= diff --git a/pkg/optimizer/INTEGRATION.md b/pkg/optimizer/INTEGRATION.md new file mode 100644 index 0000000000..4d2db78b59 --- /dev/null +++ b/pkg/optimizer/INTEGRATION.md @@ -0,0 +1,131 @@ +# Integrating Optimizer with vMCP + +## Overview + +The optimizer package ingests MCP server and tool metadata into a searchable database with semantic embeddings. This enables intelligent tool discovery and token optimization for LLM consumption. + +## Integration Approach + +**Event-Driven Ingestion**: The optimizer integrates directly with vMCP's startup process. When vMCP starts and loads its configured servers, it calls the optimizer to ingest each server's metadata and tools. + +❌ **NOT** a separate polling service discovering backends +✅ **IS** called directly by vMCP during server initialization + +## How It Is Integrated + +The optimizer is already integrated into vMCP and works automatically when enabled via configuration. Here's how the integration works: + +### Initialization + +When vMCP starts with optimizer enabled in the configuration, it: + +1. Initializes the optimizer database (chromem-go + SQLite FTS5) +2. Configures the embedding backend (placeholder, Ollama, or vLLM) +3. Sets up the ingestion service + +### Automatic Ingestion + +The optimizer integrates with vMCP's `OnRegisterSession` hook, which is called whenever: + +- vMCP starts and loads configured MCP servers +- A new MCP server is dynamically added +- A session reconnects or refreshes + +When this hook is triggered, the optimizer: + +1. Retrieves the server's metadata and tools via MCP protocol +2. Generates embeddings for searchable content +3. Stores the data in both the vector database (chromem-go) and FTS5 database +4. Makes the tools immediately available for semantic search + +### Exposed Tools + +When the optimizer is enabled, vMCP automatically exposes these tools to LLM clients: + +- `optim.find_tool`: Semantic search for tools across all registered servers +- `optim.call_tool`: Dynamic tool invocation after discovery + +### Implementation Location + +The integration code is located in: +- `cmd/vmcp/optimizer.go`: Optimizer initialization and configuration +- `pkg/vmcp/optimizer/optimizer.go`: Session registration hook implementation +- `pkg/optimizer/ingestion/service.go`: Core ingestion service + +## Configuration + +Add optimizer configuration to vMCP's config: + +```yaml +# vMCP config +optimizer: + enabled: true + db_path: /data/optimizer.db + embedding: + backend: vllm # or "ollama" for local dev, "placeholder" for testing + url: http://vllm-service:8000 + model: sentence-transformers/all-MiniLM-L6-v2 + dimension: 384 +``` + +## Error Handling + +**Important**: Optimizer failures should NOT break vMCP functionality: + +- ✅ Log warnings if optimizer fails +- ✅ Continue server startup even if ingestion fails +- ✅ Run ingestion in goroutines to avoid blocking +- ❌ Don't fail server startup if optimizer is unavailable + +## Benefits + +1. **Automatic**: Servers are indexed as they're added to vMCP +2. **Up-to-date**: Database reflects current vMCP state +3. **No polling**: Event-driven, efficient +4. **Semantic search**: Enables intelligent tool discovery +5. **Token optimization**: Tracks token usage for LLM efficiency + +## Testing + +```go +func TestOptimizerIntegration(t *testing.T) { + // Initialize optimizer + optimizerSvc, err := ingestion.NewService(&ingestion.Config{ + DBConfig: &db.Config{Path: "/tmp/test-optimizer.db"}, + EmbeddingConfig: &embeddings.Config{ + BackendType: "placeholder", + Dimension: 384, + }, + }) + require.NoError(t, err) + defer optimizerSvc.Close() + + // Simulate vMCP starting a server + ctx := context.Background() + tools := []mcp.Tool{ + {Name: "get_weather", Description: "Get current weather"}, + {Name: "get_forecast", Description: "Get weather forecast"}, + } + + err = optimizerSvc.IngestServer( + ctx, + "weather-001", + "weather-service", + "http://weather.local", + models.TransportSSE, + ptr("Weather information service"), + tools, + ) + require.NoError(t, err) + + // Verify ingestion + server, err := optimizerSvc.GetServer(ctx, "weather-001") + require.NoError(t, err) + assert.Equal(t, "weather-service", server.Name) +} +``` + +## See Also + +- [Optimizer Package README](./README.md) - Package overview and API + diff --git a/pkg/optimizer/README.md b/pkg/optimizer/README.md new file mode 100644 index 0000000000..2984f2697a --- /dev/null +++ b/pkg/optimizer/README.md @@ -0,0 +1,337 @@ +# Optimizer Package + +The optimizer package provides semantic tool discovery and ingestion for MCP servers in ToolHive's vMCP. It enables intelligent, context-aware tool selection to reduce token usage and improve LLM performance. + +## Features + +- **Pure Go**: No CGO dependencies - uses [chromem-go](https://github.com/philippgille/chromem-go) for vector search and `modernc.org/sqlite` for FTS5 +- **Hybrid Search**: Combines semantic search (chromem-go) with BM25 full-text search (SQLite FTS5) +- **In-Memory by Default**: Fast ephemeral database with optional persistence +- **Pluggable Embeddings**: Supports vLLM, Ollama, and placeholder backends +- **Event-Driven**: Integrates with vMCP's `OnRegisterSession` hook for automatic ingestion +- **Semantic + Keyword Search**: Configurable ratio between semantic and BM25 search +- **Token Counting**: Tracks token usage for LLM consumption metrics + +## Architecture + +``` +pkg/optimizer/ +├── models/ # Domain models (Server, Tool, etc.) +├── db/ # Hybrid database layer (chromem-go + SQLite FTS5) +│ ├── db.go # Database coordinator +│ ├── fts.go # SQLite FTS5 for BM25 search (pure Go) +│ ├── hybrid.go # Hybrid search combining semantic + BM25 +│ ├── backend_server.go # Server operations +│ └── backend_tool.go # Tool operations +├── embeddings/ # Embedding backends (vLLM, Ollama, placeholder) +├── ingestion/ # Event-driven ingestion service +└── tokens/ # Token counting for LLM metrics +``` + +## Embedding Backends + +The optimizer supports multiple embedding backends: + +| Backend | Use Case | Performance | Setup | +|---------|----------|-------------|-------| +| **vLLM** | **Production/Kubernetes (recommended)** | Excellent (GPU) | Deploy vLLM service | +| Ollama | Local development, CPU-only | Good | `ollama serve` | +| Placeholder | Testing, CI/CD | Fast (hash-based) | Zero setup | + +**For production Kubernetes deployments, vLLM is recommended** due to its high-throughput performance, GPU efficiency (PagedAttention), and scalability for multi-user environments. + +## Hybrid Search + +The optimizer **always uses hybrid search** combining: + +1. **Semantic Search** (chromem-go): Understands meaning and context via embeddings +2. **BM25 Full-Text Search** (SQLite FTS5): Keyword matching with Porter stemming + +This dual approach ensures the best of both worlds: semantic understanding for intent-based queries and keyword precision for technical terms and acronyms. + +### Configuration + +```yaml +optimizer: + enabled: true + embeddingBackend: placeholder + embeddingDimension: 384 + # persistPath: /data/optimizer # Optional: for persistence + # ftsDBPath: /data/optimizer-fts.db # Optional: defaults to :memory: or {persistPath}/fts.db + hybridSearchRatio: 0.7 # 70% semantic, 30% BM25 (default) +``` + +| Ratio | Semantic | BM25 | Best For | +|-------|----------|------|----------| +| 1.0 | 100% | 0% | Pure semantic (intent-heavy queries) | +| 0.7 | 70% | 30% | **Default**: Balanced hybrid | +| 0.5 | 50% | 50% | Equal weight | +| 0.0 | 0% | 100% | Pure keyword (exact term matching) | + +### How It Works + +1. **Parallel Execution**: Semantic and BM25 searches run concurrently +2. **Result Merging**: Combines results and removes duplicates +3. **Ranking**: Sorts by similarity/relevance score +4. **Limit Enforcement**: Returns top N results + +### Example Queries + +| Query | Semantic Match | BM25 Match | Winner | +|-------|----------------|------------|--------| +| "What's the weather?" | ✅ `get_current_weather` | ✅ `weather_forecast` | Both (deduped) | +| "SQL database query" | ❌ (no embeddings) | ✅ `execute_sql` | BM25 | +| "Make it rain outside" | ✅ `weather_control` | ❌ (no keyword) | Semantic | + +## Quick Start + +### vMCP Integration (Recommended) + +The optimizer is designed to work as part of vMCP, not standalone: + +```yaml +# examples/vmcp-config-optimizer.yaml +optimizer: + enabled: true + embeddingBackend: placeholder # or "ollama", "openai-compatible" + embeddingDimension: 384 + # persistPath: /data/optimizer # Optional: for chromem-go persistence + # ftsDBPath: /data/fts.db # Optional: auto-defaults to :memory: or {persistPath}/fts.db + # hybridSearchRatio: 0.7 # Optional: 70% semantic, 30% BM25 (default) +``` + +Start vMCP with optimizer: + +```bash +thv vmcp serve --config examples/vmcp-config-optimizer.yaml +``` + +When optimizer is enabled, vMCP exposes: +- `optim.find_tool`: Semantic search for tools +- `optim.call_tool`: Dynamic tool invocation + +### Programmatic Usage + +```go +import ( + "context" + + "github.com/stacklok/toolhive/pkg/optimizer/db" + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/pkg/optimizer/ingestion" +) + +func main() { + ctx := context.Background() + + // Initialize database (in-memory) + database, err := db.NewDB(&db.Config{ + PersistPath: "", // Empty = in-memory only + }) + if err != nil { + panic(err) + } + + // Initialize embedding manager with placeholder (no external dependencies) + embeddingMgr, err := embeddings.NewManager(&embeddings.Config{ + BackendType: "placeholder", + Dimension: 384, + }) + if err != nil { + panic(err) + } + + // Create ingestion service + svc, err := ingestion.NewService(&ingestion.Config{ + DBConfig: &db.Config{PersistPath: ""}, + EmbeddingConfig: embeddingMgr.Config(), + }) + if err != nil { + panic(err) + } + defer svc.Close() + + // Ingest a server (called by vMCP on session registration) + err = svc.IngestServer(ctx, "server-id", "MyServer", nil, []mcp.Tool{...}) + if err != nil { + panic(err) + } +} +``` + +### Production Deployment with vLLM (Kubernetes) + +```yaml +optimizer: + enabled: true + embeddingBackend: openai-compatible + embeddingURL: http://vllm-service:8000/v1 + embeddingModel: BAAI/bge-small-en-v1.5 + embeddingDimension: 768 + persistPath: /data/optimizer # Persistent storage for faster restarts +``` + +Deploy vLLM alongside vMCP: + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: vllm-embeddings +spec: + template: + spec: + containers: + - name: vllm + image: vllm/vllm-openai:latest + args: + - --model + - BAAI/bge-small-en-v1.5 + - --port + - "8000" + resources: + limits: + nvidia.com/gpu: 1 +``` + +### Local Development with Ollama + +```bash +# Start Ollama +ollama serve + +# Pull an embedding model +ollama pull nomic-embed-text +``` + +Configure vMCP: + +```yaml +optimizer: + enabled: true + embeddingBackend: ollama + embeddingURL: http://localhost:11434 + embeddingModel: nomic-embed-text + embeddingDimension: 384 +``` + +## Configuration + +### Database + +- **Storage**: chromem-go (pure Go, no CGO) +- **Default**: In-memory (ephemeral) +- **Persistence**: Optional via `persistPath` +- **Format**: Binary (gob encoding) + +### Embedding Models + +Common embedding dimensions: +- **384**: all-MiniLM-L6-v2, nomic-embed-text (default) +- **768**: BAAI/bge-small-en-v1.5 +- **1536**: OpenAI text-embedding-3-small + +### Performance + +From chromem-go benchmarks (mid-range 2020 Intel laptop): +- **1,000 tools**: ~0.5ms query time +- **5,000 tools**: ~2.2ms query time +- **25,000 tools**: ~9.9ms query time +- **100,000 tools**: ~39.6ms query time + +Perfect for typical vMCP deployments (hundreds to thousands of tools). + +## Testing + +Run the unit tests: + +```bash +# Test all packages +go test ./pkg/optimizer/... + +# Test with coverage +go test -cover ./pkg/optimizer/... + +# Test specific package +go test ./pkg/optimizer/models +``` + +## Inspecting the Database + +The optimizer uses a hybrid database (chromem-go + SQLite FTS5). Here's how to inspect each: + +### Inspecting SQLite FTS5 (Easiest) + +The FTS5 database is standard SQLite and can be opened with any SQLite tool: + +```bash +# Use sqlite3 CLI +sqlite3 /tmp/vmcp-optimizer-fts.db + +# Count documents +SELECT COUNT(*) FROM backend_servers_fts; +SELECT COUNT(*) FROM backend_tools_fts; + +# View tool names and descriptions +SELECT tool_name, tool_description FROM backend_tools_fts LIMIT 10; + +# Full-text search with BM25 ranking +SELECT tool_name, rank +FROM backend_tool_fts_index +WHERE backend_tool_fts_index MATCH 'github repository' +ORDER BY rank +LIMIT 5; + +# Join servers and tools +SELECT s.name, t.tool_name, t.tool_description +FROM backend_tools_fts t +JOIN backend_servers_fts s ON t.mcpserver_id = s.id +LIMIT 10; +``` + +**VSCode Extension**: Install `alexcvzz.vscode-sqlite` to view `.db` files directly in VSCode. + +### Inspecting chromem-go (Vector Database) + +chromem-go uses `.gob` binary files. Use the provided inspection scripts: + +```bash +# Quick summary (shows collection sizes and first few documents) +go run scripts/inspect-chromem-raw.go /tmp/vmcp-optimizer-debug.db + +# View specific tool with full metadata and embeddings +go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db get_file_contents + +# View all documents (warning: lots of output) +go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db + +# Search by content +go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db "search" +``` + +### chromem-go Schema + +Each document in chromem-go contains: + +```go +Document { + ID: string // "github" or UUID for tools + Content: string // "tool_name. description..." + Embedding: []float32 // 384-dimensional vector + Metadata: map[string]string // {"type": "backend_tool", "server_id": "github", "data": "...JSON..."} +} +``` + +**Collections**: +- `backend_servers`: Server metadata (3 documents in typical setup) +- `backend_tools`: Tool metadata and embeddings (40+ documents) + +## Known Limitations + +1. **Scale**: Optimized for <100,000 tools (more than sufficient for typical vMCP deployments) +2. **Approximate Search**: chromem-go uses exhaustive search (not HNSW), but this is fine for our scale +3. **Persistence Format**: Binary gob format (not human-readable) + +## License + +This package is part of ToolHive and follows the same license. diff --git a/pkg/optimizer/db/backend_server.go b/pkg/optimizer/db/backend_server.go new file mode 100644 index 0000000000..8685d4c47d --- /dev/null +++ b/pkg/optimizer/db/backend_server.go @@ -0,0 +1,234 @@ +// Package db provides chromem-go based database operations for the optimizer. +package db + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/optimizer/models" +) + +// BackendServerOps provides operations for backend servers in chromem-go +type BackendServerOps struct { + db *DB + embeddingFunc chromem.EmbeddingFunc +} + +// NewBackendServerOps creates a new BackendServerOps instance +func NewBackendServerOps(db *DB, embeddingFunc chromem.EmbeddingFunc) *BackendServerOps { + return &BackendServerOps{ + db: db, + embeddingFunc: embeddingFunc, + } +} + +// Create adds a new backend server to the collection +func (ops *BackendServerOps) Create(ctx context.Context, server *models.BackendServer) error { + collection, err := ops.db.GetOrCreateCollection(ctx, BackendServerCollection, ops.embeddingFunc) + if err != nil { + return fmt.Errorf("failed to get backend server collection: %w", err) + } + + // Prepare content for embedding (name + description) + content := server.Name + if server.Description != nil && *server.Description != "" { + content += ". " + *server.Description + } + + // Serialize metadata + metadata, err := serializeServerMetadata(server) + if err != nil { + return fmt.Errorf("failed to serialize server metadata: %w", err) + } + + // Create document + doc := chromem.Document{ + ID: server.ID, + Content: content, + Metadata: metadata, + } + + // If embedding is provided, use it + if len(server.ServerEmbedding) > 0 { + doc.Embedding = server.ServerEmbedding + } + + // Add document to chromem-go collection + err = collection.AddDocument(ctx, doc) + if err != nil { + return fmt.Errorf("failed to add server document to chromem-go: %w", err) + } + + // Also add to FTS5 database if available (for keyword filtering) + if ftsDB := ops.db.GetFTSDB(); ftsDB != nil { + if err := ftsDB.UpsertServer(ctx, server); err != nil { + // Log but don't fail - FTS5 is supplementary + logger.Warnf("Failed to upsert server to FTS5: %v", err) + } + } + + logger.Debugf("Created backend server: %s (chromem-go + FTS5)", server.ID) + return nil +} + +// Get retrieves a backend server by ID +func (ops *BackendServerOps) Get(ctx context.Context, serverID string) (*models.BackendServer, error) { + collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc) + if err != nil { + return nil, fmt.Errorf("backend server collection not found: %w", err) + } + + // Query by ID with exact match + results, err := collection.Query(ctx, serverID, 1, nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to query server: %w", err) + } + + if len(results) == 0 { + return nil, fmt.Errorf("server not found: %s", serverID) + } + + // Deserialize from metadata + server, err := deserializeServerMetadata(results[0].Metadata) + if err != nil { + return nil, fmt.Errorf("failed to deserialize server: %w", err) + } + + return server, nil +} + +// Update updates an existing backend server +func (ops *BackendServerOps) Update(ctx context.Context, server *models.BackendServer) error { + // chromem-go doesn't have an update operation, so we delete and re-create + err := ops.Delete(ctx, server.ID) + if err != nil { + // If server doesn't exist, that's fine + logger.Debugf("Server %s not found for update, will create new", server.ID) + } + + return ops.Create(ctx, server) +} + +// Delete removes a backend server +func (ops *BackendServerOps) Delete(ctx context.Context, serverID string) error { + collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist, nothing to delete + return nil + } + + err = collection.Delete(ctx, nil, nil, serverID) + if err != nil { + return fmt.Errorf("failed to delete server from chromem-go: %w", err) + } + + // Also delete from FTS5 database if available + if ftsDB := ops.db.GetFTSDB(); ftsDB != nil { + if err := ftsDB.DeleteServer(ctx, serverID); err != nil { + // Log but don't fail + logger.Warnf("Failed to delete server from FTS5: %v", err) + } + } + + logger.Debugf("Deleted backend server: %s (chromem-go + FTS5)", serverID) + return nil +} + +// List returns all backend servers +func (ops *BackendServerOps) List(ctx context.Context) ([]*models.BackendServer, error) { + collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist yet, return empty list + return []*models.BackendServer{}, nil + } + + // Get count to determine nResults + count := collection.Count() + if count == 0 { + return []*models.BackendServer{}, nil + } + + // Query with a generic term to get all servers + // Using "server" as a generic query that should match all servers + results, err := collection.Query(ctx, "server", count, nil, nil) + if err != nil { + return []*models.BackendServer{}, nil + } + + servers := make([]*models.BackendServer, 0, len(results)) + for _, result := range results { + server, err := deserializeServerMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize server: %v", err) + continue + } + servers = append(servers, server) + } + + return servers, nil +} + +// Search performs semantic search for backend servers +func (ops *BackendServerOps) Search(ctx context.Context, query string, limit int) ([]*models.BackendServer, error) { + collection, err := ops.db.GetCollection(BackendServerCollection, ops.embeddingFunc) + if err != nil { + return []*models.BackendServer{}, nil + } + + // Get collection count and adjust limit if necessary + count := collection.Count() + if count == 0 { + return []*models.BackendServer{}, nil + } + if limit > count { + limit = count + } + + results, err := collection.Query(ctx, query, limit, nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to search servers: %w", err) + } + + servers := make([]*models.BackendServer, 0, len(results)) + for _, result := range results { + server, err := deserializeServerMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize server: %v", err) + continue + } + servers = append(servers, server) + } + + return servers, nil +} + +// Helper functions for metadata serialization + +func serializeServerMetadata(server *models.BackendServer) (map[string]string, error) { + data, err := json.Marshal(server) + if err != nil { + return nil, err + } + return map[string]string{ + "data": string(data), + "type": "backend_server", + }, nil +} + +func deserializeServerMetadata(metadata map[string]string) (*models.BackendServer, error) { + data, ok := metadata["data"] + if !ok { + return nil, fmt.Errorf("missing data field in metadata") + } + + var server models.BackendServer + if err := json.Unmarshal([]byte(data), &server); err != nil { + return nil, err + } + + return &server, nil +} diff --git a/pkg/optimizer/db/backend_server_test.go b/pkg/optimizer/db/backend_server_test.go new file mode 100644 index 0000000000..adc23ae91c --- /dev/null +++ b/pkg/optimizer/db/backend_server_test.go @@ -0,0 +1,424 @@ +package db + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/models" +) + +// TestBackendServerOps_Create tests creating a backend server +func TestBackendServerOps_Create(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + description := "A test MCP server" + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: &description, + Group: "default", + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Verify server was created by retrieving it + retrieved, err := ops.Get(ctx, "server-1") + require.NoError(t, err) + assert.Equal(t, "Test Server", retrieved.Name) + assert.Equal(t, "server-1", retrieved.ID) + assert.Equal(t, description, *retrieved.Description) +} + +// TestBackendServerOps_CreateWithEmbedding tests creating server with precomputed embedding +func TestBackendServerOps_CreateWithEmbedding(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + description := "Server with embedding" + embedding := make([]float32, 384) + for i := range embedding { + embedding[i] = 0.5 + } + + server := &models.BackendServer{ + ID: "server-2", + Name: "Embedded Server", + Description: &description, + Group: "default", + ServerEmbedding: embedding, + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Verify server was created + retrieved, err := ops.Get(ctx, "server-2") + require.NoError(t, err) + assert.Equal(t, "Embedded Server", retrieved.Name) +} + +// TestBackendServerOps_Get tests retrieving a backend server +func TestBackendServerOps_Get(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create a server first + description := "GitHub MCP server" + server := &models.BackendServer{ + ID: "github-server", + Name: "GitHub", + Description: &description, + Group: "development", + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Test Get + retrieved, err := ops.Get(ctx, "github-server") + require.NoError(t, err) + assert.Equal(t, "github-server", retrieved.ID) + assert.Equal(t, "GitHub", retrieved.Name) + assert.Equal(t, "development", retrieved.Group) +} + +// TestBackendServerOps_Get_NotFound tests retrieving non-existent server +func TestBackendServerOps_Get_NotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Try to get a non-existent server + _, err := ops.Get(ctx, "non-existent") + assert.Error(t, err) + // Error message could be "server not found" or "collection not found" depending on state + assert.True(t, err != nil, "Should return an error for non-existent server") +} + +// TestBackendServerOps_Update tests updating a backend server +func TestBackendServerOps_Update(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create initial server + description := "Original description" + server := &models.BackendServer{ + ID: "server-1", + Name: "Original Name", + Description: &description, + Group: "default", + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Update the server + updatedDescription := "Updated description" + server.Name = "Updated Name" + server.Description = &updatedDescription + + err = ops.Update(ctx, server) + require.NoError(t, err) + + // Verify update + retrieved, err := ops.Get(ctx, "server-1") + require.NoError(t, err) + assert.Equal(t, "Updated Name", retrieved.Name) + assert.Equal(t, "Updated description", *retrieved.Description) +} + +// TestBackendServerOps_Update_NonExistent tests updating non-existent server +func TestBackendServerOps_Update_NonExistent(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Try to update non-existent server (should create it) + description := "New server" + server := &models.BackendServer{ + ID: "new-server", + Name: "New Server", + Description: &description, + Group: "default", + } + + err := ops.Update(ctx, server) + require.NoError(t, err) + + // Verify server was created + retrieved, err := ops.Get(ctx, "new-server") + require.NoError(t, err) + assert.Equal(t, "New Server", retrieved.Name) +} + +// TestBackendServerOps_Delete tests deleting a backend server +func TestBackendServerOps_Delete(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create a server + description := "Server to delete" + server := &models.BackendServer{ + ID: "delete-me", + Name: "Delete Me", + Description: &description, + Group: "default", + } + + err := ops.Create(ctx, server) + require.NoError(t, err) + + // Delete the server + err = ops.Delete(ctx, "delete-me") + require.NoError(t, err) + + // Verify deletion + _, err = ops.Get(ctx, "delete-me") + assert.Error(t, err, "Should not find deleted server") +} + +// TestBackendServerOps_Delete_NonExistent tests deleting non-existent server +func TestBackendServerOps_Delete_NonExistent(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Try to delete a non-existent server - should not error + err := ops.Delete(ctx, "non-existent") + assert.NoError(t, err) +} + +// TestBackendServerOps_List tests listing all servers +func TestBackendServerOps_List(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create multiple servers + desc1 := "Server 1" + server1 := &models.BackendServer{ + ID: "server-1", + Name: "Server 1", + Description: &desc1, + Group: "group-a", + } + + desc2 := "Server 2" + server2 := &models.BackendServer{ + ID: "server-2", + Name: "Server 2", + Description: &desc2, + Group: "group-b", + } + + desc3 := "Server 3" + server3 := &models.BackendServer{ + ID: "server-3", + Name: "Server 3", + Description: &desc3, + Group: "group-a", + } + + err := ops.Create(ctx, server1) + require.NoError(t, err) + err = ops.Create(ctx, server2) + require.NoError(t, err) + err = ops.Create(ctx, server3) + require.NoError(t, err) + + // List all servers + servers, err := ops.List(ctx) + require.NoError(t, err) + assert.Len(t, servers, 3, "Should have 3 servers") + + // Verify server names + serverNames := make(map[string]bool) + for _, server := range servers { + serverNames[server.Name] = true + } + assert.True(t, serverNames["Server 1"]) + assert.True(t, serverNames["Server 2"]) + assert.True(t, serverNames["Server 3"]) +} + +// TestBackendServerOps_List_Empty tests listing servers on empty database +func TestBackendServerOps_List_Empty(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // List empty database + servers, err := ops.List(ctx) + require.NoError(t, err) + assert.Empty(t, servers, "Should return empty list for empty database") +} + +// TestBackendServerOps_Search tests semantic search for servers +func TestBackendServerOps_Search(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Create test servers + desc1 := "GitHub integration server" + server1 := &models.BackendServer{ + ID: "github", + Name: "GitHub Server", + Description: &desc1, + Group: "vcs", + } + + desc2 := "Slack messaging server" + server2 := &models.BackendServer{ + ID: "slack", + Name: "Slack Server", + Description: &desc2, + Group: "messaging", + } + + err := ops.Create(ctx, server1) + require.NoError(t, err) + err = ops.Create(ctx, server2) + require.NoError(t, err) + + // Search for servers + results, err := ops.Search(ctx, "integration", 5) + require.NoError(t, err) + assert.NotEmpty(t, results, "Should find servers") +} + +// TestBackendServerOps_Search_Empty tests search on empty database +func TestBackendServerOps_Search_Empty(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendServerOps(db, embeddingFunc) + + // Search empty database + results, err := ops.Search(ctx, "anything", 5) + require.NoError(t, err) + assert.Empty(t, results, "Should return empty results for empty database") +} + +// TestBackendServerOps_MetadataSerialization tests metadata serialization/deserialization +func TestBackendServerOps_MetadataSerialization(t *testing.T) { + t.Parallel() + + description := "Test server" + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: &description, + Group: "default", + } + + // Test serialization + metadata, err := serializeServerMetadata(server) + require.NoError(t, err) + assert.Contains(t, metadata, "data") + assert.Equal(t, "backend_server", metadata["type"]) + + // Test deserialization + deserializedServer, err := deserializeServerMetadata(metadata) + require.NoError(t, err) + assert.Equal(t, server.ID, deserializedServer.ID) + assert.Equal(t, server.Name, deserializedServer.Name) + assert.Equal(t, server.Group, deserializedServer.Group) +} + +// TestBackendServerOps_MetadataDeserialization_MissingData tests error handling +func TestBackendServerOps_MetadataDeserialization_MissingData(t *testing.T) { + t.Parallel() + + // Test with missing data field + metadata := map[string]string{ + "type": "backend_server", + } + + _, err := deserializeServerMetadata(metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing data field") +} + +// TestBackendServerOps_MetadataDeserialization_InvalidJSON tests invalid JSON handling +func TestBackendServerOps_MetadataDeserialization_InvalidJSON(t *testing.T) { + t.Parallel() + + // Test with invalid JSON + metadata := map[string]string{ + "data": "invalid json {", + "type": "backend_server", + } + + _, err := deserializeServerMetadata(metadata) + assert.Error(t, err) +} diff --git a/pkg/optimizer/db/backend_tool.go b/pkg/optimizer/db/backend_tool.go new file mode 100644 index 0000000000..909779edb8 --- /dev/null +++ b/pkg/optimizer/db/backend_tool.go @@ -0,0 +1,310 @@ +package db + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/optimizer/models" +) + +// BackendToolOps provides operations for backend tools in chromem-go +type BackendToolOps struct { + db *DB + embeddingFunc chromem.EmbeddingFunc +} + +// NewBackendToolOps creates a new BackendToolOps instance +func NewBackendToolOps(db *DB, embeddingFunc chromem.EmbeddingFunc) *BackendToolOps { + return &BackendToolOps{ + db: db, + embeddingFunc: embeddingFunc, + } +} + +// Create adds a new backend tool to the collection +func (ops *BackendToolOps) Create(ctx context.Context, tool *models.BackendTool, serverName string) error { + collection, err := ops.db.GetOrCreateCollection(ctx, BackendToolCollection, ops.embeddingFunc) + if err != nil { + return fmt.Errorf("failed to get backend tool collection: %w", err) + } + + // Prepare content for embedding (name + description + input schema summary) + content := tool.ToolName + if tool.Description != nil && *tool.Description != "" { + content += ". " + *tool.Description + } + + // Serialize metadata + metadata, err := serializeToolMetadata(tool) + if err != nil { + return fmt.Errorf("failed to serialize tool metadata: %w", err) + } + + // Create document + doc := chromem.Document{ + ID: tool.ID, + Content: content, + Metadata: metadata, + } + + // If embedding is provided, use it + if len(tool.ToolEmbedding) > 0 { + doc.Embedding = tool.ToolEmbedding + } + + // Add document to chromem-go collection + err = collection.AddDocument(ctx, doc) + if err != nil { + return fmt.Errorf("failed to add tool document to chromem-go: %w", err) + } + + // Also add to FTS5 database if available (for BM25 search) + if ops.db.fts != nil { + if err := ops.db.fts.UpsertToolMeta(ctx, tool, serverName); err != nil { + // Log but don't fail - FTS5 is supplementary + logger.Warnf("Failed to upsert tool to FTS5: %v", err) + } + } + + logger.Debugf("Created backend tool: %s (chromem-go + FTS5)", tool.ID) + return nil +} + +// Get retrieves a backend tool by ID +func (ops *BackendToolOps) Get(ctx context.Context, toolID string) (*models.BackendTool, error) { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + return nil, fmt.Errorf("backend tool collection not found: %w", err) + } + + // Query by ID with exact match + results, err := collection.Query(ctx, toolID, 1, nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to query tool: %w", err) + } + + if len(results) == 0 { + return nil, fmt.Errorf("tool not found: %s", toolID) + } + + // Deserialize from metadata + tool, err := deserializeToolMetadata(results[0].Metadata) + if err != nil { + return nil, fmt.Errorf("failed to deserialize tool: %w", err) + } + + return tool, nil +} + +// Update updates an existing backend tool in chromem-go +// Note: This only updates chromem-go, not FTS5. Use Create to update both. +func (ops *BackendToolOps) Update(ctx context.Context, tool *models.BackendTool) error { + collection, err := ops.db.GetOrCreateCollection(ctx, BackendToolCollection, ops.embeddingFunc) + if err != nil { + return fmt.Errorf("failed to get backend tool collection: %w", err) + } + + // Prepare content for embedding + content := tool.ToolName + if tool.Description != nil && *tool.Description != "" { + content += ". " + *tool.Description + } + + // Serialize metadata + metadata, err := serializeToolMetadata(tool) + if err != nil { + return fmt.Errorf("failed to serialize tool metadata: %w", err) + } + + // Delete existing document + _ = collection.Delete(ctx, nil, nil, tool.ID) // Ignore error if doesn't exist + + // Create updated document + doc := chromem.Document{ + ID: tool.ID, + Content: content, + Metadata: metadata, + } + + if len(tool.ToolEmbedding) > 0 { + doc.Embedding = tool.ToolEmbedding + } + + err = collection.AddDocument(ctx, doc) + if err != nil { + return fmt.Errorf("failed to update tool document: %w", err) + } + + logger.Debugf("Updated backend tool: %s", tool.ID) + return nil +} + +// Delete removes a backend tool +func (ops *BackendToolOps) Delete(ctx context.Context, toolID string) error { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist, nothing to delete + return nil + } + + err = collection.Delete(ctx, nil, nil, toolID) + if err != nil { + return fmt.Errorf("failed to delete tool: %w", err) + } + + logger.Debugf("Deleted backend tool: %s", toolID) + return nil +} + +// DeleteByServer removes all tools for a given server from both chromem-go and FTS5 +func (ops *BackendToolOps) DeleteByServer(ctx context.Context, serverID string) error { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist, nothing to delete in chromem-go + logger.Debug("Backend tool collection not found, skipping chromem-go deletion") + } else { + // Query all tools for this server + tools, err := ops.ListByServer(ctx, serverID) + if err != nil { + return fmt.Errorf("failed to list tools for server: %w", err) + } + + // Delete each tool from chromem-go + for _, tool := range tools { + if err := collection.Delete(ctx, nil, nil, tool.ID); err != nil { + logger.Warnf("Failed to delete tool %s from chromem-go: %v", tool.ID, err) + } + } + + logger.Debugf("Deleted %d tools from chromem-go for server: %s", len(tools), serverID) + } + + // Also delete from FTS5 database if available + if ops.db.fts != nil { + if err := ops.db.fts.DeleteToolsByServer(ctx, serverID); err != nil { + logger.Warnf("Failed to delete tools from FTS5 for server %s: %v", serverID, err) + } else { + logger.Debugf("Deleted tools from FTS5 for server: %s", serverID) + } + } + + return nil +} + +// ListByServer returns all tools for a given server +func (ops *BackendToolOps) ListByServer(ctx context.Context, serverID string) ([]*models.BackendTool, error) { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist yet, return empty list + return []*models.BackendTool{}, nil + } + + // Get count to determine nResults + count := collection.Count() + if count == 0 { + return []*models.BackendTool{}, nil + } + + // Query with a generic term and metadata filter + // Using "tool" as a generic query that should match all tools + results, err := collection.Query(ctx, "tool", count, map[string]string{"server_id": serverID}, nil) + if err != nil { + // If no tools match, return empty list + return []*models.BackendTool{}, nil + } + + tools := make([]*models.BackendTool, 0, len(results)) + for _, result := range results { + tool, err := deserializeToolMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize tool: %v", err) + continue + } + tools = append(tools, tool) + } + + return tools, nil +} + +// Search performs semantic search for backend tools +func (ops *BackendToolOps) Search( + ctx context.Context, + query string, + limit int, + serverID *string, +) ([]*models.BackendToolWithMetadata, error) { + collection, err := ops.db.GetCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + return []*models.BackendToolWithMetadata{}, nil + } + + // Get collection count and adjust limit if necessary + count := collection.Count() + if count == 0 { + return []*models.BackendToolWithMetadata{}, nil + } + if limit > count { + limit = count + } + + // Build metadata filter if server ID is provided + var metadataFilter map[string]string + if serverID != nil { + metadataFilter = map[string]string{"server_id": *serverID} + } + + results, err := collection.Query(ctx, query, limit, metadataFilter, nil) + if err != nil { + return nil, fmt.Errorf("failed to search tools: %w", err) + } + + tools := make([]*models.BackendToolWithMetadata, 0, len(results)) + for _, result := range results { + tool, err := deserializeToolMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize tool: %v", err) + continue + } + + // Add similarity score + toolWithMeta := &models.BackendToolWithMetadata{ + BackendTool: *tool, + Similarity: result.Similarity, + } + tools = append(tools, toolWithMeta) + } + + return tools, nil +} + +// Helper functions for metadata serialization + +func serializeToolMetadata(tool *models.BackendTool) (map[string]string, error) { + data, err := json.Marshal(tool) + if err != nil { + return nil, err + } + return map[string]string{ + "data": string(data), + "type": "backend_tool", + "server_id": tool.MCPServerID, + }, nil +} + +func deserializeToolMetadata(metadata map[string]string) (*models.BackendTool, error) { + data, ok := metadata["data"] + if !ok { + return nil, fmt.Errorf("missing data field in metadata") + } + + var tool models.BackendTool + if err := json.Unmarshal([]byte(data), &tool); err != nil { + return nil, err + } + + return &tool, nil +} diff --git a/pkg/optimizer/db/backend_tool_test.go b/pkg/optimizer/db/backend_tool_test.go new file mode 100644 index 0000000000..557e5ca5f5 --- /dev/null +++ b/pkg/optimizer/db/backend_tool_test.go @@ -0,0 +1,579 @@ +package db + +import ( + "context" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/pkg/optimizer/models" +) + +// createTestDB creates a test database with placeholder embeddings +func createTestDB(t *testing.T) *DB { + t.Helper() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + + return db +} + +// createTestEmbeddingFunc creates a test embedding function using placeholder embeddings +func createTestEmbeddingFunc(t *testing.T) func(ctx context.Context, text string) ([]float32, error) { + t.Helper() + + // Create placeholder embedding manager + config := &embeddings.Config{ + BackendType: "placeholder", + Dimension: 384, + } + + manager, err := embeddings.NewManager(config) + require.NoError(t, err) + t.Cleanup(func() { _ = manager.Close() }) + + return func(_ context.Context, text string) ([]float32, error) { + results, err := manager.GenerateEmbedding([]string{text}) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, assert.AnError + } + return results[0], nil + } +} + +// TestBackendToolOps_Create tests creating a backend tool +func TestBackendToolOps_Create(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + description := "Get current weather information" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "get_weather", + Description: &description, + InputSchema: []byte(`{"type":"object","properties":{"location":{"type":"string"}}}`), + TokenCount: 100, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Verify tool was created by retrieving it + retrieved, err := ops.Get(ctx, "tool-1") + require.NoError(t, err) + assert.Equal(t, "get_weather", retrieved.ToolName) + assert.Equal(t, "server-1", retrieved.MCPServerID) + assert.Equal(t, description, *retrieved.Description) +} + +// TestBackendToolOps_CreateWithPrecomputedEmbedding tests creating tool with existing embedding +func TestBackendToolOps_CreateWithPrecomputedEmbedding(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + description := "Search the web" + // Generate a precomputed embedding + precomputedEmbedding := make([]float32, 384) + for i := range precomputedEmbedding { + precomputedEmbedding[i] = 0.1 + } + + tool := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "search_web", + Description: &description, + InputSchema: []byte(`{}`), + ToolEmbedding: precomputedEmbedding, + TokenCount: 50, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Verify tool was created + retrieved, err := ops.Get(ctx, "tool-2") + require.NoError(t, err) + assert.Equal(t, "search_web", retrieved.ToolName) +} + +// TestBackendToolOps_Get tests retrieving a backend tool +func TestBackendToolOps_Get(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create a tool first + description := "Send an email" + tool := &models.BackendTool{ + ID: "tool-3", + MCPServerID: "server-1", + ToolName: "send_email", + Description: &description, + InputSchema: []byte(`{}`), + TokenCount: 75, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Test Get + retrieved, err := ops.Get(ctx, "tool-3") + require.NoError(t, err) + assert.Equal(t, "tool-3", retrieved.ID) + assert.Equal(t, "send_email", retrieved.ToolName) +} + +// TestBackendToolOps_Get_NotFound tests retrieving non-existent tool +func TestBackendToolOps_Get_NotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Try to get a non-existent tool + _, err := ops.Get(ctx, "non-existent") + assert.Error(t, err) +} + +// TestBackendToolOps_Update tests updating a backend tool +func TestBackendToolOps_Update(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create initial tool + description := "Original description" + tool := &models.BackendTool{ + ID: "tool-4", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &description, + InputSchema: []byte(`{}`), + TokenCount: 50, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Update the tool + const updatedDescription = "Updated description" + updatedDescriptionCopy := updatedDescription + tool.Description = &updatedDescriptionCopy + tool.TokenCount = 75 + + err = ops.Update(ctx, tool) + require.NoError(t, err) + + // Verify update + retrieved, err := ops.Get(ctx, "tool-4") + require.NoError(t, err) + assert.Equal(t, "Updated description", *retrieved.Description) +} + +// TestBackendToolOps_Delete tests deleting a backend tool +func TestBackendToolOps_Delete(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create a tool + description := "Tool to delete" + tool := &models.BackendTool{ + ID: "tool-5", + MCPServerID: "server-1", + ToolName: "delete_me", + Description: &description, + InputSchema: []byte(`{}`), + TokenCount: 25, + } + + err := ops.Create(ctx, tool, "Test Server") + require.NoError(t, err) + + // Delete the tool + err = ops.Delete(ctx, "tool-5") + require.NoError(t, err) + + // Verify deletion + _, err = ops.Get(ctx, "tool-5") + assert.Error(t, err, "Should not find deleted tool") +} + +// TestBackendToolOps_Delete_NonExistent tests deleting non-existent tool +func TestBackendToolOps_Delete_NonExistent(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Try to delete a non-existent tool - should not error + err := ops.Delete(ctx, "non-existent") + // Delete may or may not error depending on implementation + // Just ensure it doesn't panic + _ = err +} + +// TestBackendToolOps_ListByServer tests listing tools for a server +func TestBackendToolOps_ListByServer(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create multiple tools for different servers + desc1 := "Tool 1" + tool1 := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "tool_1", + Description: &desc1, + InputSchema: []byte(`{}`), + TokenCount: 10, + } + + desc2 := "Tool 2" + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "tool_2", + Description: &desc2, + InputSchema: []byte(`{}`), + TokenCount: 20, + } + + desc3 := "Tool 3" + tool3 := &models.BackendTool{ + ID: "tool-3", + MCPServerID: "server-2", + ToolName: "tool_3", + Description: &desc3, + InputSchema: []byte(`{}`), + TokenCount: 30, + } + + err := ops.Create(ctx, tool1, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool2, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool3, "Server 2") + require.NoError(t, err) + + // List tools for server-1 + tools, err := ops.ListByServer(ctx, "server-1") + require.NoError(t, err) + assert.Len(t, tools, 2, "Should have 2 tools for server-1") + + // Verify tool names + toolNames := make(map[string]bool) + for _, tool := range tools { + toolNames[tool.ToolName] = true + } + assert.True(t, toolNames["tool_1"]) + assert.True(t, toolNames["tool_2"]) + + // List tools for server-2 + tools, err = ops.ListByServer(ctx, "server-2") + require.NoError(t, err) + assert.Len(t, tools, 1, "Should have 1 tool for server-2") + assert.Equal(t, "tool_3", tools[0].ToolName) +} + +// TestBackendToolOps_ListByServer_Empty tests listing tools for server with no tools +func TestBackendToolOps_ListByServer_Empty(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // List tools for non-existent server + tools, err := ops.ListByServer(ctx, "non-existent-server") + require.NoError(t, err) + assert.Empty(t, tools, "Should return empty list for server with no tools") +} + +// TestBackendToolOps_DeleteByServer tests deleting all tools for a server +func TestBackendToolOps_DeleteByServer(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create tools for two servers + desc1 := "Tool 1" + tool1 := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "tool_1", + Description: &desc1, + InputSchema: []byte(`{}`), + TokenCount: 10, + } + + desc2 := "Tool 2" + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "tool_2", + Description: &desc2, + InputSchema: []byte(`{}`), + TokenCount: 20, + } + + desc3 := "Tool 3" + tool3 := &models.BackendTool{ + ID: "tool-3", + MCPServerID: "server-2", + ToolName: "tool_3", + Description: &desc3, + InputSchema: []byte(`{}`), + TokenCount: 30, + } + + err := ops.Create(ctx, tool1, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool2, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool3, "Server 2") + require.NoError(t, err) + + // Delete all tools for server-1 + err = ops.DeleteByServer(ctx, "server-1") + require.NoError(t, err) + + // Verify server-1 tools are deleted + tools, err := ops.ListByServer(ctx, "server-1") + require.NoError(t, err) + assert.Empty(t, tools, "All server-1 tools should be deleted") + + // Verify server-2 tools are still present + tools, err = ops.ListByServer(ctx, "server-2") + require.NoError(t, err) + assert.Len(t, tools, 1, "Server-2 tools should remain") +} + +// TestBackendToolOps_Search tests semantic search for tools +func TestBackendToolOps_Search(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create test tools + desc1 := "Get current weather conditions" + tool1 := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "get_weather", + Description: &desc1, + InputSchema: []byte(`{}`), + TokenCount: 50, + } + + desc2 := "Send email message" + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "send_email", + Description: &desc2, + InputSchema: []byte(`{}`), + TokenCount: 40, + } + + err := ops.Create(ctx, tool1, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool2, "Server 1") + require.NoError(t, err) + + // Search for tools + results, err := ops.Search(ctx, "weather information", 5, nil) + require.NoError(t, err) + assert.NotEmpty(t, results, "Should find tools") + + // With placeholder embeddings, we just verify we get results + // Semantic similarity isn't guaranteed with hash-based embeddings + assert.Len(t, results, 2, "Should return both tools") +} + +// TestBackendToolOps_Search_WithServerFilter tests search with server ID filter +func TestBackendToolOps_Search_WithServerFilter(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Create tools for different servers + desc1 := "Weather tool" + tool1 := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "get_weather", + Description: &desc1, + InputSchema: []byte(`{}`), + TokenCount: 50, + } + + desc2 := "Email tool" + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-2", + ToolName: "send_email", + Description: &desc2, + InputSchema: []byte(`{}`), + TokenCount: 40, + } + + err := ops.Create(ctx, tool1, "Server 1") + require.NoError(t, err) + err = ops.Create(ctx, tool2, "Server 2") + require.NoError(t, err) + + // Search with server filter + serverID := "server-1" + results, err := ops.Search(ctx, "tool", 5, &serverID) + require.NoError(t, err) + assert.Len(t, results, 1, "Should only return tools from server-1") + assert.Equal(t, "server-1", results[0].MCPServerID) +} + +// TestBackendToolOps_Search_Empty tests search on empty database +func TestBackendToolOps_Search_Empty(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDB(t) + defer func() { _ = db.Close() }() + + embeddingFunc := createTestEmbeddingFunc(t) + ops := NewBackendToolOps(db, embeddingFunc) + + // Search empty database + results, err := ops.Search(ctx, "anything", 5, nil) + require.NoError(t, err) + assert.Empty(t, results, "Should return empty results for empty database") +} + +// TestBackendToolOps_MetadataSerialization tests metadata serialization/deserialization +func TestBackendToolOps_MetadataSerialization(t *testing.T) { + t.Parallel() + + description := "Test tool" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &description, + InputSchema: []byte(`{"type":"object"}`), + TokenCount: 100, + } + + // Test serialization + metadata, err := serializeToolMetadata(tool) + require.NoError(t, err) + assert.Contains(t, metadata, "data") + assert.Equal(t, "backend_tool", metadata["type"]) + assert.Equal(t, "server-1", metadata["server_id"]) + + // Test deserialization + deserializedTool, err := deserializeToolMetadata(metadata) + require.NoError(t, err) + assert.Equal(t, tool.ID, deserializedTool.ID) + assert.Equal(t, tool.ToolName, deserializedTool.ToolName) + assert.Equal(t, tool.MCPServerID, deserializedTool.MCPServerID) +} + +// TestBackendToolOps_MetadataDeserialization_MissingData tests error handling +func TestBackendToolOps_MetadataDeserialization_MissingData(t *testing.T) { + t.Parallel() + + // Test with missing data field + metadata := map[string]string{ + "type": "backend_tool", + } + + _, err := deserializeToolMetadata(metadata) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing data field") +} + +// TestBackendToolOps_MetadataDeserialization_InvalidJSON tests invalid JSON handling +func TestBackendToolOps_MetadataDeserialization_InvalidJSON(t *testing.T) { + t.Parallel() + + // Test with invalid JSON + metadata := map[string]string{ + "data": "invalid json {", + "type": "backend_tool", + } + + _, err := deserializeToolMetadata(metadata) + assert.Error(t, err) +} diff --git a/pkg/optimizer/db/db.go b/pkg/optimizer/db/db.go new file mode 100644 index 0000000000..f7e7df5bb8 --- /dev/null +++ b/pkg/optimizer/db/db.go @@ -0,0 +1,182 @@ +package db + +import ( + "context" + "fmt" + "sync" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// Config holds database configuration +// +// The optimizer database is designed to be ephemeral - it's rebuilt from scratch +// on each startup by ingesting MCP backends. Persistence is optional and primarily +// useful for development/debugging to avoid re-generating embeddings. +type Config struct { + // PersistPath is the optional path for chromem-go persistence. + // If empty, chromem-go will be in-memory only (recommended for production). + PersistPath string + + // FTSDBPath is the path for SQLite FTS5 database for BM25 search. + // If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + // FTS5 is always enabled for hybrid search. + FTSDBPath string +} + +// DB represents the hybrid database (chromem-go + SQLite FTS5) for optimizer data +type DB struct { + config *Config + chromem *chromem.DB // Vector/semantic search + fts *FTSDatabase // BM25 full-text search (optional) + mu sync.RWMutex +} + +// Collection names +// +// Terminology: We use "backend_servers" and "backend_tools" to be explicit about +// tracking MCP server metadata. While vMCP uses "Backend" for the workload concept, +// the optimizer focuses on the MCP server component for semantic search and tool discovery. +// This naming convention provides clarity across the database layer. +const ( + BackendServerCollection = "backend_servers" + BackendToolCollection = "backend_tools" +) + +// NewDB creates a new chromem-go database with FTS5 for hybrid search +func NewDB(config *Config) (*DB, error) { + var chromemDB *chromem.DB + var err error + + if config.PersistPath != "" { + logger.Infof("Creating chromem-go database with persistence at: %s", config.PersistPath) + chromemDB, err = chromem.NewPersistentDB(config.PersistPath, false) + if err != nil { + return nil, fmt.Errorf("failed to create persistent database: %w", err) + } + } else { + logger.Info("Creating in-memory chromem-go database") + chromemDB = chromem.NewDB() + } + + db := &DB{ + config: config, + chromem: chromemDB, + } + + // Set default FTS5 path if not provided + ftsPath := config.FTSDBPath + if ftsPath == "" { + if config.PersistPath != "" { + // Persistent mode: store FTS5 alongside chromem-go + ftsPath = config.PersistPath + "/fts.db" + } else { + // In-memory mode: use SQLite in-memory database + ftsPath = ":memory:" + } + } + + // Initialize FTS5 database for BM25 text search (always enabled) + logger.Infof("Initializing FTS5 database for hybrid search at: %s", ftsPath) + ftsDB, err := NewFTSDatabase(&FTSConfig{DBPath: ftsPath}) + if err != nil { + return nil, fmt.Errorf("failed to create FTS5 database: %w", err) + } + db.fts = ftsDB + logger.Info("Hybrid search enabled (chromem-go + FTS5)") + + logger.Info("Optimizer database initialized successfully") + return db, nil +} + +// GetOrCreateCollection gets an existing collection or creates a new one +func (db *DB) GetOrCreateCollection( + _ context.Context, + name string, + embeddingFunc chromem.EmbeddingFunc, +) (*chromem.Collection, error) { + db.mu.Lock() + defer db.mu.Unlock() + + // Try to get existing collection first + collection := db.chromem.GetCollection(name, embeddingFunc) + if collection != nil { + return collection, nil + } + + // Create new collection if it doesn't exist + collection, err := db.chromem.CreateCollection(name, nil, embeddingFunc) + if err != nil { + return nil, fmt.Errorf("failed to create collection %s: %w", name, err) + } + + logger.Debugf("Created new collection: %s", name) + return collection, nil +} + +// GetCollection gets an existing collection +func (db *DB) GetCollection(name string, embeddingFunc chromem.EmbeddingFunc) (*chromem.Collection, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + collection := db.chromem.GetCollection(name, embeddingFunc) + if collection == nil { + return nil, fmt.Errorf("collection not found: %s", name) + } + return collection, nil +} + +// DeleteCollection deletes a collection +func (db *DB) DeleteCollection(name string) { + db.mu.Lock() + defer db.mu.Unlock() + + //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error + db.chromem.DeleteCollection(name) + logger.Debugf("Deleted collection: %s", name) +} + +// Close closes both databases +func (db *DB) Close() error { + logger.Info("Closing optimizer databases") + // chromem-go doesn't need explicit close, but FTS5 does + if db.fts != nil { + if err := db.fts.Close(); err != nil { + return fmt.Errorf("failed to close FTS database: %w", err) + } + } + return nil +} + +// GetChromemDB returns the underlying chromem.DB instance +func (db *DB) GetChromemDB() *chromem.DB { + return db.chromem +} + +// GetFTSDB returns the FTS database (may be nil if FTS is disabled) +func (db *DB) GetFTSDB() *FTSDatabase { + return db.fts +} + +// Reset clears all collections and FTS tables (useful for testing) +func (db *DB) Reset() { + db.mu.Lock() + defer db.mu.Unlock() + + //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error + db.chromem.DeleteCollection(BackendServerCollection) + //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error + db.chromem.DeleteCollection(BackendToolCollection) + + // Clear FTS5 tables if available + if db.fts != nil { + //nolint:errcheck // Best effort cleanup + _, _ = db.fts.db.Exec("DELETE FROM backend_tools_fts") + //nolint:errcheck // Best effort cleanup + _, _ = db.fts.db.Exec("DELETE FROM backend_servers_fts") + } + + logger.Debug("Reset all collections and FTS tables") +} diff --git a/pkg/optimizer/db/fts.go b/pkg/optimizer/db/fts.go new file mode 100644 index 0000000000..8dde0b2aa3 --- /dev/null +++ b/pkg/optimizer/db/fts.go @@ -0,0 +1,341 @@ +package db + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "strings" + "sync" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/optimizer/models" +) + +//go:embed schema_fts.sql +var schemaFTS string + +// FTSConfig holds FTS5 database configuration +type FTSConfig struct { + // DBPath is the path to the SQLite database file + // If empty, uses ":memory:" for in-memory database + DBPath string +} + +// FTSDatabase handles FTS5 (BM25) search operations +type FTSDatabase struct { + config *FTSConfig + db *sql.DB + mu sync.RWMutex +} + +// NewFTSDatabase creates a new FTS5 database for BM25 search +func NewFTSDatabase(config *FTSConfig) (*FTSDatabase, error) { + dbPath := config.DBPath + if dbPath == "" { + dbPath = ":memory:" + } + + // Open with modernc.org/sqlite (pure Go) + sqlDB, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, fmt.Errorf("failed to open FTS database: %w", err) + } + + // Set pragmas for performance + pragmas := []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA synchronous=NORMAL", + "PRAGMA foreign_keys=ON", + "PRAGMA busy_timeout=5000", + } + + for _, pragma := range pragmas { + if _, err := sqlDB.Exec(pragma); err != nil { + _ = sqlDB.Close() + return nil, fmt.Errorf("failed to set pragma: %w", err) + } + } + + ftsDB := &FTSDatabase{ + config: config, + db: sqlDB, + } + + // Initialize schema + if err := ftsDB.initializeSchema(); err != nil { + _ = sqlDB.Close() + return nil, fmt.Errorf("failed to initialize FTS schema: %w", err) + } + + logger.Infof("FTS5 database initialized successfully at: %s", dbPath) + return ftsDB, nil +} + +// initializeSchema creates the FTS5 tables and triggers +// +// Note: We execute the schema directly rather than using a migration framework +// because the FTS database is ephemeral (destroyed on shutdown, recreated on startup). +// Migrations are only needed when you need to preserve data across schema changes. +func (fts *FTSDatabase) initializeSchema() error { + fts.mu.Lock() + defer fts.mu.Unlock() + + _, err := fts.db.Exec(schemaFTS) + if err != nil { + return fmt.Errorf("failed to execute schema: %w", err) + } + + logger.Debug("FTS5 schema initialized") + return nil +} + +// UpsertServer inserts or updates a server in the FTS database +func (fts *FTSDatabase) UpsertServer( + ctx context.Context, + server *models.BackendServer, +) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + query := ` + INSERT INTO backend_servers_fts (id, name, description, server_group, last_updated, created_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + name = excluded.name, + description = excluded.description, + server_group = excluded.server_group, + last_updated = excluded.last_updated + ` + + _, err := fts.db.ExecContext( + ctx, + query, + server.ID, + server.Name, + server.Description, + server.Group, + server.LastUpdated, + server.CreatedAt, + ) + + if err != nil { + return fmt.Errorf("failed to upsert server in FTS: %w", err) + } + + logger.Debugf("Upserted server in FTS: %s", server.ID) + return nil +} + +// UpsertToolMeta inserts or updates a tool in the FTS database +func (fts *FTSDatabase) UpsertToolMeta( + ctx context.Context, + tool *models.BackendTool, + _ string, // serverName - unused, keeping for interface compatibility +) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + // Convert input schema to JSON string + var schemaStr *string + if len(tool.InputSchema) > 0 { + str := string(tool.InputSchema) + schemaStr = &str + } + + query := ` + INSERT INTO backend_tools_fts ( + id, mcpserver_id, tool_name, tool_description, + input_schema, token_count, last_updated, created_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + mcpserver_id = excluded.mcpserver_id, + tool_name = excluded.tool_name, + tool_description = excluded.tool_description, + input_schema = excluded.input_schema, + token_count = excluded.token_count, + last_updated = excluded.last_updated + ` + + _, err := fts.db.ExecContext( + ctx, + query, + tool.ID, + tool.MCPServerID, + tool.ToolName, + tool.Description, + schemaStr, + tool.TokenCount, + tool.LastUpdated, + tool.CreatedAt, + ) + + if err != nil { + return fmt.Errorf("failed to upsert tool in FTS: %w", err) + } + + logger.Debugf("Upserted tool in FTS: %s", tool.ToolName) + return nil +} + +// DeleteServer removes a server and its tools from FTS database +func (fts *FTSDatabase) DeleteServer(ctx context.Context, serverID string) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + // Foreign key cascade will delete related tools + _, err := fts.db.ExecContext(ctx, "DELETE FROM backend_servers_fts WHERE id = ?", serverID) + if err != nil { + return fmt.Errorf("failed to delete server from FTS: %w", err) + } + + logger.Debugf("Deleted server from FTS: %s", serverID) + return nil +} + +// DeleteToolsByServer removes all tools for a server from FTS database +func (fts *FTSDatabase) DeleteToolsByServer(ctx context.Context, serverID string) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + result, err := fts.db.ExecContext(ctx, "DELETE FROM backend_tools_fts WHERE mcpserver_id = ?", serverID) + if err != nil { + return fmt.Errorf("failed to delete tools from FTS: %w", err) + } + + count, _ := result.RowsAffected() + logger.Debugf("Deleted %d tools from FTS for server: %s", count, serverID) + return nil +} + +// DeleteTool removes a tool from FTS database +func (fts *FTSDatabase) DeleteTool(ctx context.Context, toolID string) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + _, err := fts.db.ExecContext(ctx, "DELETE FROM backend_tools_fts WHERE id = ?", toolID) + if err != nil { + return fmt.Errorf("failed to delete tool from FTS: %w", err) + } + + logger.Debugf("Deleted tool from FTS: %s", toolID) + return nil +} + +// SearchBM25 performs BM25 full-text search on tools +func (fts *FTSDatabase) SearchBM25( + ctx context.Context, + query string, + limit int, + serverID *string, +) ([]*models.BackendToolWithMetadata, error) { + fts.mu.RLock() + defer fts.mu.RUnlock() + + // Sanitize FTS5 query + sanitizedQuery := sanitizeFTS5Query(query) + if sanitizedQuery == "" { + return []*models.BackendToolWithMetadata{}, nil + } + + // Build query with optional server filter + sqlQuery := ` + SELECT + t.id, + t.mcpserver_id, + t.tool_name, + t.tool_description, + t.input_schema, + t.token_count, + t.last_updated, + t.created_at, + fts.rank + FROM backend_tool_fts_index fts + JOIN backend_tools_fts t ON fts.tool_id = t.id + WHERE backend_tool_fts_index MATCH ? + ` + + args := []interface{}{sanitizedQuery} + + if serverID != nil { + sqlQuery += " AND t.mcpserver_id = ?" + args = append(args, *serverID) + } + + sqlQuery += " ORDER BY rank LIMIT ?" + args = append(args, limit) + + rows, err := fts.db.QueryContext(ctx, sqlQuery, args...) + if err != nil { + return nil, fmt.Errorf("failed to search tools: %w", err) + } + defer func() { _ = rows.Close() }() + + var results []*models.BackendToolWithMetadata + for rows.Next() { + var tool models.BackendTool + var schemaStr sql.NullString + var rank float32 + + err := rows.Scan( + &tool.ID, + &tool.MCPServerID, + &tool.ToolName, + &tool.Description, + &schemaStr, + &tool.TokenCount, + &tool.LastUpdated, + &tool.CreatedAt, + &rank, + ) + if err != nil { + logger.Warnf("Failed to scan tool row: %v", err) + continue + } + + if schemaStr.Valid { + tool.InputSchema = []byte(schemaStr.String) + } + + // Convert BM25 rank to similarity score (higher is better) + // FTS5 rank is negative, so we negate and normalize + similarity := float32(1.0 / (1.0 - float64(rank))) + + results = append(results, &models.BackendToolWithMetadata{ + BackendTool: tool, + Similarity: similarity, + }) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating tool rows: %w", err) + } + + logger.Debugf("BM25 search found %d tools for query: %s", len(results), query) + return results, nil +} + +// Close closes the FTS database connection +func (fts *FTSDatabase) Close() error { + return fts.db.Close() +} + +// sanitizeFTS5Query escapes special characters in FTS5 queries +// FTS5 uses: " * ( ) AND OR NOT +func sanitizeFTS5Query(query string) string { + // Remove or escape special FTS5 characters + replacer := strings.NewReplacer( + `"`, `""`, // Escape quotes + `*`, ` `, // Remove wildcards + `(`, ` `, // Remove parentheses + `)`, ` `, + ) + + sanitized := replacer.Replace(query) + + // Remove multiple spaces + sanitized = strings.Join(strings.Fields(sanitized), " ") + + return strings.TrimSpace(sanitized) +} diff --git a/pkg/optimizer/db/hybrid.go b/pkg/optimizer/db/hybrid.go new file mode 100644 index 0000000000..04bbc3fd82 --- /dev/null +++ b/pkg/optimizer/db/hybrid.go @@ -0,0 +1,167 @@ +package db + +import ( + "context" + "fmt" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/optimizer/models" +) + +// HybridSearchConfig configures hybrid search behavior +type HybridSearchConfig struct { + // SemanticRatio controls the mix of semantic vs BM25 results (0.0 = all BM25, 1.0 = all semantic) + // Default: 0.7 (70% semantic, 30% BM25) + SemanticRatio float64 + + // Limit is the total number of results to return + Limit int + + // ServerID optionally filters results to a specific server + ServerID *string +} + +// DefaultHybridConfig returns sensible defaults for hybrid search +func DefaultHybridConfig() *HybridSearchConfig { + return &HybridSearchConfig{ + SemanticRatio: 0.7, + Limit: 10, + } +} + +// SearchHybrid performs hybrid search combining semantic (chromem-go) and BM25 (FTS5) results +// This matches the Python mcp-optimizer's hybrid search implementation +func (ops *BackendToolOps) SearchHybrid( + ctx context.Context, + queryText string, + config *HybridSearchConfig, +) ([]*models.BackendToolWithMetadata, error) { + if config == nil { + config = DefaultHybridConfig() + } + + // Calculate limits for each search method + semanticLimit := max(1, int(float64(config.Limit)*config.SemanticRatio)) + bm25Limit := max(1, config.Limit-semanticLimit) + + logger.Debugf( + "Hybrid search: semantic_limit=%d, bm25_limit=%d, ratio=%.2f", + semanticLimit, bm25Limit, config.SemanticRatio, + ) + + // Execute both searches in parallel + type searchResult struct { + results []*models.BackendToolWithMetadata + err error + } + + semanticCh := make(chan searchResult, 1) + bm25Ch := make(chan searchResult, 1) + + // Semantic search + go func() { + results, err := ops.Search(ctx, queryText, semanticLimit, config.ServerID) + semanticCh <- searchResult{results, err} + }() + + // BM25 search + go func() { + results, err := ops.db.fts.SearchBM25(ctx, queryText, bm25Limit, config.ServerID) + bm25Ch <- searchResult{results, err} + }() + + // Collect results + var semanticResults, bm25Results []*models.BackendToolWithMetadata + var errs []error + + // Wait for semantic results + semanticRes := <-semanticCh + if semanticRes.err != nil { + logger.Warnf("Semantic search failed: %v", semanticRes.err) + errs = append(errs, semanticRes.err) + } else { + semanticResults = semanticRes.results + } + + // Wait for BM25 results + bm25Res := <-bm25Ch + if bm25Res.err != nil { + logger.Warnf("BM25 search failed: %v", bm25Res.err) + errs = append(errs, bm25Res.err) + } else { + bm25Results = bm25Res.results + } + + // If both failed, return error + if len(errs) == 2 { + return nil, fmt.Errorf("both search methods failed: semantic=%v, bm25=%v", errs[0], errs[1]) + } + + // Combine and deduplicate results + combined := combineAndDeduplicateResults(semanticResults, bm25Results, config.Limit) + + logger.Infof( + "Hybrid search completed: semantic=%d, bm25=%d, combined=%d (requested=%d)", + len(semanticResults), len(bm25Results), len(combined), config.Limit, + ) + + return combined, nil +} + +// combineAndDeduplicateResults merges semantic and BM25 results, removing duplicates +// Keeps the result with the higher similarity score for duplicates +func combineAndDeduplicateResults( + semantic, bm25 []*models.BackendToolWithMetadata, + limit int, +) []*models.BackendToolWithMetadata { + // Use a map to deduplicate by tool ID + seen := make(map[string]*models.BackendToolWithMetadata) + + // Add semantic results first (they typically have higher quality) + for _, result := range semantic { + seen[result.ID] = result + } + + // Add BM25 results, only if not seen or if similarity is higher + for _, result := range bm25 { + if existing, exists := seen[result.ID]; exists { + // Keep the one with higher similarity + if result.Similarity > existing.Similarity { + seen[result.ID] = result + } + } else { + seen[result.ID] = result + } + } + + // Convert map to slice + combined := make([]*models.BackendToolWithMetadata, 0, len(seen)) + for _, result := range seen { + combined = append(combined, result) + } + + // Sort by similarity (descending) and limit + sortedResults := sortBySimilarity(combined) + if len(sortedResults) > limit { + sortedResults = sortedResults[:limit] + } + + return sortedResults +} + +// sortBySimilarity sorts results by similarity score in descending order +func sortBySimilarity(results []*models.BackendToolWithMetadata) []*models.BackendToolWithMetadata { + // Simple bubble sort (fine for small result sets) + sorted := make([]*models.BackendToolWithMetadata, len(results)) + copy(sorted, results) + + for i := 0; i < len(sorted); i++ { + for j := i + 1; j < len(sorted); j++ { + if sorted[j].Similarity > sorted[i].Similarity { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + + return sorted +} diff --git a/pkg/optimizer/db/schema_fts.sql b/pkg/optimizer/db/schema_fts.sql new file mode 100644 index 0000000000..101dbea7d7 --- /dev/null +++ b/pkg/optimizer/db/schema_fts.sql @@ -0,0 +1,120 @@ +-- FTS5 schema for BM25 full-text search +-- Complements chromem-go (which handles vector/semantic search) +-- +-- This schema only contains: +-- 1. Metadata tables for tool/server information +-- 2. FTS5 virtual tables for BM25 keyword search +-- +-- Note: chromem-go handles embeddings separately in memory/persistent storage + +-- Backend servers metadata (for FTS queries and joining) +CREATE TABLE IF NOT EXISTS backend_servers_fts ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + server_group TEXT NOT NULL DEFAULT 'default', + last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_backend_servers_fts_group ON backend_servers_fts(server_group); + +-- Backend tools metadata (for FTS queries and joining) +CREATE TABLE IF NOT EXISTS backend_tools_fts ( + id TEXT PRIMARY KEY, + mcpserver_id TEXT NOT NULL, + tool_name TEXT NOT NULL, + tool_description TEXT, + input_schema TEXT, -- JSON string + token_count INTEGER NOT NULL DEFAULT 0, + last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (mcpserver_id) REFERENCES backend_servers_fts(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_backend_tools_fts_server ON backend_tools_fts(mcpserver_id); +CREATE INDEX IF NOT EXISTS idx_backend_tools_fts_name ON backend_tools_fts(tool_name); + +-- FTS5 virtual table for backend tools +-- Uses Porter stemming for better keyword matching +-- Indexes: server name, tool name, and tool description +CREATE VIRTUAL TABLE IF NOT EXISTS backend_tool_fts_index +USING fts5( + tool_id UNINDEXED, + mcp_server_name, + tool_name, + tool_description, + tokenize='porter', + content='backend_tools_fts', + content_rowid='rowid' +); + +-- Triggers to keep FTS5 index in sync with backend_tools_fts table +CREATE TRIGGER IF NOT EXISTS backend_tools_fts_ai AFTER INSERT ON backend_tools_fts BEGIN + INSERT INTO backend_tool_fts_index( + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) + SELECT + rowid, + new.id, + (SELECT name FROM backend_servers_fts WHERE id = new.mcpserver_id), + new.tool_name, + COALESCE(new.tool_description, '') + FROM backend_tools_fts + WHERE id = new.id; +END; + +CREATE TRIGGER IF NOT EXISTS backend_tools_fts_ad AFTER DELETE ON backend_tools_fts BEGIN + INSERT INTO backend_tool_fts_index( + backend_tool_fts_index, + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) VALUES ( + 'delete', + old.rowid, + old.id, + NULL, + NULL, + NULL + ); +END; + +CREATE TRIGGER IF NOT EXISTS backend_tools_fts_au AFTER UPDATE ON backend_tools_fts BEGIN + INSERT INTO backend_tool_fts_index( + backend_tool_fts_index, + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) VALUES ( + 'delete', + old.rowid, + old.id, + NULL, + NULL, + NULL + ); + INSERT INTO backend_tool_fts_index( + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) + SELECT + rowid, + new.id, + (SELECT name FROM backend_servers_fts WHERE id = new.mcpserver_id), + new.tool_name, + COALESCE(new.tool_description, '') + FROM backend_tools_fts + WHERE id = new.id; +END; diff --git a/pkg/optimizer/db/sqlite_fts.go b/pkg/optimizer/db/sqlite_fts.go new file mode 100644 index 0000000000..a4a3c9e421 --- /dev/null +++ b/pkg/optimizer/db/sqlite_fts.go @@ -0,0 +1,8 @@ +// Package db provides database operations for the optimizer. +// This file handles FTS5 (Full-Text Search) using modernc.org/sqlite (pure Go). +package db + +import ( + // Pure Go SQLite driver with FTS5 support + _ "modernc.org/sqlite" +) diff --git a/pkg/optimizer/doc.go b/pkg/optimizer/doc.go new file mode 100644 index 0000000000..0808bb76b2 --- /dev/null +++ b/pkg/optimizer/doc.go @@ -0,0 +1,83 @@ +// Package optimizer provides semantic tool discovery and ingestion for MCP servers. +// +// The optimizer package implements an ingestion service that discovers MCP backends +// from ToolHive, generates semantic embeddings for tools using ONNX Runtime, and stores +// them in a SQLite database with vector search capabilities. +// +// # Architecture +// +// The optimizer follows a similar architecture to mcp-optimizer (Python) but adapted +// for Go idioms and patterns: +// +// pkg/optimizer/ +// ├── doc.go // Package documentation +// ├── models/ // Database models and types +// │ ├── models.go // Core domain models (Server, Tool, etc.) +// │ └── transport.go // Transport and status enums +// ├── db/ // Database layer +// │ ├── db.go // Database connection and config +// │ ├── fts.go // FTS5 database for BM25 search +// │ ├── schema_fts.sql // Embedded FTS5 schema (executed directly) +// │ ├── hybrid.go // Hybrid search (semantic + BM25) +// │ ├── backend_server.go // Backend server operations +// │ └── backend_tool.go // Backend tool operations +// ├── embeddings/ // Embedding generation +// │ ├── manager.go // Embedding manager with ONNX Runtime +// │ └── cache.go // Optional embedding cache +// ├── mcpclient/ // MCP client for tool discovery +// │ └── client.go // MCP client wrapper +// ├── ingestion/ // Core ingestion service +// │ ├── service.go // Ingestion service implementation +// │ └── errors.go // Custom errors +// └── tokens/ // Token counting (for LLM consumption) +// └── counter.go // Token counter using tiktoken-go +// +// # Core Concepts +// +// **Ingestion**: Discovers MCP backends from ToolHive (via Docker or Kubernetes), +// connects to each backend to list tools, generates embeddings, and stores in database. +// +// **Embeddings**: Uses ONNX Runtime to generate semantic embeddings for tools and servers. +// Embeddings enable semantic search to find relevant tools based on natural language queries. +// +// **Database**: Hybrid approach using chromem-go for vector search and SQLite FTS5 for +// keyword search. The database is ephemeral (in-memory by default, optional persistence) +// and schema is initialized directly on startup without migrations. +// +// **Terminology**: Uses "BackendServer" and "BackendTool" to explicitly refer to MCP server +// metadata, distinguishing from vMCP's broader "Backend" concept which represents workloads. +// +// **Token Counting**: Tracks token counts for tools to measure LLM consumption and +// calculate token savings from semantic filtering. +// +// # Usage +// +// The optimizer is integrated into vMCP as native tools: +// +// 1. **vMCP Integration**: The optimizer runs as part of vMCP, exposing +// optim.find_tool and optim.call_tool to clients. +// +// 2. **Event-Driven Ingestion**: Tools are ingested when vMCP sessions +// are registered, not via polling. +// +// Example vMCP integration (see pkg/vmcp/optimizer): +// +// import ( +// "github.com/stacklok/toolhive/pkg/optimizer/ingestion" +// "github.com/stacklok/toolhive/pkg/optimizer/embeddings" +// ) +// +// // Create embedding manager +// embMgr, err := embeddings.NewManager(embeddings.Config{ +// BackendType: "placeholder", // or "ollama" or "openai-compatible" +// Dimension: 384, +// }) +// +// // Create ingestion service +// svc, err := ingestion.NewService(ctx, ingestion.Config{ +// DBConfig: dbConfig, +// }, embMgr) +// +// // Ingest a server (called by vMCP's OnRegisterSession hook) +// err = svc.IngestServer(ctx, "weather-service", tools, target) +package optimizer diff --git a/pkg/optimizer/embeddings/cache.go b/pkg/optimizer/embeddings/cache.go new file mode 100644 index 0000000000..7638939f5e --- /dev/null +++ b/pkg/optimizer/embeddings/cache.go @@ -0,0 +1,101 @@ +// Package embeddings provides caching for embedding vectors. +package embeddings + +import ( + "container/list" + "sync" +) + +// cache implements an LRU cache for embeddings +type cache struct { + maxSize int + mu sync.RWMutex + items map[string]*list.Element + lru *list.List + hits int64 + misses int64 +} + +type cacheEntry struct { + key string + value []float32 +} + +// newCache creates a new LRU cache +func newCache(maxSize int) *cache { + return &cache{ + maxSize: maxSize, + items: make(map[string]*list.Element), + lru: list.New(), + } +} + +// Get retrieves an embedding from the cache +func (c *cache) Get(key string) []float32 { + c.mu.Lock() + defer c.mu.Unlock() + + elem, ok := c.items[key] + if !ok { + c.misses++ + return nil + } + + c.hits++ + c.lru.MoveToFront(elem) + return elem.Value.(*cacheEntry).value +} + +// Put stores an embedding in the cache +func (c *cache) Put(key string, value []float32) { + c.mu.Lock() + defer c.mu.Unlock() + + // Check if key already exists + if elem, ok := c.items[key]; ok { + c.lru.MoveToFront(elem) + elem.Value.(*cacheEntry).value = value + return + } + + // Add new entry + entry := &cacheEntry{ + key: key, + value: value, + } + elem := c.lru.PushFront(entry) + c.items[key] = elem + + // Evict if necessary + if c.lru.Len() > c.maxSize { + c.evict() + } +} + +// evict removes the least recently used item +func (c *cache) evict() { + elem := c.lru.Back() + if elem != nil { + c.lru.Remove(elem) + entry := elem.Value.(*cacheEntry) + delete(c.items, entry.key) + } +} + +// Size returns the current cache size +func (c *cache) Size() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.lru.Len() +} + +// Clear clears the cache +func (c *cache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + c.items = make(map[string]*list.Element) + c.lru = list.New() + c.hits = 0 + c.misses = 0 +} diff --git a/pkg/optimizer/embeddings/cache_test.go b/pkg/optimizer/embeddings/cache_test.go new file mode 100644 index 0000000000..9992d64605 --- /dev/null +++ b/pkg/optimizer/embeddings/cache_test.go @@ -0,0 +1,169 @@ +package embeddings + +import ( + "testing" +) + +func TestCache_GetPut(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Test cache miss + result := c.Get("key1") + if result != nil { + t.Error("Expected cache miss for non-existent key") + } + if c.misses != 1 { + t.Errorf("Expected 1 miss, got %d", c.misses) + } + + // Test cache put and hit + embedding := []float32{1.0, 2.0, 3.0} + c.Put("key1", embedding) + + result = c.Get("key1") + if result == nil { + t.Fatal("Expected cache hit for existing key") + } + if c.hits != 1 { + t.Errorf("Expected 1 hit, got %d", c.hits) + } + + // Verify embedding values + if len(result) != len(embedding) { + t.Errorf("Embedding length mismatch: got %d, want %d", len(result), len(embedding)) + } + for i := range embedding { + if result[i] != embedding[i] { + t.Errorf("Embedding value mismatch at index %d: got %f, want %f", i, result[i], embedding[i]) + } + } +} + +func TestCache_LRUEviction(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Add two items (fills cache) + c.Put("key1", []float32{1.0}) + c.Put("key2", []float32{2.0}) + + if c.Size() != 2 { + t.Errorf("Expected cache size 2, got %d", c.Size()) + } + + // Add third item (should evict key1) + c.Put("key3", []float32{3.0}) + + if c.Size() != 2 { + t.Errorf("Expected cache size 2 after eviction, got %d", c.Size()) + } + + // key1 should be evicted (oldest) + if result := c.Get("key1"); result != nil { + t.Error("key1 should have been evicted") + } + + // key2 and key3 should still exist + if result := c.Get("key2"); result == nil { + t.Error("key2 should still exist") + } + if result := c.Get("key3"); result == nil { + t.Error("key3 should still exist") + } +} + +func TestCache_MoveToFrontOnAccess(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Add two items + c.Put("key1", []float32{1.0}) + c.Put("key2", []float32{2.0}) + + // Access key1 (moves it to front) + c.Get("key1") + + // Add third item (should evict key2, not key1) + c.Put("key3", []float32{3.0}) + + // key1 should still exist (was accessed recently) + if result := c.Get("key1"); result == nil { + t.Error("key1 should still exist (was accessed recently)") + } + + // key2 should be evicted (was oldest) + if result := c.Get("key2"); result != nil { + t.Error("key2 should have been evicted") + } + + // key3 should exist + if result := c.Get("key3"); result == nil { + t.Error("key3 should exist") + } +} + +func TestCache_UpdateExistingKey(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Add initial value + c.Put("key1", []float32{1.0}) + + // Update with new value + newEmbedding := []float32{2.0, 3.0} + c.Put("key1", newEmbedding) + + // Should get updated value + result := c.Get("key1") + if result == nil { + t.Fatal("Expected cache hit for existing key") + } + + if len(result) != len(newEmbedding) { + t.Errorf("Embedding length mismatch: got %d, want %d", len(result), len(newEmbedding)) + } + + // Cache size should still be 1 + if c.Size() != 1 { + t.Errorf("Expected cache size 1, got %d", c.Size()) + } +} + +func TestCache_Clear(t *testing.T) { + t.Parallel() + c := newCache(10) + + // Add some items + c.Put("key1", []float32{1.0}) + c.Put("key2", []float32{2.0}) + c.Put("key3", []float32{3.0}) + + // Access some items to generate stats + c.Get("key1") + c.Get("missing") + + if c.Size() != 3 { + t.Errorf("Expected cache size 3, got %d", c.Size()) + } + + // Clear cache + c.Clear() + + if c.Size() != 0 { + t.Errorf("Expected cache size 0 after clear, got %d", c.Size()) + } + + // Stats should be reset + if c.hits != 0 { + t.Errorf("Expected 0 hits after clear, got %d", c.hits) + } + if c.misses != 0 { + t.Errorf("Expected 0 misses after clear, got %d", c.misses) + } + + // Items should be gone + if result := c.Get("key1"); result != nil { + t.Error("key1 should be gone after clear") + } +} diff --git a/pkg/optimizer/embeddings/manager.go b/pkg/optimizer/embeddings/manager.go new file mode 100644 index 0000000000..9ccc94fca3 --- /dev/null +++ b/pkg/optimizer/embeddings/manager.go @@ -0,0 +1,281 @@ +package embeddings + +import ( + "fmt" + "sync" + + "github.com/stacklok/toolhive/pkg/logger" +) + +const ( + // BackendTypePlaceholder is the placeholder backend type + BackendTypePlaceholder = "placeholder" +) + +// Config holds configuration for the embedding manager +type Config struct { + // BackendType specifies which backend to use: + // - "ollama": Ollama native API + // - "vllm": vLLM OpenAI-compatible API + // - "unified": Generic OpenAI-compatible API (works with both) + // - "placeholder": Hash-based embeddings for testing + BackendType string + + // BaseURL is the base URL for the embedding service + // - Ollama: http://localhost:11434 + // - vLLM: http://localhost:8000 + BaseURL string + + // Model is the model name to use + // - Ollama: "nomic-embed-text", "all-minilm" + // - vLLM: "sentence-transformers/all-MiniLM-L6-v2", "intfloat/e5-mistral-7b-instruct" + Model string + + // Dimension is the embedding dimension (default 384 for all-MiniLM-L6-v2) + Dimension int + + // EnableCache enables caching of embeddings + EnableCache bool + + // MaxCacheSize is the maximum number of embeddings to cache (default 1000) + MaxCacheSize int +} + +// Backend interface for different embedding implementations +type Backend interface { + Embed(text string) ([]float32, error) + EmbedBatch(texts []string) ([][]float32, error) + Dimension() int + Close() error +} + +// Manager manages embedding generation using pluggable backends +// Default backend is all-MiniLM-L6-v2 (same model as codegate) +type Manager struct { + config *Config + backend Backend + cache *cache + mu sync.RWMutex +} + +// NewManager creates a new embedding manager +func NewManager(config *Config) (*Manager, error) { + if config.Dimension == 0 { + config.Dimension = 384 // Default dimension for all-MiniLM-L6-v2 + } + + if config.MaxCacheSize == 0 { + config.MaxCacheSize = 1000 + } + + // Default to placeholder (zero dependencies) + if config.BackendType == "" { + config.BackendType = "placeholder" + } + + // Initialize backend based on configuration + var backend Backend + var err error + + switch config.BackendType { + case "ollama": + // Use Ollama native API (requires ollama serve) + baseURL := config.BaseURL + if baseURL == "" { + baseURL = "http://localhost:11434" + } + model := config.Model + if model == "" { + model = "nomic-embed-text" + } + backend, err = NewOllamaBackend(baseURL, model) + if err != nil { + logger.Warnf("Failed to initialize Ollama backend: %v", err) + logger.Info("Falling back to placeholder embeddings. To use Ollama: ollama serve && ollama pull nomic-embed-text") + backend = &PlaceholderBackend{dimension: config.Dimension} + } + + case "vllm", "unified", "openai": + // Use OpenAI-compatible API + // vLLM is recommended for production Kubernetes deployments (GPU-accelerated, high-throughput) + // Also supports: Ollama v1 API, OpenAI, or any OpenAI-compatible service + if config.BaseURL == "" { + return nil, fmt.Errorf("BaseURL is required for %s backend", config.BackendType) + } + if config.Model == "" { + return nil, fmt.Errorf("model is required for %s backend", config.BackendType) + } + backend, err = NewOpenAICompatibleBackend(config.BaseURL, config.Model, config.Dimension) + if err != nil { + logger.Warnf("Failed to initialize %s backend: %v", config.BackendType, err) + logger.Infof("Falling back to placeholder embeddings") + backend = &PlaceholderBackend{dimension: config.Dimension} + } + + case BackendTypePlaceholder: + // Use placeholder for testing + backend = &PlaceholderBackend{dimension: config.Dimension} + + default: + return nil, fmt.Errorf("unknown backend type: %s (supported: ollama, vllm, unified, placeholder)", config.BackendType) + } + + m := &Manager{ + config: config, + backend: backend, + } + + if config.EnableCache { + m.cache = newCache(config.MaxCacheSize) + } + + return m, nil +} + +// GenerateEmbedding generates embeddings for the given texts +// Returns a 2D slice where each row is an embedding for the corresponding text +// Uses all-MiniLM-L6-v2 by default (same model as codegate) +func (m *Manager) GenerateEmbedding(texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no texts provided") + } + + // Check cache for single text requests + if len(texts) == 1 && m.config.EnableCache && m.cache != nil { + if cached := m.cache.Get(texts[0]); cached != nil { + logger.Debugf("Cache hit for embedding") + return [][]float32{cached}, nil + } + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Use backend to generate embeddings + embeddings, err := m.backend.EmbedBatch(texts) + if err != nil { + // If backend fails, fall back to placeholder for non-placeholder backends + if m.config.BackendType != "placeholder" { + logger.Warnf("%s backend failed: %v, falling back to placeholder", m.config.BackendType, err) + placeholder := &PlaceholderBackend{dimension: m.config.Dimension} + embeddings, err = placeholder.EmbedBatch(texts) + if err != nil { + return nil, fmt.Errorf("failed to generate embeddings: %w", err) + } + } else { + return nil, fmt.Errorf("failed to generate embeddings: %w", err) + } + } + + // Cache single embeddings + if len(texts) == 1 && m.config.EnableCache && m.cache != nil { + m.cache.Put(texts[0], embeddings[0]) + } + + logger.Debugf("Generated %d embeddings (dimension: %d)", len(embeddings), m.backend.Dimension()) + return embeddings, nil +} + +// PlaceholderBackend is a simple backend for testing +type PlaceholderBackend struct { + dimension int +} + +// Embed generates a deterministic hash-based embedding for the given text. +func (p *PlaceholderBackend) Embed(text string) ([]float32, error) { + return p.generatePlaceholderEmbedding(text), nil +} + +// EmbedBatch generates embeddings for multiple texts. +func (p *PlaceholderBackend) EmbedBatch(texts []string) ([][]float32, error) { + embeddings := make([][]float32, len(texts)) + for i, text := range texts { + embeddings[i] = p.generatePlaceholderEmbedding(text) + } + return embeddings, nil +} + +// Dimension returns the embedding dimension. +func (p *PlaceholderBackend) Dimension() int { + return p.dimension +} + +// Close closes the backend (no-op for placeholder). +func (*PlaceholderBackend) Close() error { + return nil +} + +func (p *PlaceholderBackend) generatePlaceholderEmbedding(text string) []float32 { + embedding := make([]float32, p.dimension) + + // Simple hash-based generation for testing + hash := 0 + for _, c := range text { + hash = (hash*31 + int(c)) % 1000000 + } + + // Generate deterministic values + for i := range embedding { + hash = (hash*1103515245 + 12345) % 1000000 + embedding[i] = float32(hash) / 1000000.0 + } + + // Normalize the embedding (L2 normalization) + var norm float32 + for _, v := range embedding { + norm += v * v + } + if norm > 0 { + norm = float32(1.0 / float64(norm)) + for i := range embedding { + embedding[i] *= norm + } + } + + return embedding +} + +// GetCacheStats returns cache statistics +func (m *Manager) GetCacheStats() map[string]interface{} { + if !m.config.EnableCache || m.cache == nil { + return map[string]interface{}{ + "enabled": false, + } + } + + return map[string]interface{}{ + "enabled": true, + "hits": m.cache.hits, + "misses": m.cache.misses, + "size": m.cache.Size(), + "maxsize": m.config.MaxCacheSize, + } +} + +// ClearCache clears the embedding cache +func (m *Manager) ClearCache() { + if m.config.EnableCache && m.cache != nil { + m.cache.Clear() + logger.Info("Embedding cache cleared") + } +} + +// Close releases resources +func (m *Manager) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.backend != nil { + return m.backend.Close() + } + + return nil +} + +// Dimension returns the embedding dimension +func (m *Manager) Dimension() int { + if m.backend != nil { + return m.backend.Dimension() + } + return m.config.Dimension +} diff --git a/pkg/optimizer/embeddings/ollama.go b/pkg/optimizer/embeddings/ollama.go new file mode 100644 index 0000000000..d6f4874375 --- /dev/null +++ b/pkg/optimizer/embeddings/ollama.go @@ -0,0 +1,128 @@ +package embeddings + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// OllamaBackend implements the Backend interface using Ollama +// This provides local embeddings without remote API calls +// Ollama must be running locally (ollama serve) +type OllamaBackend struct { + baseURL string + model string + dimension int + client *http.Client +} + +type ollamaEmbedRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type ollamaEmbedResponse struct { + Embedding []float64 `json:"embedding"` +} + +// NewOllamaBackend creates a new Ollama backend +// Requires Ollama to be running locally: ollama serve +// Default model: nomic-embed-text (768 dimensions) +func NewOllamaBackend(baseURL, model string) (*OllamaBackend, error) { + if baseURL == "" { + baseURL = "http://localhost:11434" + } + if model == "" { + model = "nomic-embed-text" // Default embedding model + } + + logger.Infof("Initializing Ollama backend (model: %s, url: %s)", model, baseURL) + + backend := &OllamaBackend{ + baseURL: baseURL, + model: model, + dimension: 768, // nomic-embed-text dimension + client: &http.Client{}, + } + + // Test connection + resp, err := backend.client.Get(baseURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to Ollama at %s: %w (is 'ollama serve' running?)", baseURL, err) + } + _ = resp.Body.Close() + + logger.Info("Successfully connected to Ollama") + return backend, nil +} + +// Embed generates an embedding for a single text +func (o *OllamaBackend) Embed(text string) ([]float32, error) { + reqBody := ollamaEmbedRequest{ + Model: o.model, + Prompt: text, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + resp, err := o.client.Post( + o.baseURL+"/api/embeddings", + "application/json", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return nil, fmt.Errorf("failed to call Ollama API: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("ollama API returned status %d: %s", resp.StatusCode, string(body)) + } + + var embedResp ollamaEmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // Convert []float64 to []float32 + embedding := make([]float32, len(embedResp.Embedding)) + for i, v := range embedResp.Embedding { + embedding[i] = float32(v) + } + + return embedding, nil +} + +// EmbedBatch generates embeddings for multiple texts +func (o *OllamaBackend) EmbedBatch(texts []string) ([][]float32, error) { + embeddings := make([][]float32, len(texts)) + + for i, text := range texts { + emb, err := o.Embed(text) + if err != nil { + return nil, fmt.Errorf("failed to embed text %d: %w", i, err) + } + embeddings[i] = emb + } + + return embeddings, nil +} + +// Dimension returns the embedding dimension +func (o *OllamaBackend) Dimension() int { + return o.dimension +} + +// Close releases any resources +func (*OllamaBackend) Close() error { + // HTTP client doesn't need explicit cleanup + return nil +} diff --git a/pkg/optimizer/embeddings/ollama_test.go b/pkg/optimizer/embeddings/ollama_test.go new file mode 100644 index 0000000000..5254b7c072 --- /dev/null +++ b/pkg/optimizer/embeddings/ollama_test.go @@ -0,0 +1,106 @@ +package embeddings + +import ( + "testing" +) + +func TestOllamaBackend_Placeholder(t *testing.T) { + t.Parallel() + // This test verifies that Ollama backend is properly structured + // Actual Ollama tests require ollama to be running + + // Test that NewOllamaBackend handles connection failure gracefully + _, err := NewOllamaBackend("http://localhost:99999", "nomic-embed-text") + if err == nil { + t.Error("Expected error when connecting to invalid Ollama URL") + } +} + +func TestManagerWithOllama(t *testing.T) { + t.Parallel() + // Test that Manager falls back to placeholder when Ollama is not available or model not pulled + config := &Config{ + BackendType: "ollama", + Dimension: 384, + EnableCache: true, + MaxCacheSize: 100, + } + + manager, err := NewManager(config) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + // Should work with placeholder backend fallback + // (Ollama might not have model pulled, so it falls back to placeholder) + embeddings, err := manager.GenerateEmbedding([]string{"test text"}) + + // If Ollama is available with the model, great! + // If not, it should have fallen back to placeholder + if err != nil { + // Check if it's a "model not found" error - this is expected + if embeddings == nil { + t.Skip("Ollama not available or model not pulled (expected in CI/test environments)") + } + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } + + // Dimension could be 384 (placeholder) or 768 (Ollama nomic-embed-text) + if len(embeddings[0]) != 384 && len(embeddings[0]) != 768 { + t.Errorf("Expected dimension 384 or 768, got %d", len(embeddings[0])) + } +} + +func TestManagerWithPlaceholder(t *testing.T) { + t.Parallel() + // Test explicit placeholder backend + config := &Config{ + BackendType: "placeholder", + Dimension: 384, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + // Test single embedding + embeddings, err := manager.GenerateEmbedding([]string{"hello world"}) + if err != nil { + t.Fatalf("Failed to generate embedding: %v", err) + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } + + if len(embeddings[0]) != 384 { + t.Errorf("Expected dimension 384, got %d", len(embeddings[0])) + } + + // Test batch embeddings + texts := []string{"text 1", "text 2", "text 3"} + embeddings, err = manager.GenerateEmbedding(texts) + if err != nil { + t.Fatalf("Failed to generate batch embeddings: %v", err) + } + + if len(embeddings) != 3 { + t.Errorf("Expected 3 embeddings, got %d", len(embeddings)) + } + + // Verify embeddings are deterministic + embeddings2, _ := manager.GenerateEmbedding([]string{"text 1"}) + for i := range embeddings[0] { + if embeddings[0][i] != embeddings2[0][i] { + t.Error("Embeddings should be deterministic") + break + } + } +} diff --git a/pkg/optimizer/embeddings/openai_compatible.go b/pkg/optimizer/embeddings/openai_compatible.go new file mode 100644 index 0000000000..8a86129d56 --- /dev/null +++ b/pkg/optimizer/embeddings/openai_compatible.go @@ -0,0 +1,149 @@ +package embeddings + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// OpenAICompatibleBackend implements the Backend interface for OpenAI-compatible APIs. +// +// Supported Services: +// - vLLM: Recommended for production Kubernetes deployments +// - High-throughput GPU-accelerated inference +// - PagedAttention for efficient GPU memory utilization +// - Superior scalability for multi-user environments +// - Ollama: Good for local development (via /v1/embeddings endpoint) +// - OpenAI: For cloud-based embeddings +// - Any OpenAI-compatible embedding service +// +// For production deployments, vLLM is strongly recommended due to its performance +// characteristics and Kubernetes-native design. +type OpenAICompatibleBackend struct { + baseURL string + model string + dimension int + client *http.Client +} + +type openaiEmbedRequest struct { + Model string `json:"model"` + Input string `json:"input"` // OpenAI standard uses "input" +} + +type openaiEmbedResponse struct { + Object string `json:"object"` + Data []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + Model string `json:"model"` +} + +// NewOpenAICompatibleBackend creates a new OpenAI-compatible backend. +// +// Examples: +// - vLLM: NewOpenAICompatibleBackend("http://vllm-service:8000", "sentence-transformers/all-MiniLM-L6-v2", 384) +// - Ollama: NewOpenAICompatibleBackend("http://localhost:11434", "nomic-embed-text", 768) +// - OpenAI: NewOpenAICompatibleBackend("https://api.openai.com", "text-embedding-3-small", 1536) +func NewOpenAICompatibleBackend(baseURL, model string, dimension int) (*OpenAICompatibleBackend, error) { + if baseURL == "" { + return nil, fmt.Errorf("baseURL is required for OpenAI-compatible backend") + } + if model == "" { + return nil, fmt.Errorf("model is required for OpenAI-compatible backend") + } + if dimension == 0 { + dimension = 384 // Default dimension + } + + logger.Infof("Initializing OpenAI-compatible backend (model: %s, url: %s)", model, baseURL) + + backend := &OpenAICompatibleBackend{ + baseURL: baseURL, + model: model, + dimension: dimension, + client: &http.Client{}, + } + + // Test connection + resp, err := backend.client.Get(baseURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to %s: %w", baseURL, err) + } + _ = resp.Body.Close() + + logger.Info("Successfully connected to OpenAI-compatible service") + return backend, nil +} + +// Embed generates an embedding for a single text using OpenAI-compatible API +func (o *OpenAICompatibleBackend) Embed(text string) ([]float32, error) { + reqBody := openaiEmbedRequest{ + Model: o.model, + Input: text, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Use standard OpenAI v1 endpoint + resp, err := o.client.Post( + o.baseURL+"/v1/embeddings", + "application/json", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return nil, fmt.Errorf("failed to call embeddings API: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + } + + var embedResp openaiEmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(embedResp.Data) == 0 { + return nil, fmt.Errorf("no embeddings in response") + } + + return embedResp.Data[0].Embedding, nil +} + +// EmbedBatch generates embeddings for multiple texts +func (o *OpenAICompatibleBackend) EmbedBatch(texts []string) ([][]float32, error) { + embeddings := make([][]float32, len(texts)) + + for i, text := range texts { + emb, err := o.Embed(text) + if err != nil { + return nil, fmt.Errorf("failed to embed text %d: %w", i, err) + } + embeddings[i] = emb + } + + return embeddings, nil +} + +// Dimension returns the embedding dimension +func (o *OpenAICompatibleBackend) Dimension() int { + return o.dimension +} + +// Close releases any resources +func (*OpenAICompatibleBackend) Close() error { + // HTTP client doesn't need explicit cleanup + return nil +} diff --git a/pkg/optimizer/embeddings/openai_compatible_test.go b/pkg/optimizer/embeddings/openai_compatible_test.go new file mode 100644 index 0000000000..916ad0cb8f --- /dev/null +++ b/pkg/optimizer/embeddings/openai_compatible_test.go @@ -0,0 +1,235 @@ +package embeddings + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +const testEmbeddingsEndpoint = "/v1/embeddings" + +func TestOpenAICompatibleBackend(t *testing.T) { + t.Parallel() + // Create a test server that mimics OpenAI-compatible API + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == testEmbeddingsEndpoint { + var req openaiEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("Failed to decode request: %v", err) + } + + // Return a mock embedding response + resp := openaiEmbedResponse{ + Object: "list", + Data: []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + }{ + { + Object: "embedding", + Embedding: make([]float32, 384), + Index: 0, + }, + }, + Model: req.Model, + } + + // Fill with test data + for i := range resp.Data[0].Embedding { + resp.Data[0].Embedding[i] = float32(i) / 384.0 + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + + // Health check endpoint + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Test backend creation + backend, err := NewOpenAICompatibleBackend(server.URL, "test-model", 384) + if err != nil { + t.Fatalf("Failed to create backend: %v", err) + } + defer backend.Close() + + // Test embedding generation + embedding, err := backend.Embed("test text") + if err != nil { + t.Fatalf("Failed to generate embedding: %v", err) + } + + if len(embedding) != 384 { + t.Errorf("Expected embedding dimension 384, got %d", len(embedding)) + } + + // Test batch embedding + texts := []string{"text1", "text2", "text3"} + embeddings, err := backend.EmbedBatch(texts) + if err != nil { + t.Fatalf("Failed to generate batch embeddings: %v", err) + } + + if len(embeddings) != len(texts) { + t.Errorf("Expected %d embeddings, got %d", len(texts), len(embeddings)) + } +} + +func TestOpenAICompatibleBackendErrors(t *testing.T) { + t.Parallel() + // Test missing baseURL + _, err := NewOpenAICompatibleBackend("", "model", 384) + if err == nil { + t.Error("Expected error for missing baseURL") + } + + // Test missing model + _, err = NewOpenAICompatibleBackend("http://localhost:8000", "", 384) + if err == nil { + t.Error("Expected error for missing model") + } +} + +func TestManagerWithVLLM(t *testing.T) { + t.Parallel() + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == testEmbeddingsEndpoint { + resp := openaiEmbedResponse{ + Object: "list", + Data: []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + }{ + { + Object: "embedding", + Embedding: make([]float32, 384), + Index: 0, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Test manager with vLLM backend + config := &Config{ + BackendType: "vllm", + BaseURL: server.URL, + Model: "sentence-transformers/all-MiniLM-L6-v2", + Dimension: 384, + EnableCache: true, + } + + manager, err := NewManager(config) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + // Test embedding generation + embeddings, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Fatalf("Failed to generate embeddings: %v", err) + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } + if len(embeddings[0]) != 384 { + t.Errorf("Expected dimension 384, got %d", len(embeddings[0])) + } +} + +func TestManagerWithUnified(t *testing.T) { + t.Parallel() + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == testEmbeddingsEndpoint { + resp := openaiEmbedResponse{ + Object: "list", + Data: []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + }{ + { + Object: "embedding", + Embedding: make([]float32, 768), + Index: 0, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Test manager with unified backend + config := &Config{ + BackendType: "unified", + BaseURL: server.URL, + Model: "nomic-embed-text", + Dimension: 768, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + // Test embedding generation + embeddings, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Fatalf("Failed to generate embeddings: %v", err) + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } +} + +func TestManagerFallbackBehavior(t *testing.T) { + t.Parallel() + // Test that invalid vLLM backend falls back to placeholder + config := &Config{ + BackendType: "vllm", + BaseURL: "http://invalid-host-that-does-not-exist:99999", + Model: "test-model", + Dimension: 384, + } + + manager, err := NewManager(config) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + // Should still work with placeholder fallback + embeddings, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Fatalf("Failed to generate embeddings with fallback: %v", err) + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } + if len(embeddings[0]) != 384 { + t.Errorf("Expected dimension 384, got %d", len(embeddings[0])) + } +} diff --git a/pkg/optimizer/ingestion/errors.go b/pkg/optimizer/ingestion/errors.go new file mode 100644 index 0000000000..cb33a97dcb --- /dev/null +++ b/pkg/optimizer/ingestion/errors.go @@ -0,0 +1,21 @@ +// Package ingestion provides services for ingesting MCP tools into the database. +package ingestion + +import "errors" + +var ( + // ErrIngestionFailed is returned when ingestion fails + ErrIngestionFailed = errors.New("ingestion failed") + + // ErrBackendRetrievalFailed is returned when backend retrieval fails + ErrBackendRetrievalFailed = errors.New("backend retrieval failed") + + // ErrToolHiveUnavailable is returned when ToolHive is unavailable + ErrToolHiveUnavailable = errors.New("ToolHive unavailable") + + // ErrBackendStatusNil is returned when backend status is nil + ErrBackendStatusNil = errors.New("backend status cannot be nil") + + // ErrInvalidRuntimeMode is returned for invalid runtime mode + ErrInvalidRuntimeMode = errors.New("invalid runtime mode: must be 'docker' or 'k8s'") +) diff --git a/pkg/optimizer/ingestion/service.go b/pkg/optimizer/ingestion/service.go new file mode 100644 index 0000000000..821f970d6f --- /dev/null +++ b/pkg/optimizer/ingestion/service.go @@ -0,0 +1,215 @@ +package ingestion + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/optimizer/db" + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/pkg/optimizer/models" + "github.com/stacklok/toolhive/pkg/optimizer/tokens" +) + +// Config holds configuration for the ingestion service +type Config struct { + // Database configuration + DBConfig *db.Config + + // Embedding configuration + EmbeddingConfig *embeddings.Config + + // MCP timeout in seconds + MCPTimeout int + + // Workloads to skip during ingestion + SkippedWorkloads []string + + // Runtime mode: "docker" or "k8s" + RuntimeMode string + + // Kubernetes configuration (used when RuntimeMode is "k8s") + K8sAPIServerURL string + K8sNamespace string + K8sAllNamespaces bool +} + +// Service handles ingestion of MCP backends and their tools +type Service struct { + config *Config + database *db.DB + embeddingManager *embeddings.Manager + tokenCounter *tokens.Counter + backendServerOps *db.BackendServerOps + backendToolOps *db.BackendToolOps +} + +// NewService creates a new ingestion service +func NewService(config *Config) (*Service, error) { + // Set defaults + if config.MCPTimeout == 0 { + config.MCPTimeout = 30 + } + if len(config.SkippedWorkloads) == 0 { + config.SkippedWorkloads = []string{"inspector", "mcp-optimizer"} + } + + // Initialize database + database, err := db.NewDB(config.DBConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + + // Initialize embedding manager + embeddingManager, err := embeddings.NewManager(config.EmbeddingConfig) + if err != nil { + _ = database.Close() + return nil, fmt.Errorf("failed to initialize embedding manager: %w", err) + } + + // Initialize token counter + tokenCounter := tokens.NewCounter() + + // Create chromem-go embeddingFunc from our embedding manager + embeddingFunc := func(_ context.Context, text string) ([]float32, error) { + // Our manager takes a slice, so wrap the single text + embeddingsResult, err := embeddingManager.GenerateEmbedding([]string{text}) + if err != nil { + return nil, err + } + if len(embeddingsResult) == 0 { + return nil, fmt.Errorf("no embeddings generated") + } + return embeddingsResult[0], nil + } + + svc := &Service{ + config: config, + database: database, + embeddingManager: embeddingManager, + tokenCounter: tokenCounter, + backendServerOps: db.NewBackendServerOps(database, embeddingFunc), + backendToolOps: db.NewBackendToolOps(database, embeddingFunc), + } + + logger.Info("Ingestion service initialized for event-driven ingestion (chromem-go)") + return svc, nil +} + +// IngestServer ingests a single MCP server and its tools into the optimizer database. +// This is called by vMCP during session registration for each backend server. +// +// Parameters: +// - serverID: Unique identifier for the backend server +// - serverName: Human-readable server name +// - description: Optional server description +// - tools: List of tools available from this server +// +// This method will: +// 1. Create or update the backend server record (simplified metadata only) +// 2. Generate embeddings for server and tools +// 3. Count tokens for each tool +// 4. Store everything in the database for semantic search +// +// Note: URL, transport, status are NOT stored - vMCP manages backend lifecycle +func (s *Service) IngestServer( + ctx context.Context, + serverID string, + serverName string, + description *string, + tools []mcp.Tool, +) error { + logger.Infof("Ingesting server: %s (%d tools)", serverName, len(tools)) + + // Create backend server record (simplified - vMCP manages lifecycle) + // chromem-go will generate embeddings automatically from the content + backendServer := &models.BackendServer{ + ID: serverID, + Name: serverName, + Description: description, + Group: "default", // TODO: Pass group from vMCP if needed + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create or update server (chromem-go handles embeddings) + if err := s.backendServerOps.Update(ctx, backendServer); err != nil { + return fmt.Errorf("failed to create/update server %s: %w", serverName, err) + } + logger.Debugf("Created/updated server: %s", serverName) + + // Sync tools for this server + toolCount, err := s.syncBackendTools(ctx, serverID, serverName, tools) + if err != nil { + return fmt.Errorf("failed to sync tools for %s: %w", serverName, err) + } + + logger.Infof("Successfully ingested server %s with %d tools", serverName, toolCount) + return nil +} + +// syncBackendTools synchronizes tools for a backend server +func (s *Service) syncBackendTools(ctx context.Context, serverID string, serverName string, tools []mcp.Tool) (int, error) { + // Delete existing tools + if err := s.backendToolOps.DeleteByServer(ctx, serverID); err != nil { + return 0, fmt.Errorf("failed to delete existing tools: %w", err) + } + + if len(tools) == 0 { + return 0, nil + } + + // Create tool records (chromem-go will generate embeddings automatically) + for _, tool := range tools { + // Extract description for embedding + description := tool.Description + + // Convert InputSchema to JSON + schemaJSON, err := json.Marshal(tool.InputSchema) + if err != nil { + return 0, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) + } + + backendTool := &models.BackendTool{ + ID: uuid.New().String(), + MCPServerID: serverID, + ToolName: tool.Name, + Description: &description, + InputSchema: schemaJSON, + TokenCount: s.tokenCounter.CountToolTokens(tool), + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + if err := s.backendToolOps.Create(ctx, backendTool, serverName); err != nil { + return 0, fmt.Errorf("failed to create tool %s: %w", tool.Name, err) + } + } + + logger.Infof("Synced %d tools for server %s", len(tools), serverName) + return len(tools), nil +} + +// Close releases resources +func (s *Service) Close() error { + var errs []error + + if err := s.embeddingManager.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close embedding manager: %w", err)) + } + + if err := s.database.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close database: %w", err)) + } + + if len(errs) > 0 { + return fmt.Errorf("errors closing service: %v", errs) + } + + return nil +} diff --git a/pkg/optimizer/ingestion/service_test.go b/pkg/optimizer/ingestion/service_test.go new file mode 100644 index 0000000000..51c73767b8 --- /dev/null +++ b/pkg/optimizer/ingestion/service_test.go @@ -0,0 +1,148 @@ +package ingestion + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/db" + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" +) + +// TestServiceCreationAndIngestion demonstrates the complete chromem-go workflow: +// 1. Create in-memory database +// 2. Initialize ingestion service +// 3. Ingest server and tools +// 4. Query the database +func TestServiceCreationAndIngestion(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create temporary directory for persistence (optional) + tmpDir := t.TempDir() + + // Initialize service with placeholder embeddings (no dependencies) + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "placeholder", // Use placeholder for testing + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Create test tools + tools := []mcp.Tool{ + { + Name: "get_weather", + Description: "Get the current weather for a location", + }, + { + Name: "search_web", + Description: "Search the web for information", + }, + } + + // Ingest server with tools + serverName := "test-server" + serverID := "test-server-id" + description := "A test MCP server" + + err = svc.IngestServer(ctx, serverID, serverName, &description, tools) + require.NoError(t, err) + + // Query tools + allTools, err := svc.backendToolOps.ListByServer(ctx, serverID) + require.NoError(t, err) + require.Len(t, allTools, 2, "Expected 2 tools to be ingested") + + // Verify tool names + toolNames := make(map[string]bool) + for _, tool := range allTools { + toolNames[tool.ToolName] = true + } + require.True(t, toolNames["get_weather"], "get_weather tool should be present") + require.True(t, toolNames["search_web"], "search_web tool should be present") + + // Search for similar tools + results, err := svc.backendToolOps.Search(ctx, "weather information", 5, &serverID) + require.NoError(t, err) + require.NotEmpty(t, results, "Should find at least one similar tool") + + // With placeholder embeddings (hash-based), semantic similarity isn't guaranteed + // Just verify we got results back + require.Len(t, results, 2, "Should return both tools") + + // Verify both tools are present (order doesn't matter with placeholder embeddings) + toolNamesFound := make(map[string]bool) + for _, result := range results { + toolNamesFound[result.ToolName] = true + } + require.True(t, toolNamesFound["get_weather"], "get_weather should be in results") + require.True(t, toolNamesFound["search_web"], "search_web should be in results") +} + +// TestServiceWithOllama demonstrates using real embeddings (requires Ollama running) +// This test can be enabled locally to verify Ollama integration +func TestServiceWithOllama(t *testing.T) { + t.Parallel() + + // Skip if not explicitly enabled or Ollama is not available + if os.Getenv("TEST_OLLAMA") != "true" { + t.Skip("Skipping Ollama integration test (set TEST_OLLAMA=true to enable)") + } + + ctx := context.Background() + tmpDir := t.TempDir() + + // Initialize service with Ollama embeddings + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "ollama-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Create test tools + tools := []mcp.Tool{ + { + Name: "get_weather", + Description: "Get current weather conditions for any location worldwide", + }, + { + Name: "send_email", + Description: "Send an email message to a recipient", + }, + } + + // Ingest server + err = svc.IngestServer(ctx, "server-1", "TestServer", nil, tools) + require.NoError(t, err) + + // Search for weather-related tools + results, err := svc.backendToolOps.Search(ctx, "What's the temperature outside?", 5, nil) + require.NoError(t, err) + require.NotEmpty(t, results) + + // With real embeddings, weather tool should be most similar + require.Equal(t, "get_weather", results[0].ToolName, + "Weather tool should be most similar to weather query") +} diff --git a/pkg/optimizer/models/errors.go b/pkg/optimizer/models/errors.go new file mode 100644 index 0000000000..984dd43eea --- /dev/null +++ b/pkg/optimizer/models/errors.go @@ -0,0 +1,16 @@ +// Package models defines domain models for the optimizer. +// It includes structures for MCP servers, tools, and related metadata. +package models + +import "errors" + +var ( + // ErrRemoteServerMissingURL is returned when a remote server doesn't have a URL + ErrRemoteServerMissingURL = errors.New("remote servers must have URL") + + // ErrContainerServerMissingPackage is returned when a container server doesn't have a package + ErrContainerServerMissingPackage = errors.New("container servers must have package") + + // ErrInvalidTokenMetrics is returned when token metrics are inconsistent + ErrInvalidTokenMetrics = errors.New("invalid token metrics: calculated values don't match") +) diff --git a/pkg/optimizer/models/models.go b/pkg/optimizer/models/models.go new file mode 100644 index 0000000000..8e1e065a38 --- /dev/null +++ b/pkg/optimizer/models/models.go @@ -0,0 +1,173 @@ +package models + +import ( + "encoding/json" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// BaseMCPServer represents the common fields for MCP servers. +type BaseMCPServer struct { + ID string `json:"id"` + Name string `json:"name"` + Remote bool `json:"remote"` + Transport TransportType `json:"transport"` + Description *string `json:"description,omitempty"` + ServerEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB + Group string `json:"group"` + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// RegistryServer represents an MCP server from the registry catalog. +type RegistryServer struct { + BaseMCPServer + URL *string `json:"url,omitempty"` // For remote servers + Package *string `json:"package,omitempty"` // For container servers +} + +// Validate checks if the registry server has valid data. +// Remote servers must have URL, container servers must have package. +func (r *RegistryServer) Validate() error { + if r.Remote && r.URL == nil { + return ErrRemoteServerMissingURL + } + if !r.Remote && r.Package == nil { + return ErrContainerServerMissingPackage + } + return nil +} + +// BackendServer represents a running MCP server backend. +// Simplified: Only stores metadata needed for tool organization and search results. +// vMCP manages backend lifecycle (URL, status, transport, etc.) +type BackendServer struct { + ID string `json:"id"` + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Group string `json:"group"` + ServerEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// BaseTool represents the common fields for tools. +type BaseTool struct { + ID string `json:"id"` + MCPServerID string `json:"mcpserver_id"` + Details mcp.Tool `json:"details"` + DetailsEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// RegistryTool represents a tool from a registry MCP server. +type RegistryTool struct { + BaseTool +} + +// BackendTool represents a tool from a backend MCP server. +// With chromem-go, embeddings are managed by the database. +type BackendTool struct { + ID string `json:"id"` + MCPServerID string `json:"mcpserver_id"` + ToolName string `json:"tool_name"` + Description *string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema,omitempty"` + ToolEmbedding []float32 `json:"-"` // Managed by chromem-go + TokenCount int `json:"token_count"` + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// ToolDetailsToJSON converts mcp.Tool to JSON for storage in the database. +func ToolDetailsToJSON(tool mcp.Tool) (string, error) { + data, err := json.Marshal(tool) + if err != nil { + return "", err + } + return string(data), nil +} + +// ToolDetailsFromJSON converts JSON to mcp.Tool +func ToolDetailsFromJSON(data string) (*mcp.Tool, error) { + var tool mcp.Tool + err := json.Unmarshal([]byte(data), &tool) + if err != nil { + return nil, err + } + return &tool, nil +} + +// BackendToolWithMetadata represents a backend tool with similarity score. +type BackendToolWithMetadata struct { + BackendTool + Similarity float32 `json:"similarity"` // Cosine similarity from chromem-go (0-1, higher is better) +} + +// RegistryToolWithMetadata represents a registry tool with server information and similarity distance. +type RegistryToolWithMetadata struct { + ServerName string `json:"server_name"` + ServerDescription *string `json:"server_description,omitempty"` + Distance float64 `json:"distance"` // Cosine distance from query embedding + Tool RegistryTool `json:"tool"` +} + +// BackendWithRegistry represents a backend server with its resolved registry relationship. +type BackendWithRegistry struct { + Backend BackendServer `json:"backend"` + Registry *RegistryServer `json:"registry,omitempty"` // NULL if autonomous +} + +// EffectiveDescription returns the description (inherited from registry or own). +func (b *BackendWithRegistry) EffectiveDescription() *string { + if b.Registry != nil { + return b.Registry.Description + } + return b.Backend.Description +} + +// EffectiveEmbedding returns the embedding (inherited from registry or own). +func (b *BackendWithRegistry) EffectiveEmbedding() []float32 { + if b.Registry != nil { + return b.Registry.ServerEmbedding + } + return b.Backend.ServerEmbedding +} + +// ServerNameForTools returns the server name to use as context for tool embeddings. +func (b *BackendWithRegistry) ServerNameForTools() string { + if b.Registry != nil { + return b.Registry.Name + } + return b.Backend.Name +} + +// TokenMetrics represents token efficiency metrics for tool filtering. +type TokenMetrics struct { + BaselineTokens int `json:"baseline_tokens"` // Total tokens for all running server tools + ReturnedTokens int `json:"returned_tokens"` // Total tokens for returned/filtered tools + TokensSaved int `json:"tokens_saved"` // Number of tokens saved by filtering + SavingsPercentage float64 `json:"savings_percentage"` // Percentage of tokens saved (0-100) +} + +// Validate checks if the token metrics are consistent. +func (t *TokenMetrics) Validate() error { + if t.TokensSaved != t.BaselineTokens-t.ReturnedTokens { + return ErrInvalidTokenMetrics + } + + var expectedPct float64 + if t.BaselineTokens > 0 { + expectedPct = (float64(t.TokensSaved) / float64(t.BaselineTokens)) * 100 + // Allow small floating point differences (0.01%) + if expectedPct-t.SavingsPercentage > 0.01 || t.SavingsPercentage-expectedPct > 0.01 { + return ErrInvalidTokenMetrics + } + } else if t.SavingsPercentage != 0.0 { + return ErrInvalidTokenMetrics + } + + return nil +} diff --git a/pkg/optimizer/models/models_test.go b/pkg/optimizer/models/models_test.go new file mode 100644 index 0000000000..6fea81c927 --- /dev/null +++ b/pkg/optimizer/models/models_test.go @@ -0,0 +1,270 @@ +package models + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestRegistryServer_Validate(t *testing.T) { + t.Parallel() + url := "http://example.com/mcp" + pkg := "github.com/example/mcp-server" + + tests := []struct { + name string + server *RegistryServer + wantErr bool + }{ + { + name: "Remote server with URL is valid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: true, + }, + URL: &url, + }, + wantErr: false, + }, + { + name: "Container server with package is valid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: false, + }, + Package: &pkg, + }, + wantErr: false, + }, + { + name: "Remote server without URL is invalid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: true, + }, + }, + wantErr: true, + }, + { + name: "Container server without package is invalid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: false, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.server.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("RegistryServer.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestToolDetailsToJSON(t *testing.T) { + t.Parallel() + tool := mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + } + + json, err := ToolDetailsToJSON(tool) + if err != nil { + t.Fatalf("ToolDetailsToJSON() error = %v", err) + } + + if json == "" { + t.Error("ToolDetailsToJSON() returned empty string") + } + + // Try to parse it back + parsed, err := ToolDetailsFromJSON(json) + if err != nil { + t.Fatalf("ToolDetailsFromJSON() error = %v", err) + } + + if parsed.Name != tool.Name { + t.Errorf("Tool name mismatch: got %v, want %v", parsed.Name, tool.Name) + } + + if parsed.Description != tool.Description { + t.Errorf("Tool description mismatch: got %v, want %v", parsed.Description, tool.Description) + } +} + +func TestTokenMetrics_Validate(t *testing.T) { + t.Parallel() + tests := []struct { + name string + metrics *TokenMetrics + wantErr bool + }{ + { + name: "Valid metrics with savings", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 600, + TokensSaved: 400, + SavingsPercentage: 40.0, + }, + wantErr: false, + }, + { + name: "Valid metrics with no savings", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 1000, + TokensSaved: 0, + SavingsPercentage: 0.0, + }, + wantErr: false, + }, + { + name: "Invalid: tokens saved doesn't match", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 600, + TokensSaved: 500, // Should be 400 + SavingsPercentage: 40.0, + }, + wantErr: true, + }, + { + name: "Invalid: savings percentage doesn't match", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 600, + TokensSaved: 400, + SavingsPercentage: 50.0, // Should be 40.0 + }, + wantErr: true, + }, + { + name: "Invalid: non-zero percentage with zero baseline", + metrics: &TokenMetrics{ + BaselineTokens: 0, + ReturnedTokens: 0, + TokensSaved: 0, + SavingsPercentage: 10.0, // Should be 0 + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.metrics.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("TokenMetrics.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestBackendWithRegistry_EffectiveDescription(t *testing.T) { + t.Parallel() + registryDesc := "Registry description" + backendDesc := "Backend description" + + tests := []struct { + name string + w *BackendWithRegistry + want *string + }{ + { + name: "Uses registry description when available", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Description: &backendDesc, + }, + Registry: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Description: ®istryDesc, + }, + }, + }, + want: ®istryDesc, + }, + { + name: "Uses backend description when no registry", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Description: &backendDesc, + }, + Registry: nil, + }, + want: &backendDesc, + }, + { + name: "Returns nil when no description", + w: &BackendWithRegistry{ + Backend: BackendServer{}, + Registry: nil, + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := tt.w.EffectiveDescription() + if (got == nil) != (tt.want == nil) { + t.Errorf("BackendWithRegistry.EffectiveDescription() = %v, want %v", got, tt.want) + } + if got != nil && tt.want != nil && *got != *tt.want { + t.Errorf("BackendWithRegistry.EffectiveDescription() = %v, want %v", *got, *tt.want) + } + }) + } +} + +func TestBackendWithRegistry_ServerNameForTools(t *testing.T) { + t.Parallel() + tests := []struct { + name string + w *BackendWithRegistry + want string + }{ + { + name: "Uses registry name when available", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Name: "backend-name", + }, + Registry: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Name: "registry-name", + }, + }, + }, + want: "registry-name", + }, + { + name: "Uses backend name when no registry", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Name: "backend-name", + }, + Registry: nil, + }, + want: "backend-name", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.w.ServerNameForTools(); got != tt.want { + t.Errorf("BackendWithRegistry.ServerNameForTools() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/optimizer/models/transport.go b/pkg/optimizer/models/transport.go new file mode 100644 index 0000000000..c8e5c0ce41 --- /dev/null +++ b/pkg/optimizer/models/transport.go @@ -0,0 +1,111 @@ +package models + +import ( + "database/sql/driver" + "fmt" +) + +// TransportType represents the transport protocol used by an MCP server. +// Maps 1:1 to ToolHive transport modes. +type TransportType string + +const ( + // TransportSSE represents Server-Sent Events transport + TransportSSE TransportType = "sse" + // TransportStreamable represents Streamable HTTP transport + TransportStreamable TransportType = "streamable-http" +) + +// Valid returns true if the transport type is valid +func (t TransportType) Valid() bool { + switch t { + case TransportSSE, TransportStreamable: + return true + default: + return false + } +} + +// String returns the string representation +func (t TransportType) String() string { + return string(t) +} + +// Value implements the driver.Valuer interface for database storage +func (t TransportType) Value() (driver.Value, error) { + if !t.Valid() { + return nil, fmt.Errorf("invalid transport type: %s", t) + } + return string(t), nil +} + +// Scan implements the sql.Scanner interface for database retrieval +func (t *TransportType) Scan(value interface{}) error { + if value == nil { + return fmt.Errorf("transport type cannot be nil") + } + + str, ok := value.(string) + if !ok { + return fmt.Errorf("transport type must be a string, got %T", value) + } + + *t = TransportType(str) + if !t.Valid() { + return fmt.Errorf("invalid transport type from database: %s", str) + } + + return nil +} + +// MCPStatus represents the status of an MCP server backend. +type MCPStatus string + +const ( + // StatusRunning indicates the backend is running + StatusRunning MCPStatus = "running" + // StatusStopped indicates the backend is stopped + StatusStopped MCPStatus = "stopped" +) + +// Valid returns true if the status is valid +func (s MCPStatus) Valid() bool { + switch s { + case StatusRunning, StatusStopped: + return true + default: + return false + } +} + +// String returns the string representation +func (s MCPStatus) String() string { + return string(s) +} + +// Value implements the driver.Valuer interface for database storage +func (s MCPStatus) Value() (driver.Value, error) { + if !s.Valid() { + return nil, fmt.Errorf("invalid MCP status: %s", s) + } + return string(s), nil +} + +// Scan implements the sql.Scanner interface for database retrieval +func (s *MCPStatus) Scan(value interface{}) error { + if value == nil { + return fmt.Errorf("MCP status cannot be nil") + } + + str, ok := value.(string) + if !ok { + return fmt.Errorf("MCP status must be a string, got %T", value) + } + + *s = MCPStatus(str) + if !s.Valid() { + return fmt.Errorf("invalid MCP status from database: %s", str) + } + + return nil +} diff --git a/pkg/optimizer/models/transport_test.go b/pkg/optimizer/models/transport_test.go new file mode 100644 index 0000000000..a70b1032f9 --- /dev/null +++ b/pkg/optimizer/models/transport_test.go @@ -0,0 +1,273 @@ +package models + +import ( + "testing" +) + +func TestTransportType_Valid(t *testing.T) { + t.Parallel() + tests := []struct { + name string + transport TransportType + want bool + }{ + { + name: "SSE transport is valid", + transport: TransportSSE, + want: true, + }, + { + name: "Streamable transport is valid", + transport: TransportStreamable, + want: true, + }, + { + name: "Invalid transport is not valid", + transport: TransportType("invalid"), + want: false, + }, + { + name: "Empty transport is not valid", + transport: TransportType(""), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.transport.Valid(); got != tt.want { + t.Errorf("TransportType.Valid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTransportType_Value(t *testing.T) { + t.Parallel() + tests := []struct { + name string + transport TransportType + wantValue string + wantErr bool + }{ + { + name: "SSE transport value", + transport: TransportSSE, + wantValue: "sse", + wantErr: false, + }, + { + name: "Streamable transport value", + transport: TransportStreamable, + wantValue: "streamable-http", + wantErr: false, + }, + { + name: "Invalid transport returns error", + transport: TransportType("invalid"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := tt.transport.Value() + if (err != nil) != tt.wantErr { + t.Errorf("TransportType.Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.wantValue { + t.Errorf("TransportType.Value() = %v, want %v", got, tt.wantValue) + } + }) + } +} + +func TestTransportType_Scan(t *testing.T) { + t.Parallel() + tests := []struct { + name string + value interface{} + want TransportType + wantErr bool + }{ + { + name: "Scan SSE transport", + value: "sse", + want: TransportSSE, + wantErr: false, + }, + { + name: "Scan streamable transport", + value: "streamable-http", + want: TransportStreamable, + wantErr: false, + }, + { + name: "Scan invalid transport returns error", + value: "invalid", + wantErr: true, + }, + { + name: "Scan nil returns error", + value: nil, + wantErr: true, + }, + { + name: "Scan non-string returns error", + value: 123, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var transport TransportType + err := transport.Scan(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("TransportType.Scan() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && transport != tt.want { + t.Errorf("TransportType.Scan() = %v, want %v", transport, tt.want) + } + }) + } +} + +func TestMCPStatus_Valid(t *testing.T) { + t.Parallel() + tests := []struct { + name string + status MCPStatus + want bool + }{ + { + name: "Running status is valid", + status: StatusRunning, + want: true, + }, + { + name: "Stopped status is valid", + status: StatusStopped, + want: true, + }, + { + name: "Invalid status is not valid", + status: MCPStatus("invalid"), + want: false, + }, + { + name: "Empty status is not valid", + status: MCPStatus(""), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.status.Valid(); got != tt.want { + t.Errorf("MCPStatus.Valid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMCPStatus_Value(t *testing.T) { + t.Parallel() + tests := []struct { + name string + status MCPStatus + wantValue string + wantErr bool + }{ + { + name: "Running status value", + status: StatusRunning, + wantValue: "running", + wantErr: false, + }, + { + name: "Stopped status value", + status: StatusStopped, + wantValue: "stopped", + wantErr: false, + }, + { + name: "Invalid status returns error", + status: MCPStatus("invalid"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := tt.status.Value() + if (err != nil) != tt.wantErr { + t.Errorf("MCPStatus.Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.wantValue { + t.Errorf("MCPStatus.Value() = %v, want %v", got, tt.wantValue) + } + }) + } +} + +func TestMCPStatus_Scan(t *testing.T) { + t.Parallel() + tests := []struct { + name string + value interface{} + want MCPStatus + wantErr bool + }{ + { + name: "Scan running status", + value: "running", + want: StatusRunning, + wantErr: false, + }, + { + name: "Scan stopped status", + value: "stopped", + want: StatusStopped, + wantErr: false, + }, + { + name: "Scan invalid status returns error", + value: "invalid", + wantErr: true, + }, + { + name: "Scan nil returns error", + value: nil, + wantErr: true, + }, + { + name: "Scan non-string returns error", + value: 123, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var status MCPStatus + err := status.Scan(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("MCPStatus.Scan() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && status != tt.want { + t.Errorf("MCPStatus.Scan() = %v, want %v", status, tt.want) + } + }) + } +} diff --git a/pkg/optimizer/tokens/counter.go b/pkg/optimizer/tokens/counter.go new file mode 100644 index 0000000000..d6c922ce7c --- /dev/null +++ b/pkg/optimizer/tokens/counter.go @@ -0,0 +1,65 @@ +// Package tokens provides token counting utilities for LLM cost estimation. +// It estimates token counts for MCP tools and their metadata. +package tokens + +import ( + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Counter counts tokens for LLM consumption +// This provides estimates of token usage for tools +type Counter struct { + // Simple heuristic: ~4 characters per token for English text + charsPerToken float64 +} + +// NewCounter creates a new token counter +func NewCounter() *Counter { + return &Counter{ + charsPerToken: 4.0, // GPT-style tokenization approximation + } +} + +// CountToolTokens estimates the number of tokens for a tool +func (c *Counter) CountToolTokens(tool mcp.Tool) int { + // Convert tool to JSON representation (as it would be sent to LLM) + toolJSON, err := json.Marshal(tool) + if err != nil { + // Fallback to simple estimation + return c.estimateFromTool(tool) + } + + // Estimate tokens from JSON length + return int(float64(len(toolJSON)) / c.charsPerToken) +} + +// estimateFromTool provides a fallback estimation from tool fields +func (c *Counter) estimateFromTool(tool mcp.Tool) int { + totalChars := len(tool.Name) + + if tool.Description != "" { + totalChars += len(tool.Description) + } + + // Estimate input schema size + schemaJSON, _ := json.Marshal(tool.InputSchema) + totalChars += len(schemaJSON) + + return int(float64(totalChars) / c.charsPerToken) +} + +// CountToolsTokens calculates total tokens for multiple tools +func (c *Counter) CountToolsTokens(tools []mcp.Tool) int { + total := 0 + for _, tool := range tools { + total += c.CountToolTokens(tool) + } + return total +} + +// EstimateText estimates tokens for arbitrary text +func (c *Counter) EstimateText(text string) int { + return int(float64(len(text)) / c.charsPerToken) +} diff --git a/pkg/optimizer/tokens/counter_test.go b/pkg/optimizer/tokens/counter_test.go new file mode 100644 index 0000000000..617ddd91ba --- /dev/null +++ b/pkg/optimizer/tokens/counter_test.go @@ -0,0 +1,143 @@ +package tokens + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestCountToolTokens(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tool := mcp.Tool{ + Name: "test_tool", + Description: "A test tool for counting tokens", + } + + tokens := counter.CountToolTokens(tool) + + // Should return a positive number + if tokens <= 0 { + t.Errorf("Expected positive token count, got %d", tokens) + } + + // Rough estimate: tool should have at least a few tokens + if tokens < 5 { + t.Errorf("Expected at least 5 tokens for a tool with name and description, got %d", tokens) + } +} + +func TestCountToolTokens_MinimalTool(t *testing.T) { + t.Parallel() + counter := NewCounter() + + // Minimal tool with just a name + tool := mcp.Tool{ + Name: "minimal", + } + + tokens := counter.CountToolTokens(tool) + + // Should return a positive number even for minimal tool + if tokens <= 0 { + t.Errorf("Expected positive token count for minimal tool, got %d", tokens) + } +} + +func TestCountToolTokens_NoDescription(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tool := mcp.Tool{ + Name: "test_tool", + } + + tokens := counter.CountToolTokens(tool) + + // Should still return a positive number + if tokens <= 0 { + t.Errorf("Expected positive token count for tool without description, got %d", tokens) + } +} + +func TestCountToolsTokens(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tools := []mcp.Tool{ + { + Name: "tool1", + Description: "First tool", + }, + { + Name: "tool2", + Description: "Second tool with longer description", + }, + } + + totalTokens := counter.CountToolsTokens(tools) + + // Should be greater than individual tools + tokens1 := counter.CountToolTokens(tools[0]) + tokens2 := counter.CountToolTokens(tools[1]) + + expectedTotal := tokens1 + tokens2 + if totalTokens != expectedTotal { + t.Errorf("Expected total tokens %d, got %d", expectedTotal, totalTokens) + } +} + +func TestCountToolsTokens_EmptyList(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tokens := counter.CountToolsTokens([]mcp.Tool{}) + + // Should return 0 for empty list + if tokens != 0 { + t.Errorf("Expected 0 tokens for empty list, got %d", tokens) + } +} + +func TestEstimateText(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tests := []struct { + name string + text string + want int + }{ + { + name: "Empty text", + text: "", + want: 0, + }, + { + name: "Short text", + text: "Hello", + want: 1, // 5 chars / 4 chars per token ≈ 1 + }, + { + name: "Medium text", + text: "This is a test message", + want: 5, // 22 chars / 4 chars per token ≈ 5 + }, + { + name: "Long text", + text: "This is a much longer test message that should have more tokens because it contains significantly more characters", + want: 28, // 112 chars / 4 chars per token = 28 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := counter.EstimateText(tt.text) + if got != tt.want { + t.Errorf("EstimateText() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go index a6b45d8daa..d1564e3c12 100644 --- a/pkg/vmcp/config/config.go +++ b/pkg/vmcp/config/config.go @@ -125,6 +125,13 @@ type Config struct { // See audit.Config for available configuration options. // +optional Audit *audit.Config `json:"audit,omitempty" yaml:"audit,omitempty"` + + // Optimizer configures the MCP optimizer for context optimization on large toolsets. + // When enabled, vMCP exposes optim.find_tool and optim.call_tool operations to clients + // instead of all backend tools directly. This reduces token usage by allowing + // LLMs to discover relevant tools on demand rather than receiving all tool definitions. + // +optional + Optimizer *OptimizerConfig `json:"optimizer,omitempty" yaml:"optimizer,omitempty"` } // IncomingAuthConfig configures client authentication to the virtual MCP server. @@ -634,6 +641,80 @@ type OutputProperty struct { Default thvjson.Any `json:"default,omitempty" yaml:"default,omitempty"` } +// OptimizerConfig configures the MCP optimizer for semantic tool discovery. +// The optimizer reduces token usage by allowing LLMs to discover relevant tools +// on demand rather than receiving all tool definitions upfront. +// +kubebuilder:object:generate=true +// +gendoc +type OptimizerConfig struct { + // Enabled determines whether the optimizer is active. + // When true, vMCP exposes optim.find_tool and optim.call_tool instead of all backend tools. + // +optional + Enabled bool `json:"enabled" yaml:"enabled"` + + // EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder". + // - "ollama": Uses local Ollama HTTP API for embeddings + // - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + // - "placeholder": Uses deterministic hash-based embeddings (for testing/development) + // +kubebuilder:validation:Enum=ollama;openai-compatible;placeholder + // +optional + EmbeddingBackend string `json:"embeddingBackend,omitempty" yaml:"embeddingBackend,omitempty"` + + // EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API). + // Required when EmbeddingBackend is "ollama" or "openai-compatible". + // Examples: + // - Ollama: "http://localhost:11434" + // - vLLM: "http://vllm-service:8000/v1" + // - OpenAI: "https://api.openai.com/v1" + // +optional + EmbeddingURL string `json:"embeddingURL,omitempty" yaml:"embeddingURL,omitempty"` + + // EmbeddingModel is the model name to use for embeddings. + // Required when EmbeddingBackend is "ollama" or "openai-compatible". + // Examples: + // - Ollama: "nomic-embed-text", "all-minilm" + // - vLLM: "BAAI/bge-small-en-v1.5" + // - OpenAI: "text-embedding-3-small" + // +optional + EmbeddingModel string `json:"embeddingModel,omitempty" yaml:"embeddingModel,omitempty"` + + // EmbeddingDimension is the dimension of the embedding vectors. + // Common values: + // - 384: all-MiniLM-L6-v2, nomic-embed-text + // - 768: BAAI/bge-small-en-v1.5 + // - 1536: OpenAI text-embedding-3-small + // +kubebuilder:validation:Minimum=1 + // +optional + EmbeddingDimension int `json:"embeddingDimension,omitempty" yaml:"embeddingDimension,omitempty"` + + // PersistPath is the optional filesystem path for persisting the chromem-go database. + // If empty, the database will be in-memory only (ephemeral). + // When set, tool metadata and embeddings are persisted to disk for faster restarts. + // +optional + PersistPath string `json:"persistPath,omitempty" yaml:"persistPath,omitempty"` + + // FTSDBPath is the path to the SQLite FTS5 database for BM25 text search. + // If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + // Hybrid search (semantic + BM25) is always enabled. + // +optional + FTSDBPath string `json:"ftsDBPath,omitempty" yaml:"ftsDBPath,omitempty"` + + // HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search. + // Value range: 0.0 (all BM25) to 1.0 (all semantic). + // Default: 0.7 (70% semantic, 30% BM25) + // Only used when FTSDBPath is set. + // +optional + // +kubebuilder:validation:Minimum=0.0 + // +kubebuilder:validation:Maximum=1.0 + HybridSearchRatio *float64 `json:"hybridSearchRatio,omitempty" yaml:"hybridSearchRatio,omitempty"` + + // EmbeddingService is the name of a Kubernetes Service that provides embeddings (K8s only). + // This is an alternative to EmbeddingURL for in-cluster deployments. + // When set, vMCP will resolve the service DNS name for the embedding API. + // +optional + EmbeddingService string `json:"embeddingService,omitempty" yaml:"embeddingService,omitempty"` +} + // Validator validates configuration. type Validator interface { // Validate checks if the configuration is valid. diff --git a/pkg/vmcp/config/zz_generated.deepcopy.go b/pkg/vmcp/config/zz_generated.deepcopy.go index 97b75415dd..6550ecddc7 100644 --- a/pkg/vmcp/config/zz_generated.deepcopy.go +++ b/pkg/vmcp/config/zz_generated.deepcopy.go @@ -187,6 +187,11 @@ func (in *Config) DeepCopyInto(out *Config) { *out = new(audit.Config) (*in).DeepCopyInto(*out) } + if in.Optimizer != nil { + in, out := &in.Optimizer, &out.Optimizer + *out = new(OptimizerConfig) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Config. @@ -324,6 +329,21 @@ func (in *OperationalConfig) DeepCopy() *OperationalConfig { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *OptimizerConfig) DeepCopyInto(out *OptimizerConfig) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new OptimizerConfig. +func (in *OptimizerConfig) DeepCopy() *OptimizerConfig { + if in == nil { + return nil + } + out := new(OptimizerConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *OutgoingAuthConfig) DeepCopyInto(out *OutgoingAuthConfig) { *out = *in diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go new file mode 100644 index 0000000000..4a24d95576 --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer.go @@ -0,0 +1,364 @@ +// Package optimizer provides vMCP integration for semantic tool discovery. +// +// This package implements the RFC-0022 optimizer integration, exposing: +// - optim.find_tool: Semantic/keyword-based tool discovery +// - optim.call_tool: Dynamic tool invocation across backends +// +// Architecture: +// - Embeddings are generated during session initialization (OnRegisterSession hook) +// - Tools are exposed as standard MCP tools callable via tools/call +// - Integrates with vMCP's two-boundary authentication model +// - Uses existing router for backend tool invocation +package optimizer + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/optimizer/db" + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/pkg/optimizer/ingestion" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" +) + +// Config holds optimizer configuration for vMCP integration. +type Config struct { + // Enabled controls whether optimizer tools are available + Enabled bool + + // PersistPath is the optional path for chromem-go database persistence (empty = in-memory) + PersistPath string + + // FTSDBPath is the path to SQLite FTS5 database for BM25 search + // (empty = auto-default: ":memory:" or "{PersistPath}/fts.db") + FTSDBPath string + + // HybridSearchRatio controls semantic vs BM25 mix (0.0-1.0, default: 0.7) + HybridSearchRatio float64 + + // EmbeddingConfig configures the embedding backend (vLLM, Ollama, placeholder) + EmbeddingConfig *embeddings.Config +} + +// OptimizerIntegration manages optimizer functionality within vMCP. +// +//nolint:revive // Name is intentional for clarity in external packages +type OptimizerIntegration struct { + config *Config + ingestionService *ingestion.Service + mcpServer *server.MCPServer // For registering tools + backendClient vmcp.BackendClient // For querying backends at startup +} + +// NewIntegration creates a new optimizer integration. +func NewIntegration( + _ context.Context, + cfg *Config, + mcpServer *server.MCPServer, + backendClient vmcp.BackendClient, +) (*OptimizerIntegration, error) { + if cfg == nil || !cfg.Enabled { + return nil, nil // Optimizer disabled + } + + // Initialize ingestion service with embedding backend + ingestionCfg := &ingestion.Config{ + DBConfig: &db.Config{ + PersistPath: cfg.PersistPath, + FTSDBPath: cfg.FTSDBPath, + }, + EmbeddingConfig: cfg.EmbeddingConfig, + } + + svc, err := ingestion.NewService(ingestionCfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize optimizer service: %w", err) + } + + return &OptimizerIntegration{ + config: cfg, + ingestionService: svc, + mcpServer: mcpServer, + backendClient: backendClient, + }, nil +} + +// OnRegisterSession is called during session initialization to generate embeddings +// and register optimizer tools. +// +// This hook: +// 1. Extracts backend tools from discovered capabilities +// 2. Generates embeddings for all tools (parallel per-backend) +// 3. Registers optim.find_tool and optim.call_tool as session tools +func (o *OptimizerIntegration) OnRegisterSession( + ctx context.Context, + session server.ClientSession, + capabilities *aggregator.AggregatedCapabilities, +) error { + if o == nil { + return nil // Optimizer not enabled + } + + sessionID := session.SessionID() + logger.Infow("Generating embeddings for session", "session_id", sessionID) + + // Group tools by backend for parallel processing + type backendTools struct { + backendID string + backendName string + backendURL string + transport string + tools []mcp.Tool + } + + backendMap := make(map[string]*backendTools) + + // Extract tools from routing table + if capabilities.RoutingTable != nil { + for toolName, target := range capabilities.RoutingTable.Tools { + // Find the tool definition from capabilities.Tools + var toolDef mcp.Tool + found := false + for i := range capabilities.Tools { + if capabilities.Tools[i].Name == toolName { + // Convert vmcp.Tool to mcp.Tool + // Note: vmcp.Tool.InputSchema is map[string]any, mcp.Tool.InputSchema is ToolInputSchema struct + // For ingestion, we just need the tool name and description + toolDef = mcp.Tool{ + Name: capabilities.Tools[i].Name, + Description: capabilities.Tools[i].Description, + // InputSchema will be empty - we only need name/description for embedding generation + } + found = true + break + } + } + if !found { + logger.Warnw("Tool in routing table but not in capabilities", + "tool_name", toolName, + "backend_id", target.WorkloadID) + continue + } + + // Group by backend + if _, exists := backendMap[target.WorkloadID]; !exists { + backendMap[target.WorkloadID] = &backendTools{ + backendID: target.WorkloadID, + backendName: target.WorkloadName, + backendURL: target.BaseURL, + transport: target.TransportType, + tools: []mcp.Tool{}, + } + } + backendMap[target.WorkloadID].tools = append(backendMap[target.WorkloadID].tools, toolDef) + } + } + + // Ingest each backend's tools (in parallel - TODO: add goroutines) + for _, bt := range backendMap { + logger.Debugw("Ingesting backend for session", + "session_id", sessionID, + "backend_id", bt.backendID, + "backend_name", bt.backendName, + "tool_count", len(bt.tools)) + + // Ingest server with simplified metadata + // Note: URL and transport are not stored - vMCP manages backend lifecycle + err := o.ingestionService.IngestServer( + ctx, + bt.backendID, + bt.backendName, + nil, // description + bt.tools, + ) + if err != nil { + logger.Errorw("Failed to ingest backend", + "session_id", sessionID, + "backend_id", bt.backendID, + "error", err) + // Continue with other backends + } + } + + logger.Infow("Embeddings generated for session", + "session_id", sessionID, + "backend_count", len(backendMap)) + + return nil +} + +// RegisterTools adds optimizer tools to the session. +// This should be called after OnRegisterSession completes. +func (o *OptimizerIntegration) RegisterTools(_ context.Context, session server.ClientSession) error { + if o == nil { + return nil // Optimizer not enabled + } + + sessionID := session.SessionID() + + // Define optimizer tools with handlers + optimizerTools := []server.ServerTool{ + { + Tool: mcp.Tool{ + Name: "optim.find_tool", + Description: "Semantic search across all backend tools using natural language description and optional keywords", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "tool_description": map[string]any{ + "type": "string", + "description": "Natural language description of the tool you're looking for", + }, + "tool_keywords": map[string]any{ + "type": "string", + "description": "Optional space-separated keywords for keyword-based search", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of tools to return (default: 10)", + "default": 10, + }, + }, + Required: []string{"tool_description"}, + }, + }, + Handler: o.createFindToolHandler(), + }, + { + Tool: mcp.Tool{ + Name: "optim.call_tool", + Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "backend_id": map[string]any{ + "type": "string", + "description": "Backend ID from find_tool results", + }, + "tool_name": map[string]any{ + "type": "string", + "description": "Tool name to invoke", + }, + "parameters": map[string]any{ + "type": "object", + "description": "Parameters to pass to the tool", + }, + }, + Required: []string{"backend_id", "tool_name", "parameters"}, + }, + }, + Handler: o.createCallToolHandler(), + }, + } + + // Add tools to session + if err := o.mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil { + return fmt.Errorf("failed to add optimizer tools to session: %w", err) + } + + logger.Debugw("Optimizer tools registered", "session_id", sessionID) + return nil +} + +// createFindToolHandler creates the handler for optim.find_tool +func (*OptimizerIntegration) createFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // TODO: Implement semantic search + // 1. Extract tool_description and tool_keywords from request.Params.Arguments + // 2. Call optimizer search service (hybrid semantic + BM25) + // 3. Return ranked list of tools with scores and token metrics + + logger.Debugw("optim.find_tool called", "request", request) + + return mcp.NewToolResultError("optim.find_tool not yet implemented"), nil + } +} + +// createCallToolHandler creates the handler for optim.call_tool +func (*OptimizerIntegration) createCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // TODO: Implement dynamic tool invocation + // 1. Extract backend_id, tool_name, parameters from request.Params.Arguments + // 2. Validate backend and tool exist + // 3. Route to backend via existing router + // 4. Return result + + logger.Debugw("optim.call_tool called", "request", request) + + return mcp.NewToolResultError("optim.call_tool not yet implemented"), nil + } +} + +// IngestInitialBackends ingests all discovered backends and their tools at startup. +// This should be called after backends are discovered during server initialization. +func (o *OptimizerIntegration) IngestInitialBackends(ctx context.Context, backends []vmcp.Backend) error { + if o == nil || o.ingestionService == nil { + return nil // Optimizer disabled + } + + logger.Infof("Ingesting %d discovered backends into optimizer", len(backends)) + + for _, backend := range backends { + // Convert Backend to BackendTarget for client API + target := vmcp.BackendToTarget(&backend) + if target == nil { + logger.Warnf("Failed to convert backend %s to target", backend.Name) + continue + } + + // Query backend capabilities to get its tools + capabilities, err := o.backendClient.ListCapabilities(ctx, target) + if err != nil { + logger.Warnf("Failed to query capabilities for backend %s: %v", backend.Name, err) + continue // Skip this backend but continue with others + } + + // Extract tools from capabilities + // Note: For ingestion, we only need name and description (for generating embeddings) + // InputSchema is not used by the ingestion service + var tools []mcp.Tool + for _, tool := range capabilities.Tools { + tools = append(tools, mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + // InputSchema not needed for embedding generation + }) + } + + // Get description from metadata (may be empty) + var description *string + if backend.Metadata != nil { + if desc := backend.Metadata["description"]; desc != "" { + description = &desc + } + } + + // Ingest this backend's tools + if err := o.ingestionService.IngestServer( + ctx, + backend.ID, + backend.Name, + description, + tools, + ); err != nil { + logger.Warnf("Failed to ingest backend %s: %v", backend.Name, err) + continue // Log but don't fail startup + } + } + + logger.Info("Initial backend ingestion completed") + return nil +} + +// Close cleans up optimizer resources. +func (o *OptimizerIntegration) Close() error { + if o == nil || o.ingestionService == nil { + return nil + } + return o.ingestionService.Close() +} diff --git a/pkg/vmcp/optimizer/optimizer_integration_test.go b/pkg/vmcp/optimizer/optimizer_integration_test.go new file mode 100644 index 0000000000..82a51a925a --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer_integration_test.go @@ -0,0 +1,167 @@ +package optimizer + +import ( + "context" + "path/filepath" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" +) + +// mockBackendClient implements vmcp.BackendClient for integration testing +type mockIntegrationBackendClient struct { + backends map[string]*vmcp.CapabilityList +} + +func newMockIntegrationBackendClient() *mockIntegrationBackendClient { + return &mockIntegrationBackendClient{ + backends: make(map[string]*vmcp.CapabilityList), + } +} + +func (m *mockIntegrationBackendClient) addBackend(backendID string, caps *vmcp.CapabilityList) { + m.backends[backendID] = caps +} + +func (m *mockIntegrationBackendClient) ListCapabilities(_ context.Context, target *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + if caps, exists := m.backends[target.WorkloadID]; exists { + return caps, nil + } + return &vmcp.CapabilityList{}, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) { + return nil, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) { + return "", nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) { + return nil, nil +} + +// mockIntegrationSession implements server.ClientSession for testing +type mockIntegrationSession struct { + sessionID string +} + +func (m *mockIntegrationSession) SessionID() string { + return m.sessionID +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Send(_ interface{}) error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Close() error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Initialize() { + // No-op for testing +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Initialized() bool { + return true +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + // Return a dummy channel for testing + ch := make(chan mcp.JSONRPCNotification, 1) + return ch +} + +// TestOptimizerIntegration_WithVMCP tests the complete integration with vMCP +func TestOptimizerIntegration_WithVMCP(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Create MCP server + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + + // Create mock backend client + mockClient := newMockIntegrationBackendClient() + mockClient.addBackend("github", &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + }, + }, + }) + + // Configure optimizer + optimizerConfig := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "placeholder", + Dimension: 384, + }, + } + + // Create optimizer integration + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Ingest backends + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err) + + // Simulate session registration + session := &mockIntegrationSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + BackendID: "github", + }, + }, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "create_issue": { + WorkloadID: "github", + WorkloadName: "GitHub", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Note: We don't test RegisterTools here because it requires the session + // to be properly registered with the MCP server, which is beyond the scope + // of this integration test. The RegisterTools method is tested separately + // in unit tests where we can properly mock the MCP server behavior. +} diff --git a/pkg/vmcp/optimizer/optimizer_unit_test.go b/pkg/vmcp/optimizer/optimizer_unit_test.go new file mode 100644 index 0000000000..794069b851 --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer_unit_test.go @@ -0,0 +1,260 @@ +package optimizer + +import ( + "context" + "path/filepath" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" +) + +// mockBackendClient implements vmcp.BackendClient for testing +type mockBackendClient struct { + capabilities *vmcp.CapabilityList + err error +} + +func (m *mockBackendClient) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + if m.err != nil { + return nil, m.err + } + return m.capabilities, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) { + return nil, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) { + return "", nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) { + return nil, nil +} + +// mockSession implements server.ClientSession for testing +type mockSession struct { + sessionID string +} + +func (m *mockSession) SessionID() string { + return m.sessionID +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Send(_ interface{}) error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Close() error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Initialize() { + // No-op for testing +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Initialized() bool { + return true +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + // Return a dummy channel for testing + ch := make(chan mcp.JSONRPCNotification, 1) + return ch +} + +// TestNewIntegration_Disabled tests that nil is returned when optimizer is disabled +func TestNewIntegration_Disabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Test with nil config + integration, err := NewIntegration(ctx, nil, nil, nil) + require.NoError(t, err) + assert.Nil(t, integration, "Should return nil when config is nil") + + // Test with disabled config + config := &Config{Enabled: false} + integration, err = NewIntegration(ctx, config, nil, nil) + require.NoError(t, err) + assert.Nil(t, integration, "Should return nil when optimizer is disabled") +} + +// TestNewIntegration_Enabled tests successful creation +func TestNewIntegration_Enabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "placeholder", + Dimension: 384, + }, + } + + integration, err := NewIntegration(ctx, config, mcpServer, mockClient) + require.NoError(t, err) + require.NotNil(t, integration) + defer func() { _ = integration.Close() }() +} + +// TestOnRegisterSession tests session registration +func TestOnRegisterSession(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "placeholder", + Dimension: 384, + }, + } + + integration, err := NewIntegration(ctx, config, mcpServer, mockClient) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + session := &mockSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{ + { + Name: "test_tool", + Description: "A test tool", + BackendID: "backend-1", + }, + }, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": { + WorkloadID: "backend-1", + WorkloadName: "Test Backend", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + err = integration.OnRegisterSession(ctx, session, capabilities) + assert.NoError(t, err) +} + +// TestOnRegisterSession_NilIntegration tests nil integration handling +func TestOnRegisterSession_NilIntegration(t *testing.T) { + t.Parallel() + ctx := context.Background() + + var integration *OptimizerIntegration = nil + session := &mockSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{} + + err := integration.OnRegisterSession(ctx, session, capabilities) + assert.NoError(t, err) +} + +// TestRegisterTools tests tool registration behavior +func TestRegisterTools(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "placeholder", + Dimension: 384, + }, + } + + integration, err := NewIntegration(ctx, config, mcpServer, mockClient) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + session := &mockSession{sessionID: "test-session"} + // RegisterTools will fail with "session not found" because the mock session + // is not actually registered with the MCP server. This is expected behavior. + // We're just testing that the method executes without panicking. + _ = integration.RegisterTools(ctx, session) +} + +// TestRegisterTools_NilIntegration tests nil integration handling +func TestRegisterTools_NilIntegration(t *testing.T) { + t.Parallel() + ctx := context.Background() + + var integration *OptimizerIntegration = nil + session := &mockSession{sessionID: "test-session"} + + err := integration.RegisterTools(ctx, session) + assert.NoError(t, err) +} + +// TestClose tests cleanup +func TestClose(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "placeholder", + Dimension: 384, + }, + } + + integration, err := NewIntegration(ctx, config, mcpServer, mockClient) + require.NoError(t, err) + + err = integration.Close() + assert.NoError(t, err) + + // Multiple closes should be safe + err = integration.Close() + assert.NoError(t, err) +} + +// TestClose_NilIntegration tests nil integration close +func TestClose_NilIntegration(t *testing.T) { + t.Parallel() + + var integration *OptimizerIntegration = nil + err := integration.Close() + assert.NoError(t, err) +} diff --git a/pkg/vmcp/router/default_router.go b/pkg/vmcp/router/default_router.go index c38b2145bf..7e32731ed9 100644 --- a/pkg/vmcp/router/default_router.go +++ b/pkg/vmcp/router/default_router.go @@ -3,6 +3,7 @@ package router import ( "context" "fmt" + "strings" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" @@ -75,7 +76,21 @@ func routeCapability( // RouteTool resolves a tool name to its backend target. // With lazy discovery, this method gets capabilities from the request context // instead of using a cached routing table. +// +// Special handling for optimizer tools: +// - Tools with "optim." prefix (optim.find_tool, optim.call_tool) are handled by vMCP itself +// - These tools are registered during session initialization and don't route to backends +// - The SDK handles these tools directly via registered handlers func (*defaultRouter) RouteTool(ctx context.Context, toolName string) (*vmcp.BackendTarget, error) { + // Optimizer tools (optim.*) are handled by vMCP itself, not routed to backends. + // The SDK will invoke the registered handler directly. + // We return ErrToolNotFound here so the handler factory doesn't try to create + // a backend routing handler for these tools. + if strings.HasPrefix(toolName, "optim.") { + logger.Debugf("Optimizer tool %s is handled by vMCP, not routed to backend", toolName) + return nil, fmt.Errorf("%w: optimizer tool %s is handled by vMCP", ErrToolNotFound, toolName) + } + return routeCapability( ctx, toolName, diff --git a/pkg/vmcp/server/mocks/mock_watcher.go b/pkg/vmcp/server/mocks/mock_watcher.go index 6bfdac7f0b..4044825b14 100644 --- a/pkg/vmcp/server/mocks/mock_watcher.go +++ b/pkg/vmcp/server/mocks/mock_watcher.go @@ -13,6 +13,9 @@ import ( context "context" reflect "reflect" + server "github.com/mark3labs/mcp-go/server" + vmcp "github.com/stacklok/toolhive/pkg/vmcp" + aggregator "github.com/stacklok/toolhive/pkg/vmcp/aggregator" gomock "go.uber.org/mock/gomock" ) @@ -53,3 +56,83 @@ func (mr *MockWatcherMockRecorder) WaitForCacheSync(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WaitForCacheSync", reflect.TypeOf((*MockWatcher)(nil).WaitForCacheSync), ctx) } + +// MockOptimizerIntegration is a mock of OptimizerIntegration interface. +type MockOptimizerIntegration struct { + ctrl *gomock.Controller + recorder *MockOptimizerIntegrationMockRecorder + isgomock struct{} +} + +// MockOptimizerIntegrationMockRecorder is the mock recorder for MockOptimizerIntegration. +type MockOptimizerIntegrationMockRecorder struct { + mock *MockOptimizerIntegration +} + +// NewMockOptimizerIntegration creates a new mock instance. +func NewMockOptimizerIntegration(ctrl *gomock.Controller) *MockOptimizerIntegration { + mock := &MockOptimizerIntegration{ctrl: ctrl} + mock.recorder = &MockOptimizerIntegrationMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOptimizerIntegration) EXPECT() *MockOptimizerIntegrationMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockOptimizerIntegration) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockOptimizerIntegrationMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockOptimizerIntegration)(nil).Close)) +} + +// IngestInitialBackends mocks base method. +func (m *MockOptimizerIntegration) IngestInitialBackends(ctx context.Context, backends []vmcp.Backend) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IngestInitialBackends", ctx, backends) + ret0, _ := ret[0].(error) + return ret0 +} + +// IngestInitialBackends indicates an expected call of IngestInitialBackends. +func (mr *MockOptimizerIntegrationMockRecorder) IngestInitialBackends(ctx, backends any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IngestInitialBackends", reflect.TypeOf((*MockOptimizerIntegration)(nil).IngestInitialBackends), ctx, backends) +} + +// OnRegisterSession mocks base method. +func (m *MockOptimizerIntegration) OnRegisterSession(ctx context.Context, session server.ClientSession, capabilities *aggregator.AggregatedCapabilities) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "OnRegisterSession", ctx, session, capabilities) + ret0, _ := ret[0].(error) + return ret0 +} + +// OnRegisterSession indicates an expected call of OnRegisterSession. +func (mr *MockOptimizerIntegrationMockRecorder) OnRegisterSession(ctx, session, capabilities any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRegisterSession", reflect.TypeOf((*MockOptimizerIntegration)(nil).OnRegisterSession), ctx, session, capabilities) +} + +// RegisterTools mocks base method. +func (m *MockOptimizerIntegration) RegisterTools(ctx context.Context, session server.ClientSession) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterTools", ctx, session) + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterTools indicates an expected call of RegisterTools. +func (mr *MockOptimizerIntegrationMockRecorder) RegisterTools(ctx, session any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterTools", reflect.TypeOf((*MockOptimizerIntegration)(nil).RegisterTools), ctx, session) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 98866d0568..e0fc6235e4 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -20,6 +20,7 @@ import ( "github.com/stacklok/toolhive/pkg/audit" "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" "github.com/stacklok/toolhive/pkg/recovery" "github.com/stacklok/toolhive/pkg/telemetry" transportsession "github.com/stacklok/toolhive/pkg/transport/session" @@ -28,6 +29,7 @@ import ( "github.com/stacklok/toolhive/pkg/vmcp/composer" "github.com/stacklok/toolhive/pkg/vmcp/discovery" "github.com/stacklok/toolhive/pkg/vmcp/health" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" "github.com/stacklok/toolhive/pkg/vmcp/router" "github.com/stacklok/toolhive/pkg/vmcp/server/adapter" vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" @@ -119,6 +121,38 @@ type Config struct { // Only set when running in K8s with outgoingAuth.source: discovered. // Used for /readyz endpoint to gate readiness on cache sync. Watcher Watcher + + // OptimizerConfig is the optional optimizer configuration. + // If nil or Enabled=false, optimizer tools (optim.find_tool, optim.call_tool) are not available. + OptimizerConfig *OptimizerConfig +} + +// OptimizerConfig holds optimizer-specific configuration for vMCP integration. +type OptimizerConfig struct { + // Enabled controls whether optimizer tools are available + Enabled bool + + // PersistPath is the optional path for chromem-go database persistence (empty = in-memory) + PersistPath string + + // FTSDBPath is the path to SQLite FTS5 database for BM25 search + // (empty = auto-default: ":memory:" or "{PersistPath}/fts.db") + FTSDBPath string + + // HybridSearchRatio controls semantic vs BM25 mix (0.0-1.0, default: 0.7) + HybridSearchRatio float64 + + // EmbeddingBackend specifies the embedding provider (vllm, ollama, placeholder) + EmbeddingBackend string + + // EmbeddingURL is the URL for the embedding service (vLLM or Ollama) + EmbeddingURL string + + // EmbeddingModel is the model name for embeddings + EmbeddingModel string + + // EmbeddingDimension is the embedding vector dimension + EmbeddingDimension int } // Server is the Virtual MCP Server that aggregates multiple backends. @@ -187,6 +221,26 @@ type Server struct { // Lock for writes (initialization, disabling on start failure). healthMonitor *health.Monitor healthMonitorMu sync.RWMutex + + // optimizerIntegration provides semantic tool discovery via optim.find_tool and optim.call_tool. + // Nil if optimizer is disabled. + optimizerIntegration OptimizerIntegration +} + +// OptimizerIntegration is the interface for optimizer functionality in vMCP. +// This is defined as an interface to avoid circular dependencies and allow testing. +type OptimizerIntegration interface { + // IngestInitialBackends ingests all discovered backends at startup + IngestInitialBackends(ctx context.Context, backends []vmcp.Backend) error + + // OnRegisterSession generates embeddings for session tools + OnRegisterSession(ctx context.Context, session server.ClientSession, capabilities *aggregator.AggregatedCapabilities) error + + // RegisterTools adds optim.find_tool and optim.call_tool to the session + RegisterTools(ctx context.Context, session server.ClientSession) error + + // Close cleans up optimizer resources + Close() error } // New creates a new Virtual MCP Server instance. @@ -329,21 +383,61 @@ func New( logger.Info("Health monitoring disabled") } + // Initialize optimizer integration if enabled + var optimizerInteg OptimizerIntegration + if cfg.OptimizerConfig != nil && cfg.OptimizerConfig.Enabled { + logger.Infow("Initializing optimizer integration (chromem-go)", + "persist_path", cfg.OptimizerConfig.PersistPath, + "embedding_backend", cfg.OptimizerConfig.EmbeddingBackend) + + // Convert server config to optimizer config + hybridRatio := 0.7 // Default + if cfg.OptimizerConfig.HybridSearchRatio != 0 { + hybridRatio = cfg.OptimizerConfig.HybridSearchRatio + } + optimizerCfg := &optimizer.Config{ + Enabled: cfg.OptimizerConfig.Enabled, + PersistPath: cfg.OptimizerConfig.PersistPath, + FTSDBPath: cfg.OptimizerConfig.FTSDBPath, + HybridSearchRatio: hybridRatio, + EmbeddingConfig: &embeddings.Config{ + BackendType: cfg.OptimizerConfig.EmbeddingBackend, + BaseURL: cfg.OptimizerConfig.EmbeddingURL, + Model: cfg.OptimizerConfig.EmbeddingModel, + Dimension: cfg.OptimizerConfig.EmbeddingDimension, + }, + } + + optimizerInteg, err = optimizer.NewIntegration(ctx, optimizerCfg, mcpServer, backendClient) + if err != nil { + return nil, fmt.Errorf("failed to initialize optimizer: %w", err) + } + logger.Info("Optimizer integration initialized successfully") + + // Ingest discovered backends at startup (populate optimizer database) + initialBackends := backendRegistry.List(ctx) + if err := optimizerInteg.IngestInitialBackends(ctx, initialBackends); err != nil { + logger.Warnf("Failed to ingest initial backends: %v", err) + // Don't fail server startup - optimizer can still work with incremental ingestion + } + } + // Create Server instance srv := &Server{ - config: cfg, - mcpServer: mcpServer, - router: rt, - backendClient: backendClient, - handlerFactory: handlerFactory, - discoveryMgr: discoveryMgr, - backendRegistry: backendRegistry, - sessionManager: sessionManager, - capabilityAdapter: capabilityAdapter, - workflowDefs: workflowDefs, - workflowExecutors: workflowExecutors, - ready: make(chan struct{}), - healthMonitor: healthMon, + config: cfg, + mcpServer: mcpServer, + router: rt, + backendClient: backendClient, + handlerFactory: handlerFactory, + discoveryMgr: discoveryMgr, + backendRegistry: backendRegistry, + sessionManager: sessionManager, + capabilityAdapter: capabilityAdapter, + workflowDefs: workflowDefs, + workflowExecutors: workflowExecutors, + ready: make(chan struct{}), + healthMonitor: healthMon, + optimizerIntegration: optimizerInteg, } // Register OnRegisterSession hook to inject capabilities after SDK registers session. @@ -430,6 +524,30 @@ func New( "session_id", sessionID, "tool_count", len(caps.Tools), "resource_count", len(caps.Resources)) + + // Generate embeddings and register optimizer tools if enabled + if srv.optimizerIntegration != nil { + logger.Debugw("Generating embeddings for optimizer", "session_id", sessionID) + + // Generate embeddings for all tools in this session + if err := srv.optimizerIntegration.OnRegisterSession(ctx, session, caps); err != nil { + logger.Errorw("failed to generate embeddings for optimizer", + "error", err, + "session_id", sessionID) + // Don't fail session initialization - continue without optimizer + } else { + // Register optimizer tools (optim.find_tool, optim.call_tool) + if err := srv.optimizerIntegration.RegisterTools(ctx, session); err != nil { + logger.Errorw("failed to register optimizer tools", + "error", err, + "session_id", sessionID) + // Don't fail session initialization - continue without optimizer tools + } else { + logger.Infow("optimizer tools registered", + "session_id", sessionID) + } + } + } }) return srv, nil diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000000..09a382f6b0 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,96 @@ +# ToolHive Scripts + +Utility scripts for development, testing, and debugging. + +## Optimizer Database Inspection + +Tools to inspect the vMCP optimizer's hybrid database (chromem-go + SQLite FTS5). + +### SQLite FTS5 Database + +```bash +# Quick shell script wrapper +./scripts/inspect-optimizer-db.sh /tmp/vmcp-optimizer-fts.db + +# Or use sqlite3 directly +sqlite3 /tmp/vmcp-optimizer-fts.db "SELECT COUNT(*) FROM backend_tools_fts;" +``` + +### chromem-go Vector Database + +chromem-go stores data in binary `.gob` format. Use these Go scripts: + +#### Quick Summary +```bash +go run scripts/inspect-chromem-raw/inspect-chromem-raw.go /tmp/vmcp-optimizer-debug.db +``` +Shows collection sizes and first few documents from each collection. + +**Example output:** +``` +📁 Collection ID: 5ff43c0b + Documents: 4 + - Document ID: github + Content: github + Embedding: 384 dimensions + Type: backend_server +``` + +#### Detailed View +```bash +# View specific tool +go run scripts/view-chromem-tool/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db get_file_contents + +# View all documents +go run scripts/view-chromem-tool/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db + +# Search by name/content +go run scripts/view-chromem-tool/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db "search" +``` + +**Example output:** +``` +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Document ID: 4da1128d-7800-4d4a-a28e-9d1ad8fcb989 +Content: get_file_contents. Get the contents of a file... +Embedding Dimensions: 384 + +Metadata: + data: { + "id": "4da1128d-7800-4d4a-a28e-9d1ad8fcb989", + "mcpserver_id": "github", + "tool_name": "get_file_contents", + "description": "Get the contents of a file or directory...", + "token_count": 38, + ... + } + server_id: github + type: backend_tool + +Embedding (first 10): [0.000, 0.003, 0.001, 0.005, ...] +``` + +#### VSCode Integration + +For SQLite files, install the VSCode extension: +```bash +code --install-extension alexcvzz.vscode-sqlite +``` + +Then open any `.db` file in VSCode to browse tables visually. + +## Testing Scripts + +### Optimizer Tests +```bash +# Test with sqlite-vec extension +./scripts/test-optimizer-with-sqlite-vec.sh +``` + +## Contributing + +When adding new scripts: +1. Make shell scripts executable: `chmod +x scripts/your-script.sh` +2. Add error handling and usage instructions +3. Document the script in this README +4. Test on both macOS and Linux if possible diff --git a/scripts/inspect-chromem-raw/inspect-chromem-raw.go b/scripts/inspect-chromem-raw/inspect-chromem-raw.go new file mode 100644 index 0000000000..caef4d524f --- /dev/null +++ b/scripts/inspect-chromem-raw/inspect-chromem-raw.go @@ -0,0 +1,106 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "encoding/gob" + "fmt" + "os" + "path/filepath" +) + +// Minimal structures to decode chromem-go documents +type Document struct { + ID string + Metadata map[string]string + Embedding []float32 + Content string +} + +func main() { + if len(os.Args) < 2 { + fmt.Println("Usage: go run inspect-chromem-raw.go ") + os.Exit(1) + } + + dbPath := os.Args[1] + fmt.Printf("📊 Raw inspection of chromem-go database: %s\n\n", dbPath) + + // Read all collection directories + entries, err := os.ReadDir(dbPath) + if err != nil { + fmt.Printf("Error reading directory: %v\n", err) + os.Exit(1) + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + collectionPath := filepath.Join(dbPath, entry.Name()) + fmt.Printf("📁 Collection ID: %s\n", entry.Name()) + + // Count gob files + gobFiles, err := filepath.Glob(filepath.Join(collectionPath, "*.gob")) + if err != nil { + fmt.Printf(" Error: %v\n", err) + continue + } + + fmt.Printf(" Documents: %d\n", len(gobFiles)) + + // Show first few documents + limit := 5 + if len(gobFiles) > limit { + fmt.Printf(" (showing first %d)\n", limit) + } + + for i, gobFile := range gobFiles { + if i >= limit { + break + } + + doc, err := decodeGobFile(gobFile) + if err != nil { + fmt.Printf(" - %s (error decoding: %v)\n", filepath.Base(gobFile), err) + continue + } + + fmt.Printf(" - Document ID: %s\n", doc.ID) + fmt.Printf(" Content: %s\n", truncate(doc.Content, 80)) + fmt.Printf(" Embedding: %d dimensions\n", len(doc.Embedding)) + if serverID, ok := doc.Metadata["server_id"]; ok { + fmt.Printf(" Server ID: %s\n", serverID) + } + if docType, ok := doc.Metadata["type"]; ok { + fmt.Printf(" Type: %s\n", docType) + } + } + fmt.Println() + } +} + +func decodeGobFile(path string) (*Document, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + dec := gob.NewDecoder(f) + var doc Document + if err := dec.Decode(&doc); err != nil { + return nil, err + } + + return &doc, nil +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/scripts/inspect-chromem/inspect-chromem.go b/scripts/inspect-chromem/inspect-chromem.go new file mode 100644 index 0000000000..672741b5ae --- /dev/null +++ b/scripts/inspect-chromem/inspect-chromem.go @@ -0,0 +1,123 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "context" + "fmt" + "os" + + "github.com/philippgille/chromem-go" +) + +func main() { + if len(os.Args) < 2 { + fmt.Println("Usage: go run inspect-chromem.go ") + fmt.Println("Example: go run inspect-chromem.go /tmp/vmcp-optimizer-debug.db") + os.Exit(1) + } + + dbPath := os.Args[1] + + // Open the chromem-go database + db, err := chromem.NewPersistentDB(dbPath, true) // true = read-only + if err != nil { + fmt.Printf("Error opening database: %v\n", err) + os.Exit(1) + } + + fmt.Printf("📊 Inspecting chromem-go database at: %s\n\n", dbPath) + + // List collections + fmt.Println("📁 Collections:") + fmt.Println(" - backend_servers") + fmt.Println(" - backend_tools") + fmt.Println() + + // Create a dummy embedding function (we're just inspecting, not querying) + dummyEmbedding := func(ctx context.Context, text string) ([]float32, error) { + return make([]float32, 384), nil // Placeholder + } + + // Inspect backend_servers collection + serversCol := db.GetCollection("backend_servers", dummyEmbedding) + if serversCol != nil { + count := serversCol.Count() + fmt.Printf("🖥️ Backend Servers Collection: %d documents\n", count) + + if count > 0 { + // Query all documents (using a generic query with high limit) + results, err := serversCol.Query(context.Background(), "", count, nil, nil) + if err == nil { + fmt.Println(" Servers:") + for _, doc := range results { + fmt.Printf(" - ID: %s\n", doc.ID) + fmt.Printf(" Content: %s\n", truncate(doc.Content, 80)) + if len(doc.Embedding) > 0 { + fmt.Printf(" Embedding: %d dimensions\n", len(doc.Embedding)) + } + fmt.Printf(" Metadata keys: %v\n", getKeys(doc.Metadata)) + } + } + } + } else { + fmt.Println("🖥️ Backend Servers Collection: not found") + } + fmt.Println() + + // Inspect backend_tools collection + toolsCol := db.GetCollection("backend_tools", dummyEmbedding) + if toolsCol != nil { + count := toolsCol.Count() + fmt.Printf("🔧 Backend Tools Collection: %d documents\n", count) + + if count > 0 && count < 20 { + // Only show details if there aren't too many + results, err := toolsCol.Query(context.Background(), "", count, nil, nil) + if err == nil { + fmt.Println(" Tools:") + for i, doc := range results { + if i >= 10 { + fmt.Printf(" ... and %d more tools\n", count-10) + break + } + fmt.Printf(" - ID: %s\n", doc.ID) + fmt.Printf(" Content: %s\n", truncate(doc.Content, 80)) + if len(doc.Embedding) > 0 { + fmt.Printf(" Embedding: %d dimensions\n", len(doc.Embedding)) + } + fmt.Printf(" Server ID: %s\n", doc.Metadata["server_id"]) + } + } + } else if count >= 20 { + fmt.Printf(" (too many to display, use query commands below)\n") + } + } else { + fmt.Println("🔧 Backend Tools Collection: not found") + } + fmt.Println() + + // Show example queries + fmt.Println("💡 Example Queries:") + fmt.Println(" To search for tools semantically:") + fmt.Println(" results, _ := toolsCol.Query(ctx, \"search repositories on GitHub\", 5, nil, nil)") + fmt.Println() + fmt.Println(" To filter by server:") + fmt.Println(" results, _ := toolsCol.Query(ctx, \"list files\", 5, map[string]string{\"server_id\": \"github\"}, nil)") +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +func getKeys(m map[string]string) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/scripts/inspect-optimizer-db.sh b/scripts/inspect-optimizer-db.sh new file mode 100755 index 0000000000..b8d5ad8168 --- /dev/null +++ b/scripts/inspect-optimizer-db.sh @@ -0,0 +1,63 @@ +#!/bin/bash +# Inspect the optimizer SQLite FTS5 database + +set -e + +DB_PATH="${1:-/tmp/vmcp-optimizer-fts.db}" + +if [ ! -f "$DB_PATH" ]; then + echo "Error: Database not found at $DB_PATH" + echo "Usage: $0 [path-to-db]" + exit 1 +fi + +echo "📊 Optimizer FTS5 Database: $DB_PATH" +echo "" + +echo "📈 Statistics:" +sqlite3 "$DB_PATH" < [tool-name]") + fmt.Println("Example: go run view-chromem-tool.go /tmp/vmcp-optimizer-debug.db get_file_contents") + os.Exit(1) + } + + dbPath := os.Args[1] + searchTerm := "" + if len(os.Args) > 2 { + searchTerm = os.Args[2] + } + + // Read all collections + entries, err := os.ReadDir(dbPath) + if err != nil { + fmt.Printf("Error: %v\n", err) + os.Exit(1) + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + collectionPath := filepath.Join(dbPath, entry.Name()) + gobFiles, err := filepath.Glob(filepath.Join(collectionPath, "*.gob")) + if err != nil { + continue + } + + for _, gobFile := range gobFiles { + doc, err := decodeGobFile(gobFile) + if err != nil { + continue + } + + // Skip empty documents + if doc.ID == "" { + continue + } + + // If searching, filter by content + if searchTerm != "" && !contains(doc.Content, searchTerm) && !contains(doc.ID, searchTerm) { + continue + } + + fmt.Println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + fmt.Printf("Document ID: %s\n", doc.ID) + fmt.Printf("Content: %s\n", doc.Content) + fmt.Printf("Embedding Dimensions: %d\n", len(doc.Embedding)) + + // Show metadata + fmt.Println("\nMetadata:") + for key, value := range doc.Metadata { + if key == "data" { + // Pretty print JSON + var jsonData interface{} + if err := json.Unmarshal([]byte(value), &jsonData); err == nil { + prettyJSON, _ := json.MarshalIndent(jsonData, " ", " ") + fmt.Printf(" %s: %s\n", key, string(prettyJSON)) + } else { + fmt.Printf(" %s: %s\n", key, truncate(value, 200)) + } + } else { + fmt.Printf(" %s: %s\n", key, value) + } + } + + // Show first few embedding values + if len(doc.Embedding) > 0 { + fmt.Printf("\nEmbedding (first 10): [") + for i := 0; i < min(10, len(doc.Embedding)); i++ { + if i > 0 { + fmt.Print(", ") + } + fmt.Printf("%.3f", doc.Embedding[i]) + } + fmt.Println(", ...]") + } + fmt.Println() + } + } +} + +func decodeGobFile(path string) (*Document, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + dec := gob.NewDecoder(f) + var doc Document + if err := dec.Decode(&doc); err != nil { + return nil, err + } + + return &doc, nil +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && + (s == substr || + len(s) > len(substr) && + (s[:len(substr)] == substr || + s[len(s)-len(substr):] == substr || + findSubstring(s, substr))) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} From 04ebcf162747652064a12e17e58636b11a5aba63 Mon Sep 17 00:00:00 2001 From: Nigel Brown Date: Mon, 19 Jan 2026 10:32:23 +0000 Subject: [PATCH 03/16] feat: Add optimizer integration endpoints and tool discovery (#3318) * feat: Add optimizer integration endpoints and tool discovery - Add find_tool and call_tool endpoints to vmcp optimizer - Add semantic search and string matching for tool discovery - Update optimizer integration documentation - Add test scripts for optimizer functionality --- codecov.yaml | 2 + examples/vmcp-config-optimizer.yaml | 10 +- pkg/optimizer/INTEGRATION.md | 5 +- pkg/optimizer/README.md | 10 +- pkg/optimizer/db/backend_server.go | 8 +- .../db/backend_server_test_coverage.go | 94 ++ pkg/optimizer/db/backend_tool.go | 8 +- pkg/optimizer/db/backend_tool_test.go | 24 +- .../db/backend_tool_test_coverage.go | 96 ++ pkg/optimizer/db/db.go | 34 +- pkg/optimizer/db/db_test.go | 302 +++++ pkg/optimizer/db/fts.go | 16 + pkg/optimizer/db/fts_test_coverage.go | 159 +++ pkg/optimizer/doc.go | 4 +- pkg/optimizer/embeddings/manager.go | 111 +- .../embeddings/manager_test_coverage.go | 155 +++ pkg/optimizer/embeddings/ollama.go | 12 +- pkg/optimizer/embeddings/ollama_test.go | 76 +- .../embeddings/openai_compatible_test.go | 26 +- pkg/optimizer/ingestion/service.go | 35 +- pkg/optimizer/ingestion/service_test.go | 32 +- .../ingestion/service_test_coverage.go | 282 +++++ pkg/vmcp/client/client.go | 2 - pkg/vmcp/discovery/manager.go | 41 +- pkg/vmcp/discovery/manager_test_coverage.go | 173 +++ pkg/vmcp/health/checker.go | 92 +- pkg/vmcp/health/checker_selfcheck_test.go | 501 ++++++++ pkg/vmcp/health/checker_test.go | 14 +- pkg/vmcp/health/monitor.go | 6 +- pkg/vmcp/health/monitor_test.go | 20 +- .../find_tool_semantic_search_test.go | 690 +++++++++++ .../find_tool_string_matching_test.go | 696 +++++++++++ pkg/vmcp/optimizer/optimizer.go | 385 +++++-- pkg/vmcp/optimizer/optimizer_handlers_test.go | 1026 +++++++++++++++++ .../optimizer/optimizer_integration_test.go | 25 +- pkg/vmcp/optimizer/optimizer_unit_test.go | 103 +- pkg/vmcp/server/optimizer_test.go | 350 ++++++ pkg/vmcp/server/server.go | 153 ++- scripts/README.md | 35 +- scripts/call-optim-find-tool/main.go | 137 +++ scripts/inspect-chromem/inspect-chromem.go | 4 +- scripts/test-optim-find-tool/main.go | 246 ++++ scripts/test-vmcp-find-tool/main.go | 158 +++ 43 files changed, 5944 insertions(+), 414 deletions(-) create mode 100644 pkg/optimizer/db/backend_server_test_coverage.go create mode 100644 pkg/optimizer/db/backend_tool_test_coverage.go create mode 100644 pkg/optimizer/db/db_test.go create mode 100644 pkg/optimizer/db/fts_test_coverage.go create mode 100644 pkg/optimizer/embeddings/manager_test_coverage.go create mode 100644 pkg/optimizer/ingestion/service_test_coverage.go create mode 100644 pkg/vmcp/discovery/manager_test_coverage.go create mode 100644 pkg/vmcp/health/checker_selfcheck_test.go create mode 100644 pkg/vmcp/optimizer/find_tool_semantic_search_test.go create mode 100644 pkg/vmcp/optimizer/find_tool_string_matching_test.go create mode 100644 pkg/vmcp/optimizer/optimizer_handlers_test.go create mode 100644 pkg/vmcp/server/optimizer_test.go create mode 100644 scripts/call-optim-find-tool/main.go create mode 100644 scripts/test-optim-find-tool/main.go create mode 100644 scripts/test-vmcp-find-tool/main.go diff --git a/codecov.yaml b/codecov.yaml index 1a8032e484..410f9ae7ee 100644 --- a/codecov.yaml +++ b/codecov.yaml @@ -13,6 +13,8 @@ coverage: - "**/mocks/**/*" - "**/mock_*.go" - "**/zz_generated.deepcopy.go" + - "**/*_test.go" + - "**/*_test_coverage.go" status: project: default: diff --git a/examples/vmcp-config-optimizer.yaml b/examples/vmcp-config-optimizer.yaml index 5b20b074d9..7687dabb7d 100644 --- a/examples/vmcp-config-optimizer.yaml +++ b/examples/vmcp-config-optimizer.yaml @@ -45,11 +45,11 @@ optimizer: # Enable the optimizer enabled: true - # Embedding backend: "ollama", "openai-compatible", or "placeholder" - # - "ollama": Uses local Ollama HTTP API for embeddings + # Embedding backend: "ollama" (default), "openai-compatible", or "vllm" + # - "ollama": Uses local Ollama HTTP API for embeddings (default, requires 'ollama serve') # - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) - # - "placeholder": Uses deterministic hash-based embeddings (for testing) - embeddingBackend: placeholder + # - "vllm": Alias for OpenAI-compatible API + embeddingBackend: ollama # Embedding dimension (common values: 384, 768, 1536) # 384 is standard for all-MiniLM-L6-v2 and nomic-embed-text @@ -75,7 +75,7 @@ optimizer: # Option 1: Local Ollama (good for development/testing) # embeddingBackend: ollama # embeddingURL: http://localhost:11434 - # embeddingModel: nomic-embed-text + # embeddingModel: all-minilm # Default model (all-MiniLM-L6-v2) # embeddingDimension: 384 # Option 2: vLLM (recommended for production with GPU acceleration) diff --git a/pkg/optimizer/INTEGRATION.md b/pkg/optimizer/INTEGRATION.md index 4d2db78b59..e1cbd4d2df 100644 --- a/pkg/optimizer/INTEGRATION.md +++ b/pkg/optimizer/INTEGRATION.md @@ -93,7 +93,10 @@ func TestOptimizerIntegration(t *testing.T) { optimizerSvc, err := ingestion.NewService(&ingestion.Config{ DBConfig: &db.Config{Path: "/tmp/test-optimizer.db"}, EmbeddingConfig: &embeddings.Config{ - BackendType: "placeholder", + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, Dimension: 384, }, }) diff --git a/pkg/optimizer/README.md b/pkg/optimizer/README.md index 2984f2697a..f1a14938aa 100644 --- a/pkg/optimizer/README.md +++ b/pkg/optimizer/README.md @@ -132,9 +132,11 @@ func main() { panic(err) } - // Initialize embedding manager with placeholder (no external dependencies) + // Initialize embedding manager with Ollama (default) embeddingMgr, err := embeddings.NewManager(&embeddings.Config{ - BackendType: "placeholder", + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", Dimension: 384, }) if err != nil { @@ -201,7 +203,7 @@ spec: ollama serve # Pull an embedding model -ollama pull nomic-embed-text +ollama pull all-minilm ``` Configure vMCP: @@ -211,7 +213,7 @@ optimizer: enabled: true embeddingBackend: ollama embeddingURL: http://localhost:11434 - embeddingModel: nomic-embed-text + embeddingModel: all-minilm embeddingDimension: 384 ``` diff --git a/pkg/optimizer/db/backend_server.go b/pkg/optimizer/db/backend_server.go index 8685d4c47d..84ae5a3742 100644 --- a/pkg/optimizer/db/backend_server.go +++ b/pkg/optimizer/db/backend_server.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "time" "github.com/philippgille/chromem-go" @@ -64,8 +65,13 @@ func (ops *BackendServerOps) Create(ctx context.Context, server *models.BackendS } // Also add to FTS5 database if available (for keyword filtering) + // Use background context to avoid cancellation issues - FTS5 is supplementary if ftsDB := ops.db.GetFTSDB(); ftsDB != nil { - if err := ftsDB.UpsertServer(ctx, server); err != nil { + // Use background context with timeout for FTS operations + // This ensures FTS operations complete even if the original context is canceled + ftsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := ftsDB.UpsertServer(ftsCtx, server); err != nil { // Log but don't fail - FTS5 is supplementary logger.Warnf("Failed to upsert server to FTS5: %v", err) } diff --git a/pkg/optimizer/db/backend_server_test_coverage.go b/pkg/optimizer/db/backend_server_test_coverage.go new file mode 100644 index 0000000000..411be12673 --- /dev/null +++ b/pkg/optimizer/db/backend_server_test_coverage.go @@ -0,0 +1,94 @@ +package db + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/models" +) + +// TestBackendServerOps_Create_FTS tests FTS integration in Create +func TestBackendServerOps_Create_FTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + ops := NewBackendServerOps(db, embeddingFunc) + + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: stringPtr("A test server"), + Group: "default", + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create should also update FTS + err = ops.Create(ctx, server) + require.NoError(t, err) + + // Verify FTS was updated by checking FTS DB directly + ftsDB := db.GetFTSDB() + require.NotNil(t, ftsDB) + + // FTS should have the server + // We can't easily query FTS directly, but we can verify it doesn't error +} + +// TestBackendServerOps_Delete_FTS tests FTS integration in Delete +func TestBackendServerOps_Delete_FTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + ops := NewBackendServerOps(db, embeddingFunc) + + desc := "A test server" + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: &desc, + Group: "default", + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create server + err = ops.Create(ctx, server) + require.NoError(t, err) + + // Delete should also delete from FTS + err = ops.Delete(ctx, server.ID) + require.NoError(t, err) +} diff --git a/pkg/optimizer/db/backend_tool.go b/pkg/optimizer/db/backend_tool.go index 909779edb8..3197428663 100644 --- a/pkg/optimizer/db/backend_tool.go +++ b/pkg/optimizer/db/backend_tool.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "time" "github.com/philippgille/chromem-go" @@ -63,8 +64,13 @@ func (ops *BackendToolOps) Create(ctx context.Context, tool *models.BackendTool, } // Also add to FTS5 database if available (for BM25 search) + // Use background context to avoid cancellation issues - FTS5 is supplementary if ops.db.fts != nil { - if err := ops.db.fts.UpsertToolMeta(ctx, tool, serverName); err != nil { + // Use background context with timeout for FTS operations + // This ensures FTS operations complete even if the original context is canceled + ftsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := ops.db.fts.UpsertToolMeta(ftsCtx, tool, serverName); err != nil { // Log but don't fail - FTS5 is supplementary logger.Warnf("Failed to upsert tool to FTS5: %v", err) } diff --git a/pkg/optimizer/db/backend_tool_test.go b/pkg/optimizer/db/backend_tool_test.go index 557e5ca5f5..95d2d5330b 100644 --- a/pkg/optimizer/db/backend_tool_test.go +++ b/pkg/optimizer/db/backend_tool_test.go @@ -12,7 +12,7 @@ import ( "github.com/stacklok/toolhive/pkg/optimizer/models" ) -// createTestDB creates a test database with placeholder embeddings +// createTestDB creates a test database func createTestDB(t *testing.T) *DB { t.Helper() tmpDir := t.TempDir() @@ -27,18 +27,23 @@ func createTestDB(t *testing.T) *DB { return db } -// createTestEmbeddingFunc creates a test embedding function using placeholder embeddings +// createTestEmbeddingFunc creates a test embedding function using Ollama embeddings func createTestEmbeddingFunc(t *testing.T) func(ctx context.Context, text string) ([]float32, error) { t.Helper() - // Create placeholder embedding manager + // Try to use Ollama if available, otherwise skip test config := &embeddings.Config{ - BackendType: "placeholder", + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", Dimension: 384, } manager, err := embeddings.NewManager(config) - require.NoError(t, err) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return nil + } t.Cleanup(func() { _ = manager.Close() }) return func(_ context.Context, text string) ([]float32, error) { @@ -454,9 +459,12 @@ func TestBackendToolOps_Search(t *testing.T) { require.NoError(t, err) assert.NotEmpty(t, results, "Should find tools") - // With placeholder embeddings, we just verify we get results - // Semantic similarity isn't guaranteed with hash-based embeddings - assert.Len(t, results, 2, "Should return both tools") + // Weather tool should be most similar to weather query + assert.NotEmpty(t, results, "Should find at least one tool") + if len(results) > 0 { + assert.Equal(t, "get_weather", results[0].ToolName, + "Weather tool should be most similar to weather query") + } } // TestBackendToolOps_Search_WithServerFilter tests search with server ID filter diff --git a/pkg/optimizer/db/backend_tool_test_coverage.go b/pkg/optimizer/db/backend_tool_test_coverage.go new file mode 100644 index 0000000000..a8766c302b --- /dev/null +++ b/pkg/optimizer/db/backend_tool_test_coverage.go @@ -0,0 +1,96 @@ +package db + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/models" +) + +// TestBackendToolOps_Create_FTS tests FTS integration in Create +func TestBackendToolOps_Create_FTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + ops := NewBackendToolOps(db, embeddingFunc) + + desc := "A test tool" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &desc, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 10, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create should also update FTS + err = ops.Create(ctx, tool, "TestServer") + require.NoError(t, err) + + // Verify FTS was updated + ftsDB := db.GetFTSDB() + require.NotNil(t, ftsDB) +} + +// TestBackendToolOps_DeleteByServer_FTS tests FTS integration in DeleteByServer +func TestBackendToolOps_DeleteByServer_FTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + ops := NewBackendToolOps(db, embeddingFunc) + + desc := "A test tool" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &desc, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 10, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create tool + err = ops.Create(ctx, tool, "TestServer") + require.NoError(t, err) + + // DeleteByServer should also delete from FTS + err = ops.DeleteByServer(ctx, "server-1") + require.NoError(t, err) +} diff --git a/pkg/optimizer/db/db.go b/pkg/optimizer/db/db.go index f7e7df5bb8..2e1b88a24f 100644 --- a/pkg/optimizer/db/db.go +++ b/pkg/optimizer/db/db.go @@ -3,6 +3,8 @@ package db import ( "context" "fmt" + "os" + "strings" "sync" "github.com/philippgille/chromem-go" @@ -54,7 +56,35 @@ func NewDB(config *Config) (*DB, error) { logger.Infof("Creating chromem-go database with persistence at: %s", config.PersistPath) chromemDB, err = chromem.NewPersistentDB(config.PersistPath, false) if err != nil { - return nil, fmt.Errorf("failed to create persistent database: %w", err) + // Check if error is due to corrupted database (missing collection metadata) + if strings.Contains(err.Error(), "collection metadata file not found") { + logger.Warnf("Database appears corrupted, attempting to remove and recreate: %v", err) + // Try to remove corrupted database directory + // Use RemoveAll which should handle directories recursively + // If it fails, we'll try to create with a new path or fall back to in-memory + if removeErr := os.RemoveAll(config.PersistPath); removeErr != nil { + logger.Warnf("Failed to remove corrupted database directory (may be in use): %v. Will try to recreate anyway.", removeErr) + // Try to rename the corrupted directory and create a new one + backupPath := config.PersistPath + ".corrupted" + if renameErr := os.Rename(config.PersistPath, backupPath); renameErr != nil { + logger.Warnf("Failed to rename corrupted database: %v. Attempting to create database anyway.", renameErr) + // Continue and let chromem-go handle it - it might work if the corruption is partial + } else { + logger.Infof("Renamed corrupted database to: %s", backupPath) + } + } + // Retry creating the database + chromemDB, err = chromem.NewPersistentDB(config.PersistPath, false) + if err != nil { + // If still failing, return the error but suggest manual cleanup + return nil, fmt.Errorf( + "failed to create persistent database after cleanup attempt. Please manually remove %s and try again: %w", + config.PersistPath, err) + } + logger.Info("Successfully recreated database after cleanup") + } else { + return nil, fmt.Errorf("failed to create persistent database: %w", err) + } } } else { logger.Info("Creating in-memory chromem-go database") @@ -160,7 +190,7 @@ func (db *DB) GetFTSDB() *FTSDatabase { return db.fts } -// Reset clears all collections and FTS tables (useful for testing) +// Reset clears all collections and FTS tables (useful for testing and startup) func (db *DB) Reset() { db.mu.Lock() defer db.mu.Unlock() diff --git a/pkg/optimizer/db/db_test.go b/pkg/optimizer/db/db_test.go new file mode 100644 index 0000000000..2da34c214a --- /dev/null +++ b/pkg/optimizer/db/db_test.go @@ -0,0 +1,302 @@ +package db + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestNewDB_CorruptedDatabase tests database recovery from corruption +func TestNewDB_CorruptedDatabase(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "corrupted-db") + + // Create a directory that looks like a corrupted database + err := os.MkdirAll(dbPath, 0755) + require.NoError(t, err) + + // Create a file that might cause issues + err = os.WriteFile(filepath.Join(dbPath, "some-file"), []byte("corrupted"), 0644) + require.NoError(t, err) + + config := &Config{ + PersistPath: dbPath, + } + + // Should recover from corruption + db, err := NewDB(config) + require.NoError(t, err) + require.NotNil(t, db) + defer func() { _ = db.Close() }() +} + +// TestNewDB_CorruptedDatabase_RecoveryFailure tests when recovery fails +func TestNewDB_CorruptedDatabase_RecoveryFailure(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "corrupted-db") + + // Create a directory that looks like a corrupted database + err := os.MkdirAll(dbPath, 0755) + require.NoError(t, err) + + // Create a file that might cause issues + err = os.WriteFile(filepath.Join(dbPath, "some-file"), []byte("corrupted"), 0644) + require.NoError(t, err) + + // Make directory read-only to simulate recovery failure + // Note: This might not work on all systems, so we'll test the error path differently + // Instead, we'll test with an invalid path that can't be created + config := &Config{ + PersistPath: "/invalid/path/that/does/not/exist", + } + + _, err = NewDB(config) + // Should return error for invalid path + assert.Error(t, err) +} + +// TestDB_GetOrCreateCollection tests collection creation and retrieval +func TestDB_GetOrCreateCollection(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + // Create a simple embedding function + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Get or create collection + collection, err := db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + require.NotNil(t, collection) + + // Get existing collection + collection2, err := db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + require.NotNil(t, collection2) + assert.Equal(t, collection, collection2) +} + +// TestDB_GetCollection tests collection retrieval +func TestDB_GetCollection(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Get non-existent collection should fail + _, err = db.GetCollection("non-existent", embeddingFunc) + assert.Error(t, err) + + // Create collection first + _, err = db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + + // Now get it + collection, err := db.GetCollection("test-collection", embeddingFunc) + require.NoError(t, err) + require.NotNil(t, collection) +} + +// TestDB_DeleteCollection tests collection deletion +func TestDB_DeleteCollection(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Create collection + _, err = db.GetOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + + // Delete collection + db.DeleteCollection("test-collection") + + // Verify it's deleted + _, err = db.GetCollection("test-collection", embeddingFunc) + assert.Error(t, err) +} + +// TestDB_Reset tests database reset +func TestDB_Reset(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Create collections + _, err = db.GetOrCreateCollection(ctx, BackendServerCollection, embeddingFunc) + require.NoError(t, err) + + _, err = db.GetOrCreateCollection(ctx, BackendToolCollection, embeddingFunc) + require.NoError(t, err) + + // Reset database + db.Reset() + + // Verify collections are deleted + _, err = db.GetCollection(BackendServerCollection, embeddingFunc) + assert.Error(t, err) + + _, err = db.GetCollection(BackendToolCollection, embeddingFunc) + assert.Error(t, err) +} + +// TestDB_GetChromemDB tests chromem DB accessor +func TestDB_GetChromemDB(t *testing.T) { + t.Parallel() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + chromemDB := db.GetChromemDB() + require.NotNil(t, chromemDB) +} + +// TestDB_GetFTSDB tests FTS DB accessor +func TestDB_GetFTSDB(t *testing.T) { + t.Parallel() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + defer func() { _ = db.Close() }() + + ftsDB := db.GetFTSDB() + require.NotNil(t, ftsDB) +} + +// TestDB_Close tests database closing +func TestDB_Close(t *testing.T) { + t.Parallel() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := NewDB(config) + require.NoError(t, err) + + err = db.Close() + require.NoError(t, err) + + // Multiple closes should be safe + err = db.Close() + require.NoError(t, err) +} + +// TestNewDB_FTSDBPath tests FTS database path configuration +func TestNewDB_FTSDBPath(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + tests := []struct { + name string + config *Config + wantErr bool + }{ + { + name: "in-memory FTS with persistent chromem", + config: &Config{ + PersistPath: filepath.Join(tmpDir, "db"), + FTSDBPath: ":memory:", + }, + wantErr: false, + }, + { + name: "persistent FTS with persistent chromem", + config: &Config{ + PersistPath: filepath.Join(tmpDir, "db2"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + }, + wantErr: false, + }, + { + name: "default FTS path with persistent chromem", + config: &Config{ + PersistPath: filepath.Join(tmpDir, "db3"), + // FTSDBPath not set, should default to {PersistPath}/fts.db + }, + wantErr: false, + }, + { + name: "in-memory FTS with in-memory chromem", + config: &Config{ + PersistPath: "", + FTSDBPath: ":memory:", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db, err := NewDB(tt.config) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, db) + defer func() { _ = db.Close() }() + + // Verify FTS DB is accessible + ftsDB := db.GetFTSDB() + require.NotNil(t, ftsDB) + } + }) + } +} diff --git a/pkg/optimizer/db/fts.go b/pkg/optimizer/db/fts.go index 8dde0b2aa3..e9cecd7a09 100644 --- a/pkg/optimizer/db/fts.go +++ b/pkg/optimizer/db/fts.go @@ -316,6 +316,22 @@ func (fts *FTSDatabase) SearchBM25( return results, nil } +// GetTotalToolTokens returns the sum of token_count across all tools +func (fts *FTSDatabase) GetTotalToolTokens(ctx context.Context) (int, error) { + fts.mu.RLock() + defer fts.mu.RUnlock() + + var totalTokens int + query := "SELECT COALESCE(SUM(token_count), 0) FROM backend_tools_fts" + + err := fts.db.QueryRowContext(ctx, query).Scan(&totalTokens) + if err != nil { + return 0, fmt.Errorf("failed to get total tool tokens: %w", err) + } + + return totalTokens, nil +} + // Close closes the FTS database connection func (fts *FTSDatabase) Close() error { return fts.db.Close() diff --git a/pkg/optimizer/db/fts_test_coverage.go b/pkg/optimizer/db/fts_test_coverage.go new file mode 100644 index 0000000000..b6a7fe2321 --- /dev/null +++ b/pkg/optimizer/db/fts_test_coverage.go @@ -0,0 +1,159 @@ +package db + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/models" +) + +// stringPtr returns a pointer to the given string +func stringPtr(s string) *string { + return &s +} + +// TestFTSDatabase_GetTotalToolTokens tests token counting +func TestFTSDatabase_GetTotalToolTokens(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &FTSConfig{ + DBPath: ":memory:", + } + + ftsDB, err := NewFTSDatabase(config) + require.NoError(t, err) + defer func() { _ = ftsDB.Close() }() + + // Initially should be 0 + totalTokens, err := ftsDB.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.Equal(t, 0, totalTokens) + + // Add a tool + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: stringPtr("Test tool"), + TokenCount: 100, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + err = ftsDB.UpsertToolMeta(ctx, tool, "TestServer") + require.NoError(t, err) + + // Should now have tokens + totalTokens, err = ftsDB.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.Equal(t, 100, totalTokens) + + // Add another tool + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "test_tool2", + Description: stringPtr("Test tool 2"), + TokenCount: 50, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + err = ftsDB.UpsertToolMeta(ctx, tool2, "TestServer") + require.NoError(t, err) + + // Should sum tokens + totalTokens, err = ftsDB.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.Equal(t, 150, totalTokens) +} + +// TestSanitizeFTS5Query tests query sanitization +func TestSanitizeFTS5Query(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "remove quotes", + input: `"test query"`, + expected: "test query", + }, + { + name: "remove wildcards", + input: "test*query", + expected: "test query", + }, + { + name: "remove parentheses", + input: "test(query)", + expected: "test query", + }, + { + name: "remove multiple spaces", + input: "test query", + expected: "test query", + }, + { + name: "trim whitespace", + input: " test query ", + expected: "test query", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "only special characters", + input: `"*()`, + expected: "", + }, + { + name: "mixed special characters", + input: `test"query*with(special)chars`, + expected: "test query with special chars", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := sanitizeFTS5Query(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestFTSDatabase_SearchBM25_EmptyQuery tests empty query handling +func TestFTSDatabase_SearchBM25_EmptyQuery(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &FTSConfig{ + DBPath: ":memory:", + } + + ftsDB, err := NewFTSDatabase(config) + require.NoError(t, err) + defer func() { _ = ftsDB.Close() }() + + // Empty query should return empty results + results, err := ftsDB.SearchBM25(ctx, "", 10, nil) + require.NoError(t, err) + assert.Empty(t, results) + + // Query with only special characters should return empty results + results, err = ftsDB.SearchBM25(ctx, `"*()`, 10, nil) + require.NoError(t, err) + assert.Empty(t, results) +} diff --git a/pkg/optimizer/doc.go b/pkg/optimizer/doc.go index 0808bb76b2..549bf23900 100644 --- a/pkg/optimizer/doc.go +++ b/pkg/optimizer/doc.go @@ -69,7 +69,9 @@ // // // Create embedding manager // embMgr, err := embeddings.NewManager(embeddings.Config{ -// BackendType: "placeholder", // or "ollama" or "openai-compatible" +// BackendType: "ollama", // or "openai-compatible" or "vllm" +// BaseURL: "http://localhost:11434", +// Model: "all-minilm", // Dimension: 384, // }) // diff --git a/pkg/optimizer/embeddings/manager.go b/pkg/optimizer/embeddings/manager.go index 9ccc94fca3..70ac838492 100644 --- a/pkg/optimizer/embeddings/manager.go +++ b/pkg/optimizer/embeddings/manager.go @@ -8,17 +8,19 @@ import ( ) const ( - // BackendTypePlaceholder is the placeholder backend type - BackendTypePlaceholder = "placeholder" + // DefaultModelAllMiniLM is the default Ollama model name + DefaultModelAllMiniLM = "all-minilm" + // BackendTypeOllama is the Ollama backend type + BackendTypeOllama = "ollama" ) // Config holds configuration for the embedding manager type Config struct { // BackendType specifies which backend to use: - // - "ollama": Ollama native API + // - "ollama": Ollama native API (default) // - "vllm": vLLM OpenAI-compatible API // - "unified": Generic OpenAI-compatible API (works with both) - // - "placeholder": Hash-based embeddings for testing + // - "openai": OpenAI-compatible API BackendType string // BaseURL is the base URL for the embedding service @@ -27,7 +29,7 @@ type Config struct { BaseURL string // Model is the model name to use - // - Ollama: "nomic-embed-text", "all-minilm" + // - Ollama: "all-minilm" (default), "nomic-embed-text" // - vLLM: "sentence-transformers/all-MiniLM-L6-v2", "intfloat/e5-mistral-7b-instruct" Model string @@ -68,9 +70,9 @@ func NewManager(config *Config) (*Manager, error) { config.MaxCacheSize = 1000 } - // Default to placeholder (zero dependencies) + // Default to Ollama if config.BackendType == "" { - config.BackendType = "placeholder" + config.BackendType = BackendTypeOllama } // Initialize backend based on configuration @@ -78,7 +80,7 @@ func NewManager(config *Config) (*Manager, error) { var err error switch config.BackendType { - case "ollama": + case BackendTypeOllama: // Use Ollama native API (requires ollama serve) baseURL := config.BaseURL if baseURL == "" { @@ -86,13 +88,17 @@ func NewManager(config *Config) (*Manager, error) { } model := config.Model if model == "" { - model = "nomic-embed-text" + model = DefaultModelAllMiniLM // Default: all-MiniLM-L6-v2 + } + // Update dimension if not set and using default model + if config.Dimension == 0 && model == DefaultModelAllMiniLM { + config.Dimension = 384 } backend, err = NewOllamaBackend(baseURL, model) if err != nil { - logger.Warnf("Failed to initialize Ollama backend: %v", err) - logger.Info("Falling back to placeholder embeddings. To use Ollama: ollama serve && ollama pull nomic-embed-text") - backend = &PlaceholderBackend{dimension: config.Dimension} + return nil, fmt.Errorf( + "failed to initialize Ollama backend: %w (ensure 'ollama serve' is running and 'ollama pull %s' has been executed)", + err, DefaultModelAllMiniLM) } case "vllm", "unified", "openai": @@ -107,17 +113,11 @@ func NewManager(config *Config) (*Manager, error) { } backend, err = NewOpenAICompatibleBackend(config.BaseURL, config.Model, config.Dimension) if err != nil { - logger.Warnf("Failed to initialize %s backend: %v", config.BackendType, err) - logger.Infof("Falling back to placeholder embeddings") - backend = &PlaceholderBackend{dimension: config.Dimension} + return nil, fmt.Errorf("failed to initialize %s backend: %w", config.BackendType, err) } - case BackendTypePlaceholder: - // Use placeholder for testing - backend = &PlaceholderBackend{dimension: config.Dimension} - default: - return nil, fmt.Errorf("unknown backend type: %s (supported: ollama, vllm, unified, placeholder)", config.BackendType) + return nil, fmt.Errorf("unknown backend type: %s (supported: ollama, vllm, unified, openai)", config.BackendType) } m := &Manager{ @@ -154,17 +154,7 @@ func (m *Manager) GenerateEmbedding(texts []string) ([][]float32, error) { // Use backend to generate embeddings embeddings, err := m.backend.EmbedBatch(texts) if err != nil { - // If backend fails, fall back to placeholder for non-placeholder backends - if m.config.BackendType != "placeholder" { - logger.Warnf("%s backend failed: %v, falling back to placeholder", m.config.BackendType, err) - placeholder := &PlaceholderBackend{dimension: m.config.Dimension} - embeddings, err = placeholder.EmbedBatch(texts) - if err != nil { - return nil, fmt.Errorf("failed to generate embeddings: %w", err) - } - } else { - return nil, fmt.Errorf("failed to generate embeddings: %w", err) - } + return nil, fmt.Errorf("failed to generate embeddings: %w", err) } // Cache single embeddings @@ -176,65 +166,6 @@ func (m *Manager) GenerateEmbedding(texts []string) ([][]float32, error) { return embeddings, nil } -// PlaceholderBackend is a simple backend for testing -type PlaceholderBackend struct { - dimension int -} - -// Embed generates a deterministic hash-based embedding for the given text. -func (p *PlaceholderBackend) Embed(text string) ([]float32, error) { - return p.generatePlaceholderEmbedding(text), nil -} - -// EmbedBatch generates embeddings for multiple texts. -func (p *PlaceholderBackend) EmbedBatch(texts []string) ([][]float32, error) { - embeddings := make([][]float32, len(texts)) - for i, text := range texts { - embeddings[i] = p.generatePlaceholderEmbedding(text) - } - return embeddings, nil -} - -// Dimension returns the embedding dimension. -func (p *PlaceholderBackend) Dimension() int { - return p.dimension -} - -// Close closes the backend (no-op for placeholder). -func (*PlaceholderBackend) Close() error { - return nil -} - -func (p *PlaceholderBackend) generatePlaceholderEmbedding(text string) []float32 { - embedding := make([]float32, p.dimension) - - // Simple hash-based generation for testing - hash := 0 - for _, c := range text { - hash = (hash*31 + int(c)) % 1000000 - } - - // Generate deterministic values - for i := range embedding { - hash = (hash*1103515245 + 12345) % 1000000 - embedding[i] = float32(hash) / 1000000.0 - } - - // Normalize the embedding (L2 normalization) - var norm float32 - for _, v := range embedding { - norm += v * v - } - if norm > 0 { - norm = float32(1.0 / float64(norm)) - for i := range embedding { - embedding[i] *= norm - } - } - - return embedding -} - // GetCacheStats returns cache statistics func (m *Manager) GetCacheStats() map[string]interface{} { if !m.config.EnableCache || m.cache == nil { diff --git a/pkg/optimizer/embeddings/manager_test_coverage.go b/pkg/optimizer/embeddings/manager_test_coverage.go new file mode 100644 index 0000000000..98eb4a9eec --- /dev/null +++ b/pkg/optimizer/embeddings/manager_test_coverage.go @@ -0,0 +1,155 @@ +package embeddings + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestManager_GetCacheStats tests cache statistics +func TestManager_GetCacheStats(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: true, + MaxCacheSize: 100, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + stats := manager.GetCacheStats() + require.NotNil(t, stats) + assert.True(t, stats["enabled"].(bool)) + assert.Contains(t, stats, "hits") + assert.Contains(t, stats, "misses") + assert.Contains(t, stats, "size") + assert.Contains(t, stats, "maxsize") +} + +// TestManager_GetCacheStats_Disabled tests cache statistics when cache is disabled +func TestManager_GetCacheStats_Disabled(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + stats := manager.GetCacheStats() + require.NotNil(t, stats) + assert.False(t, stats["enabled"].(bool)) +} + +// TestManager_ClearCache tests cache clearing +func TestManager_ClearCache(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: true, + MaxCacheSize: 100, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + // Clear cache should not panic + manager.ClearCache() + + // Multiple clears should be safe + manager.ClearCache() +} + +// TestManager_ClearCache_Disabled tests cache clearing when cache is disabled +func TestManager_ClearCache_Disabled(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + // Clear cache should not panic even when disabled + manager.ClearCache() +} + +// TestManager_Dimension tests dimension accessor +func TestManager_Dimension(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + dimension := manager.Dimension() + assert.Equal(t, 384, dimension) +} + +// TestManager_Dimension_Default tests default dimension +func TestManager_Dimension_Default(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + // Dimension not set, should default to 384 + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + dimension := manager.Dimension() + assert.Equal(t, 384, dimension) +} diff --git a/pkg/optimizer/embeddings/ollama.go b/pkg/optimizer/embeddings/ollama.go index d6f4874375..a05af2af11 100644 --- a/pkg/optimizer/embeddings/ollama.go +++ b/pkg/optimizer/embeddings/ollama.go @@ -31,21 +31,27 @@ type ollamaEmbedResponse struct { // NewOllamaBackend creates a new Ollama backend // Requires Ollama to be running locally: ollama serve -// Default model: nomic-embed-text (768 dimensions) +// Default model: all-minilm (all-MiniLM-L6-v2, 384 dimensions) func NewOllamaBackend(baseURL, model string) (*OllamaBackend, error) { if baseURL == "" { baseURL = "http://localhost:11434" } if model == "" { - model = "nomic-embed-text" // Default embedding model + model = "all-minilm" // Default embedding model (all-MiniLM-L6-v2) } logger.Infof("Initializing Ollama backend (model: %s, url: %s)", model, baseURL) + // Determine dimension based on model + dimension := 384 // Default for all-minilm + if model == "nomic-embed-text" { + dimension = 768 + } + backend := &OllamaBackend{ baseURL: baseURL, model: model, - dimension: 768, // nomic-embed-text dimension + dimension: dimension, client: &http.Client{}, } diff --git a/pkg/optimizer/embeddings/ollama_test.go b/pkg/optimizer/embeddings/ollama_test.go index 5254b7c072..83594863e5 100644 --- a/pkg/optimizer/embeddings/ollama_test.go +++ b/pkg/optimizer/embeddings/ollama_test.go @@ -4,13 +4,12 @@ import ( "testing" ) -func TestOllamaBackend_Placeholder(t *testing.T) { +func TestOllamaBackend_ConnectionFailure(t *testing.T) { t.Parallel() - // This test verifies that Ollama backend is properly structured - // Actual Ollama tests require ollama to be running + // This test verifies that Ollama backend handles connection failures gracefully // Test that NewOllamaBackend handles connection failure gracefully - _, err := NewOllamaBackend("http://localhost:99999", "nomic-embed-text") + _, err := NewOllamaBackend("http://localhost:99999", "all-minilm") if err == nil { t.Error("Expected error when connecting to invalid Ollama URL") } @@ -18,68 +17,36 @@ func TestOllamaBackend_Placeholder(t *testing.T) { func TestManagerWithOllama(t *testing.T) { t.Parallel() - // Test that Manager falls back to placeholder when Ollama is not available or model not pulled + // Test that Manager works with Ollama when available config := &Config{ - BackendType: "ollama", - Dimension: 384, + BackendType: BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: DefaultModelAllMiniLM, + Dimension: 768, EnableCache: true, MaxCacheSize: 100, } manager, err := NewManager(config) if err != nil { - t.Fatalf("Failed to create manager: %v", err) - } - defer manager.Close() - - // Should work with placeholder backend fallback - // (Ollama might not have model pulled, so it falls back to placeholder) - embeddings, err := manager.GenerateEmbedding([]string{"test text"}) - - // If Ollama is available with the model, great! - // If not, it should have fallen back to placeholder - if err != nil { - // Check if it's a "model not found" error - this is expected - if embeddings == nil { - t.Skip("Ollama not available or model not pulled (expected in CI/test environments)") - } - } - - if len(embeddings) != 1 { - t.Errorf("Expected 1 embedding, got %d", len(embeddings)) - } - - // Dimension could be 384 (placeholder) or 768 (Ollama nomic-embed-text) - if len(embeddings[0]) != 384 && len(embeddings[0]) != 768 { - t.Errorf("Expected dimension 384 or 768, got %d", len(embeddings[0])) - } -} - -func TestManagerWithPlaceholder(t *testing.T) { - t.Parallel() - // Test explicit placeholder backend - config := &Config{ - BackendType: "placeholder", - Dimension: 384, - EnableCache: false, - } - - manager, err := NewManager(config) - if err != nil { - t.Fatalf("Failed to create manager: %v", err) + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return } defer manager.Close() // Test single embedding - embeddings, err := manager.GenerateEmbedding([]string{"hello world"}) + embeddings, err := manager.GenerateEmbedding([]string{"test text"}) if err != nil { - t.Fatalf("Failed to generate embedding: %v", err) + // Model might not be pulled - skip gracefully + t.Skipf("Skipping test: Failed to generate embedding. Error: %v. Run 'ollama pull nomic-embed-text'", err) + return } if len(embeddings) != 1 { t.Errorf("Expected 1 embedding, got %d", len(embeddings)) } + // Ollama all-minilm uses 384 dimensions if len(embeddings[0]) != 384 { t.Errorf("Expected dimension 384, got %d", len(embeddings[0])) } @@ -88,19 +55,12 @@ func TestManagerWithPlaceholder(t *testing.T) { texts := []string{"text 1", "text 2", "text 3"} embeddings, err = manager.GenerateEmbedding(texts) if err != nil { - t.Fatalf("Failed to generate batch embeddings: %v", err) + // Model might not be pulled - skip gracefully + t.Skipf("Skipping test: Failed to generate batch embeddings. Error: %v. Run 'ollama pull nomic-embed-text'", err) + return } if len(embeddings) != 3 { t.Errorf("Expected 3 embeddings, got %d", len(embeddings)) } - - // Verify embeddings are deterministic - embeddings2, _ := manager.GenerateEmbedding([]string{"text 1"}) - for i := range embeddings[0] { - if embeddings[0][i] != embeddings2[0][i] { - t.Error("Embeddings should be deterministic") - break - } - } } diff --git a/pkg/optimizer/embeddings/openai_compatible_test.go b/pkg/optimizer/embeddings/openai_compatible_test.go index 916ad0cb8f..e829d2d6ac 100644 --- a/pkg/optimizer/embeddings/openai_compatible_test.go +++ b/pkg/optimizer/embeddings/openai_compatible_test.go @@ -206,30 +206,18 @@ func TestManagerWithUnified(t *testing.T) { func TestManagerFallbackBehavior(t *testing.T) { t.Parallel() - // Test that invalid vLLM backend falls back to placeholder + // Test that invalid vLLM backend fails gracefully during initialization + // (No fallback behavior is currently implemented) config := &Config{ BackendType: "vllm", - BaseURL: "http://invalid-host-that-does-not-exist:99999", + BaseURL: "http://invalid-host-that-does-not-exist:9999", Model: "test-model", Dimension: 384, } - manager, err := NewManager(config) - if err != nil { - t.Fatalf("Failed to create manager: %v", err) - } - defer manager.Close() - - // Should still work with placeholder fallback - embeddings, err := manager.GenerateEmbedding([]string{"test"}) - if err != nil { - t.Fatalf("Failed to generate embeddings with fallback: %v", err) - } - - if len(embeddings) != 1 { - t.Errorf("Expected 1 embedding, got %d", len(embeddings)) - } - if len(embeddings[0]) != 384 { - t.Errorf("Expected dimension 384, got %d", len(embeddings[0])) + _, err := NewManager(config) + if err == nil { + t.Error("Expected error when creating manager with invalid backend URL") } + // Test passes if error is returned (no fallback behavior) } diff --git a/pkg/optimizer/ingestion/service.go b/pkg/optimizer/ingestion/service.go index 821f970d6f..9b63e01289 100644 --- a/pkg/optimizer/ingestion/service.go +++ b/pkg/optimizer/ingestion/service.go @@ -65,6 +65,11 @@ func NewService(config *Config) (*Service, error) { return nil, fmt.Errorf("failed to initialize database: %w", err) } + // Clear database on startup to ensure fresh embeddings + // This is important when the embedding model changes or for consistency + database.Reset() + logger.Info("Cleared optimizer database on startup") + // Initialize embedding manager embeddingManager, err := embeddings.NewManager(config.EmbeddingConfig) if err != nil { @@ -124,7 +129,7 @@ func (s *Service) IngestServer( description *string, tools []mcp.Tool, ) error { - logger.Infof("Ingesting server: %s (%d tools)", serverName, len(tools)) + logger.Infof("Ingesting server: %s (%d tools) [serverID=%s]", serverName, len(tools), serverID) // Create backend server record (simplified - vMCP manages lifecycle) // chromem-go will generate embeddings automatically from the content @@ -155,6 +160,7 @@ func (s *Service) IngestServer( // syncBackendTools synchronizes tools for a backend server func (s *Service) syncBackendTools(ctx context.Context, serverID string, serverName string, tools []mcp.Tool) (int, error) { + logger.Debugf("syncBackendTools: server=%s, serverID=%s, tool_count=%d", serverName, serverID, len(tools)) // Delete existing tools if err := s.backendToolOps.DeleteByServer(ctx, serverID); err != nil { return 0, fmt.Errorf("failed to delete existing tools: %w", err) @@ -195,6 +201,33 @@ func (s *Service) syncBackendTools(ctx context.Context, serverID string, serverN return len(tools), nil } +// GetEmbeddingManager returns the embedding manager for this service +func (s *Service) GetEmbeddingManager() *embeddings.Manager { + return s.embeddingManager +} + +// GetBackendToolOps returns the backend tool operations for search and retrieval +func (s *Service) GetBackendToolOps() *db.BackendToolOps { + return s.backendToolOps +} + +// GetTotalToolTokens returns the total token count across all tools in the database +func (s *Service) GetTotalToolTokens(ctx context.Context) int { + // Use FTS database to efficiently count all tool tokens + if s.database.GetFTSDB() != nil { + totalTokens, err := s.database.GetFTSDB().GetTotalToolTokens(ctx) + if err != nil { + logger.Warnw("Failed to get total tool tokens from FTS", "error", err) + return 0 + } + return totalTokens + } + + // Fallback: query all tools (less efficient but works) + logger.Warn("FTS database not available, using fallback for token counting") + return 0 +} + // Close releases resources func (s *Service) Close() error { var errs []error diff --git a/pkg/optimizer/ingestion/service_test.go b/pkg/optimizer/ingestion/service_test.go index 51c73767b8..acc5b18754 100644 --- a/pkg/optimizer/ingestion/service_test.go +++ b/pkg/optimizer/ingestion/service_test.go @@ -25,14 +25,31 @@ func TestServiceCreationAndIngestion(t *testing.T) { // Create temporary directory for persistence (optional) tmpDir := t.TempDir() - // Initialize service with placeholder embeddings (no dependencies) + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + // Initialize service with Ollama embeddings config := &Config{ DBConfig: &db.Config{ PersistPath: filepath.Join(tmpDir, "test-db"), }, EmbeddingConfig: &embeddings.Config{ - BackendType: "placeholder", // Use placeholder for testing - Dimension: 384, + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, }, } @@ -78,11 +95,11 @@ func TestServiceCreationAndIngestion(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, results, "Should find at least one similar tool") - // With placeholder embeddings (hash-based), semantic similarity isn't guaranteed - // Just verify we got results back - require.Len(t, results, 2, "Should return both tools") + require.NotEmpty(t, results, "Should return at least one result") - // Verify both tools are present (order doesn't matter with placeholder embeddings) + // Weather tool should be most similar to weather query + require.Equal(t, "get_weather", results[0].ToolName, + "Weather tool should be most similar to weather query") toolNamesFound := make(map[string]bool) for _, result := range results { toolNamesFound[result.ToolName] = true @@ -142,7 +159,6 @@ func TestServiceWithOllama(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, results) - // With real embeddings, weather tool should be most similar require.Equal(t, "get_weather", results[0].ToolName, "Weather tool should be most similar to weather query") } diff --git a/pkg/optimizer/ingestion/service_test_coverage.go b/pkg/optimizer/ingestion/service_test_coverage.go new file mode 100644 index 0000000000..2328db7120 --- /dev/null +++ b/pkg/optimizer/ingestion/service_test_coverage.go @@ -0,0 +1,282 @@ +package ingestion + +import ( + "context" + "path/filepath" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/db" + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" +) + +// TestService_GetTotalToolTokens tests token counting +func TestService_GetTotalToolTokens(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Ingest some tools + tools := []mcp.Tool{ + { + Name: "tool1", + Description: "Tool 1", + }, + { + Name: "tool2", + Description: "Tool 2", + }, + } + + err = svc.IngestServer(ctx, "server-1", "TestServer", nil, tools) + require.NoError(t, err) + + // Get total tokens + totalTokens := svc.GetTotalToolTokens(ctx) + assert.GreaterOrEqual(t, totalTokens, 0, "Total tokens should be non-negative") +} + +// TestService_GetTotalToolTokens_NoFTS tests token counting without FTS +func TestService_GetTotalToolTokens_NoFTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: "", // In-memory + FTSDBPath: "", // Will default to :memory: + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Get total tokens (should use FTS if available, fallback otherwise) + totalTokens := svc.GetTotalToolTokens(ctx) + assert.GreaterOrEqual(t, totalTokens, 0, "Total tokens should be non-negative") +} + +// TestService_GetBackendToolOps tests backend tool ops accessor +func TestService_GetBackendToolOps(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + toolOps := svc.GetBackendToolOps() + require.NotNil(t, toolOps) +} + +// TestService_GetEmbeddingManager tests embedding manager accessor +func TestService_GetEmbeddingManager(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + manager := svc.GetEmbeddingManager() + require.NotNil(t, manager) +} + +// TestService_IngestServer_ErrorHandling tests error handling during ingestion +func TestService_IngestServer_ErrorHandling(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Test with empty tools list + err = svc.IngestServer(ctx, "server-1", "TestServer", nil, []mcp.Tool{}) + require.NoError(t, err, "Should handle empty tools list gracefully") + + // Test with nil description + err = svc.IngestServer(ctx, "server-2", "TestServer2", nil, []mcp.Tool{ + { + Name: "tool1", + Description: "Tool 1", + }, + }) + require.NoError(t, err, "Should handle nil description gracefully") +} + +// TestService_Close_ErrorHandling tests error handling during close +func TestService_Close_ErrorHandling(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + + // Close should succeed + err = svc.Close() + require.NoError(t, err) + + // Multiple closes should be safe + err = svc.Close() + require.NoError(t, err) +} diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index 09d69a2fda..e99533a83a 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -123,8 +123,6 @@ func (a *authRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) return nil, fmt.Errorf("authentication failed for backend %s: %w", a.target.WorkloadID, err) } - logger.Debugf("Applied authentication strategy %q to backend %s", a.authStrategy.Name(), a.target.WorkloadID) - return a.base.RoundTrip(reqClone) } diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go index 9bdfdc1d39..86c6b82482 100644 --- a/pkg/vmcp/discovery/manager.go +++ b/pkg/vmcp/discovery/manager.go @@ -15,6 +15,8 @@ import ( "sync" "time" + "golang.org/x/sync/singleflight" + "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" @@ -65,6 +67,9 @@ type DefaultManager struct { stopCh chan struct{} stopOnce sync.Once wg sync.WaitGroup + // singleFlight ensures only one aggregation happens per cache key at a time + // This prevents concurrent requests from all triggering aggregation + singleFlight singleflight.Group } // NewManager creates a new discovery manager with the given aggregator. @@ -128,6 +133,9 @@ func NewManagerWithRegistry(agg aggregator.Aggregator, registry vmcp.DynamicRegi // // The context must contain an authenticated user identity (set by auth middleware). // Returns ErrNoIdentity if user identity is not found in context. +// +// This method uses singleflight to ensure that concurrent requests for the same +// cache key only trigger one aggregation, preventing duplicate work. func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { // Validate user identity is present (set by auth middleware) // This ensures discovery happens with proper user authentication context @@ -139,7 +147,7 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) // Generate cache key from user identity and backend set cacheKey := m.generateCacheKey(identity.Subject, backends) - // Check cache first + // Check cache first (with read lock) if caps := m.getCachedCapabilities(cacheKey); caps != nil { logger.Debugf("Cache hit for user %s (key: %s)", identity.Subject, cacheKey) return caps, nil @@ -147,16 +155,33 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) logger.Debugf("Cache miss - performing capability discovery for user: %s", identity.Subject) - // Cache miss - perform aggregation - caps, err := m.aggregator.AggregateCapabilities(ctx, backends) + // Use singleflight to ensure only one aggregation happens per cache key + // Even if multiple requests come in concurrently, they'll all wait for the same result + result, err, _ := m.singleFlight.Do(cacheKey, func() (interface{}, error) { + // Double-check cache after acquiring singleflight lock + // Another goroutine might have populated it while we were waiting + if caps := m.getCachedCapabilities(cacheKey); caps != nil { + logger.Debugf("Cache populated while waiting - returning cached result for user %s", identity.Subject) + return caps, nil + } + + // Perform aggregation + caps, err := m.aggregator.AggregateCapabilities(ctx, backends) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) + } + + // Cache the result (skips caching if at capacity and key doesn't exist) + m.cacheCapabilities(cacheKey, caps) + + return caps, nil + }) + if err != nil { - return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) + return nil, err } - // Cache the result (skips caching if at capacity and key doesn't exist) - m.cacheCapabilities(cacheKey, caps) - - return caps, nil + return result.(*aggregator.AggregatedCapabilities), nil } // Stop gracefully stops the manager and cleans up resources. diff --git a/pkg/vmcp/discovery/manager_test_coverage.go b/pkg/vmcp/discovery/manager_test_coverage.go new file mode 100644 index 0000000000..2d31a9db56 --- /dev/null +++ b/pkg/vmcp/discovery/manager_test_coverage.go @@ -0,0 +1,173 @@ +package discovery + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/auth" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + aggmocks "github.com/stacklok/toolhive/pkg/vmcp/aggregator/mocks" + vmcpmocks "github.com/stacklok/toolhive/pkg/vmcp/mocks" +) + +// TestDefaultManager_CacheVersionMismatch tests cache invalidation on version mismatch +func TestDefaultManager_CacheVersionMismatch(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + mockRegistry := vmcpmocks.NewMockDynamicRegistry(ctrl) + + // First call - version 1 + mockRegistry.EXPECT().Version().Return(uint64(1)).Times(2) + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + manager, err := NewManagerWithRegistry(mockAggregator, mockRegistry) + require.NoError(t, err) + defer manager.Stop() + + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-1", + }) + + backends := []vmcp.Backend{ + {ID: "backend-1", Name: "Backend 1"}, + } + + // First discovery - should cache + caps1, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps1) + + // Second discovery with same version - should use cache + mockRegistry.EXPECT().Version().Return(uint64(1)).Times(1) + caps2, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps2) + + // Third discovery with different version - should invalidate cache + mockRegistry.EXPECT().Version().Return(uint64(2)).Times(2) + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + caps3, err := manager.Discover(ctx, backends) + require.NoError(t, err) + require.NotNil(t, caps3) +} + +// TestDefaultManager_CacheAtCapacity tests cache eviction when at capacity +func TestDefaultManager_CacheAtCapacity(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + + // Create many different cache keys to fill cache + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(maxCacheSize + 1) // One more than capacity + + manager, err := NewManager(mockAggregator) + require.NoError(t, err) + defer manager.Stop() + + // Fill cache to capacity + for i := 0; i < maxCacheSize; i++ { + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-" + string(rune(i)), + }) + + backends := []vmcp.Backend{ + {ID: "backend-" + string(rune(i)), Name: "Backend"}, + } + + _, err := manager.Discover(ctx, backends) + require.NoError(t, err) + } + + // Next discovery should not cache (at capacity) + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-new", + }) + + backends := []vmcp.Backend{ + {ID: "backend-new", Name: "Backend"}, + } + + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) +} + +// TestDefaultManager_CacheAtCapacity_ExistingKey tests cache update when at capacity but key exists +func TestDefaultManager_CacheAtCapacity_ExistingKey(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockAggregator := aggmocks.NewMockAggregator(ctrl) + + // First call + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + manager, err := NewManager(mockAggregator) + require.NoError(t, err) + defer manager.Stop() + + ctx := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-1", + }) + + backends := []vmcp.Backend{ + {ID: "backend-1", Name: "Backend 1"}, + } + + // First discovery + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) + + // Fill cache to capacity with other keys + for i := 0; i < maxCacheSize-1; i++ { + ctxOther := context.WithValue(context.Background(), auth.IdentityContextKey{}, &auth.Identity{ + Subject: "user-" + string(rune(i+2)), + }) + + backendsOther := []vmcp.Backend{ + {ID: "backend-" + string(rune(i+2)), Name: "Backend"}, + } + + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + _, err := manager.Discover(ctxOther, backendsOther) + require.NoError(t, err) + } + + // Update existing key should work even at capacity + mockAggregator.EXPECT(). + AggregateCapabilities(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + Times(1) + + _, err = manager.Discover(ctx, backends) + require.NoError(t, err) +} diff --git a/pkg/vmcp/health/checker.go b/pkg/vmcp/health/checker.go index 9705a0b788..c84a341fc4 100644 --- a/pkg/vmcp/health/checker.go +++ b/pkg/vmcp/health/checker.go @@ -8,6 +8,8 @@ import ( "context" "errors" "fmt" + "net/url" + "strings" "time" "github.com/stacklok/toolhive/pkg/logger" @@ -26,6 +28,10 @@ type healthChecker struct { // If a health check succeeds but takes longer than this duration, the backend is marked degraded. // Zero means disabled (backends will never be marked degraded based on response time alone). degradedThreshold time.Duration + + // selfURL is the server's own URL. If a health check targets this URL, it's short-circuited. + // This prevents the server from trying to health check itself. + selfURL string } // NewHealthChecker creates a new health checker that uses BackendClient.ListCapabilities @@ -36,13 +42,20 @@ type healthChecker struct { // - client: BackendClient for communicating with backend MCP servers // - timeout: Maximum duration for health check operations (0 = no timeout) // - degradedThreshold: Response time threshold for marking backend as degraded (0 = disabled) +// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited. // // Returns a new HealthChecker implementation. -func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degradedThreshold time.Duration) vmcp.HealthChecker { +func NewHealthChecker( + client vmcp.BackendClient, + timeout time.Duration, + degradedThreshold time.Duration, + selfURL string, +) vmcp.HealthChecker { return &healthChecker{ client: client, timeout: timeout, degradedThreshold: degradedThreshold, + selfURL: selfURL, } } @@ -59,16 +72,28 @@ func NewHealthChecker(client vmcp.BackendClient, timeout time.Duration, degraded // The error return is informational and provides context about what failed. // The BackendHealthStatus return indicates the categorized health state. func (h *healthChecker) CheckHealth(ctx context.Context, target *vmcp.BackendTarget) (vmcp.BackendHealthStatus, error) { - // Apply timeout if configured - checkCtx := ctx + // Mark context as health check to bypass authentication logging + // Health checks verify backend availability and should not require user credentials + healthCheckCtx := WithHealthCheckMarker(ctx) + + // Apply timeout if configured (after adding health check marker) + checkCtx := healthCheckCtx var cancel context.CancelFunc if h.timeout > 0 { - checkCtx, cancel = context.WithTimeout(ctx, h.timeout) + checkCtx, cancel = context.WithTimeout(healthCheckCtx, h.timeout) defer cancel() } logger.Debugf("Performing health check for backend %s (%s)", target.WorkloadName, target.BaseURL) + // Short-circuit health check if targeting ourselves + // This prevents the server from trying to health check itself, which would work + // but is wasteful and can cause connection issues during startup + if h.selfURL != "" && h.isSelfCheck(target.BaseURL) { + logger.Debugf("Skipping health check for backend %s - this is the server itself", target.WorkloadName) + return vmcp.BackendHealthy, nil + } + // Track response time for degraded detection startTime := time.Now() @@ -134,3 +159,62 @@ func categorizeError(err error) vmcp.BackendHealthStatus { // Default to unhealthy for unknown errors return vmcp.BackendUnhealthy } + +// isSelfCheck checks if a backend URL matches the server's own URL. +// URLs are normalized before comparison to handle variations like: +// - http://127.0.0.1:PORT vs http://localhost:PORT +// - http://HOST:PORT vs http://HOST:PORT/ +func (h *healthChecker) isSelfCheck(backendURL string) bool { + if h.selfURL == "" || backendURL == "" { + return false + } + + // Normalize both URLs for comparison + backendNormalized, err := NormalizeURLForComparison(backendURL) + if err != nil { + return false + } + + selfNormalized, err := NormalizeURLForComparison(h.selfURL) + if err != nil { + return false + } + + return backendNormalized == selfNormalized +} + +// NormalizeURLForComparison normalizes a URL for comparison by: +// - Parsing and reconstructing the URL +// - Converting localhost/127.0.0.1 to a canonical form +// - Comparing only scheme://host:port (ignoring path, query, fragment) +// - Lowercasing scheme and host +// Exported for testing purposes +func NormalizeURLForComparison(rawURL string) (string, error) { + u, err := url.Parse(rawURL) + if err != nil { + return "", err + } + // Validate that we have a scheme and host (basic URL validation) + if u.Scheme == "" || u.Host == "" { + return "", fmt.Errorf("invalid URL: missing scheme or host") + } + + // Normalize host: convert localhost to 127.0.0.1 for consistency + host := strings.ToLower(u.Hostname()) + if host == "localhost" { + host = "127.0.0.1" + } + + // Reconstruct URL with normalized components (scheme://host:port only) + // We ignore path, query, and fragment for comparison + normalized := &url.URL{ + Scheme: strings.ToLower(u.Scheme), + } + if u.Port() != "" { + normalized.Host = host + ":" + u.Port() + } else { + normalized.Host = host + } + + return normalized.String(), nil +} diff --git a/pkg/vmcp/health/checker_selfcheck_test.go b/pkg/vmcp/health/checker_selfcheck_test.go new file mode 100644 index 0000000000..fc42b071f0 --- /dev/null +++ b/pkg/vmcp/health/checker_selfcheck_test.go @@ -0,0 +1,501 @@ +package health + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" +) + +// TestHealthChecker_CheckHealth_SelfCheck tests self-check detection +func TestHealthChecker_CheckHealth_SelfCheck(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + // Should not call ListCapabilities for self-check + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", // Same as selfURL + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_Localhost tests localhost normalization +func TestHealthChecker_CheckHealth_SelfCheck_Localhost(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://localhost:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", // localhost should match 127.0.0.1 + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_Reverse tests reverse localhost normalization +func TestHealthChecker_CheckHealth_SelfCheck_Reverse(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", // 127.0.0.1 should match localhost + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_DifferentPort tests different ports don't match +func TestHealthChecker_CheckHealth_SelfCheck_DifferentPort(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8081", // Different port + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_EmptyURL tests empty URLs +func TestHealthChecker_CheckHealth_SelfCheck_EmptyURL(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_InvalidURL tests invalid URLs +func TestHealthChecker_CheckHealth_SelfCheck_InvalidURL(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "not-a-valid-url") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_SelfCheck_WithPath tests URLs with paths are normalized +func TestHealthChecker_CheckHealth_SelfCheck_WithPath(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Times(0) + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "http://127.0.0.1:8080") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://127.0.0.1:8080/mcp", // Path should be ignored + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status) +} + +// TestHealthChecker_CheckHealth_DegradedThreshold tests degraded threshold detection +func TestHealthChecker_CheckHealth_DegradedThreshold(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + // Simulate slow response + time.Sleep(150 * time.Millisecond) + return &vmcp.CapabilityList{}, nil + }). + Times(1) + + // Set degraded threshold to 100ms + checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendDegraded, status, "Should mark as degraded when response time exceeds threshold") +} + +// TestHealthChecker_CheckHealth_DegradedThreshold_Disabled tests disabled degraded threshold +func TestHealthChecker_CheckHealth_DegradedThreshold_Disabled(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + // Simulate slow response + time.Sleep(150 * time.Millisecond) + return &vmcp.CapabilityList{}, nil + }). + Times(1) + + // Set degraded threshold to 0 (disabled) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when threshold is disabled") +} + +// TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse tests fast response doesn't trigger degraded +func TestHealthChecker_CheckHealth_DegradedThreshold_FastResponse(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockClient := mocks.NewMockBackendClient(ctrl) + mockClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + Times(1) + + // Set degraded threshold to 100ms + checker := NewHealthChecker(mockClient, 5*time.Second, 100*time.Millisecond, "") + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "test-backend", + BaseURL: "http://localhost:8080", + } + + status, err := checker.CheckHealth(context.Background(), target) + assert.NoError(t, err) + assert.Equal(t, vmcp.BackendHealthy, status, "Should not mark as degraded when response is fast") +} + +// TestCategorizeError_SentinelErrors tests sentinel error categorization +func TestCategorizeError_SentinelErrors(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + expectedStatus vmcp.BackendHealthStatus + }{ + { + name: "ErrAuthenticationFailed", + err: vmcp.ErrAuthenticationFailed, + expectedStatus: vmcp.BackendUnauthenticated, + }, + { + name: "ErrAuthorizationFailed", + err: vmcp.ErrAuthorizationFailed, + expectedStatus: vmcp.BackendUnauthenticated, + }, + { + name: "ErrTimeout", + err: vmcp.ErrTimeout, + expectedStatus: vmcp.BackendUnhealthy, + }, + { + name: "ErrCancelled", + err: vmcp.ErrCancelled, + expectedStatus: vmcp.BackendUnhealthy, + }, + { + name: "ErrBackendUnavailable", + err: vmcp.ErrBackendUnavailable, + expectedStatus: vmcp.BackendUnhealthy, + }, + { + name: "wrapped ErrAuthenticationFailed", + err: errors.New("wrapped: " + vmcp.ErrAuthenticationFailed.Error()), + expectedStatus: vmcp.BackendUnauthenticated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + status := categorizeError(tt.err) + assert.Equal(t, tt.expectedStatus, status) + }) + } +} + +// TestNormalizeURLForComparison tests URL normalization +func TestNormalizeURLForComparison(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + wantErr bool + }{ + { + name: "localhost normalized to 127.0.0.1", + input: "http://localhost:8080", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "127.0.0.1 stays as is", + input: "http://127.0.0.1:8080", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "path is ignored", + input: "http://127.0.0.1:8080/mcp", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "query is ignored", + input: "http://127.0.0.1:8080?param=value", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "fragment is ignored", + input: "http://127.0.0.1:8080#fragment", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "scheme is lowercased", + input: "HTTP://127.0.0.1:8080", + expected: "http://127.0.0.1:8080", + wantErr: false, + }, + { + name: "host is lowercased", + input: "http://EXAMPLE.COM:8080", + expected: "http://example.com:8080", + wantErr: false, + }, + { + name: "no port", + input: "http://127.0.0.1", + expected: "http://127.0.0.1", + wantErr: false, + }, + { + name: "invalid URL", + input: "not-a-url", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result, err := NormalizeURLForComparison(tt.input) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +// TestIsSelfCheck_EdgeCases tests edge cases for self-check detection +func TestIsSelfCheck_EdgeCases(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(func() { ctrl.Finish() }) + + mockClient := mocks.NewMockBackendClient(ctrl) + + tests := []struct { + name string + selfURL string + backendURL string + expected bool + }{ + { + name: "both empty", + selfURL: "", + backendURL: "", + expected: false, + }, + { + name: "selfURL empty", + selfURL: "", + backendURL: "http://127.0.0.1:8080", + expected: false, + }, + { + name: "backendURL empty", + selfURL: "http://127.0.0.1:8080", + backendURL: "", + expected: false, + }, + { + name: "localhost matches 127.0.0.1", + selfURL: "http://localhost:8080", + backendURL: "http://127.0.0.1:8080", + expected: true, + }, + { + name: "127.0.0.1 matches localhost", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://localhost:8080", + expected: true, + }, + { + name: "different ports", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://127.0.0.1:8081", + expected: false, + }, + { + name: "different hosts", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://192.168.1.1:8080", + expected: false, + }, + { + name: "path ignored", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://127.0.0.1:8080/mcp", + expected: true, + }, + { + name: "query ignored", + selfURL: "http://127.0.0.1:8080", + backendURL: "http://127.0.0.1:8080?param=value", + expected: true, + }, + { + name: "invalid selfURL", + selfURL: "not-a-url", + backendURL: "http://127.0.0.1:8080", + expected: false, + }, + { + name: "invalid backendURL", + selfURL: "http://127.0.0.1:8080", + backendURL: "not-a-url", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + checker := NewHealthChecker(mockClient, 5*time.Second, 0, tt.selfURL) + hc, ok := checker.(*healthChecker) + require.True(t, ok) + + result := hc.isSelfCheck(tt.backendURL) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/pkg/vmcp/health/checker_test.go b/pkg/vmcp/health/checker_test.go index a0515cb3c2..818021bd33 100644 --- a/pkg/vmcp/health/checker_test.go +++ b/pkg/vmcp/health/checker_test.go @@ -41,7 +41,7 @@ func TestNewHealthChecker(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - checker := NewHealthChecker(mockClient, tt.timeout, 0) + checker := NewHealthChecker(mockClient, tt.timeout, 0, "") require.NotNil(t, checker) // Type assert to access internals for verification @@ -65,7 +65,7 @@ func TestHealthChecker_CheckHealth_Success(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). Times(1) - checker := NewHealthChecker(mockClient, 5*time.Second, 0) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -92,7 +92,7 @@ func TestHealthChecker_CheckHealth_ContextCancellation(t *testing.T) { }). Times(1) - checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0) + checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -120,7 +120,7 @@ func TestHealthChecker_CheckHealth_NoTimeout(t *testing.T) { Times(1) // Create checker with no timeout - checker := NewHealthChecker(mockClient, 0, 0) + checker := NewHealthChecker(mockClient, 0, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -210,7 +210,7 @@ func TestHealthChecker_CheckHealth_ErrorCategorization(t *testing.T) { Return(nil, tt.err). Times(1) - checker := NewHealthChecker(mockClient, 5*time.Second, 0) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -427,7 +427,7 @@ func TestHealthChecker_CheckHealth_Timeout(t *testing.T) { }). Times(1) - checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0) + checker := NewHealthChecker(mockClient, 100*time.Millisecond, 0, "") target := &vmcp.BackendTarget{ WorkloadID: "backend-1", WorkloadName: "test-backend", @@ -464,7 +464,7 @@ func TestHealthChecker_CheckHealth_MultipleBackends(t *testing.T) { }). Times(4) - checker := NewHealthChecker(mockClient, 5*time.Second, 0) + checker := NewHealthChecker(mockClient, 5*time.Second, 0, "") // Test healthy backend status, err := checker.CheckHealth(context.Background(), &vmcp.BackendTarget{ diff --git a/pkg/vmcp/health/monitor.go b/pkg/vmcp/health/monitor.go index aa6a26240f..ee49e5bf70 100644 --- a/pkg/vmcp/health/monitor.go +++ b/pkg/vmcp/health/monitor.go @@ -105,12 +105,14 @@ func DefaultConfig() MonitorConfig { // - client: BackendClient for communicating with backend MCP servers // - backends: List of backends to monitor // - config: Configuration for health monitoring +// - selfURL: Optional server's own URL. If provided, health checks targeting this URL are short-circuited. // // Returns (monitor, error). Error is returned if configuration is invalid. func NewMonitor( client vmcp.BackendClient, backends []vmcp.Backend, config MonitorConfig, + selfURL string, ) (*Monitor, error) { // Validate configuration if config.CheckInterval <= 0 { @@ -120,8 +122,8 @@ func NewMonitor( return nil, fmt.Errorf("unhealthy threshold must be >= 1, got %d", config.UnhealthyThreshold) } - // Create health checker with degraded threshold - checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold) + // Create health checker with degraded threshold and self URL + checker := NewHealthChecker(client, config.Timeout, config.DegradedThreshold, selfURL) // Create status tracker statusTracker := newStatusTracker(config.UnhealthyThreshold) diff --git a/pkg/vmcp/health/monitor_test.go b/pkg/vmcp/health/monitor_test.go index 0bb74f163f..36defadd04 100644 --- a/pkg/vmcp/health/monitor_test.go +++ b/pkg/vmcp/health/monitor_test.go @@ -63,7 +63,7 @@ func TestNewMonitor_Validation(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - monitor, err := NewMonitor(mockClient, backends, tt.config) + monitor, err := NewMonitor(mockClient, backends, tt.config, "") if tt.expectError { assert.Error(t, err) assert.Nil(t, monitor) @@ -98,7 +98,7 @@ func TestMonitor_StartStop(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) // Start monitor @@ -175,7 +175,7 @@ func TestMonitor_StartErrors(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) err = tt.setupFunc(monitor) @@ -205,7 +205,7 @@ func TestMonitor_StopWithoutStart(t *testing.T) { Timeout: 50 * time.Millisecond, } - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) // Try to stop without starting @@ -236,7 +236,7 @@ func TestMonitor_PeriodicHealthChecks(t *testing.T) { Return(nil, errors.New("backend unavailable")). MinTimes(2) - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -286,7 +286,7 @@ func TestMonitor_GetHealthSummary(t *testing.T) { }). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -330,7 +330,7 @@ func TestMonitor_GetBackendStatus(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -379,7 +379,7 @@ func TestMonitor_GetBackendState(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -430,7 +430,7 @@ func TestMonitor_GetAllBackendStates(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) ctx := context.Background() @@ -474,7 +474,7 @@ func TestMonitor_ContextCancellation(t *testing.T) { Return(&vmcp.CapabilityList{}, nil). AnyTimes() - monitor, err := NewMonitor(mockClient, backends, config) + monitor, err := NewMonitor(mockClient, backends, config, "") require.NoError(t, err) // Start with cancellable context diff --git a/pkg/vmcp/optimizer/find_tool_semantic_search_test.go b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go new file mode 100644 index 0000000000..a539937fe9 --- /dev/null +++ b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go @@ -0,0 +1,690 @@ +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +const ( + testBackendOllama = "ollama" + testBackendOpenAI = "openai" +) + +// verifyEmbeddingBackendWorking verifies that the embedding backend is actually working by attempting to generate an embedding +// This ensures the service is not just reachable but actually functional +func verifyEmbeddingBackendWorking(t *testing.T, manager *embeddings.Manager, backendType string) { + t.Helper() + _, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + if backendType == testBackendOllama { + t.Skipf("Skipping test: Ollama is reachable but embedding generation failed. Error: %v. Ensure 'ollama pull %s' has been executed", err, embeddings.DefaultModelAllMiniLM) + } else { + t.Skipf("Skipping test: Embedding backend is reachable but embedding generation failed. Error: %v", err) + } + } +} + +// TestFindTool_SemanticSearch tests semantic search capabilities +// These tests verify that find_tool can find tools based on semantic meaning, +// not just exact keyword matches +func TestFindTool_SemanticSearch(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + embeddingBackend := testBackendOllama + embeddingConfig := &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, // all-MiniLM-L6-v2 dimension + } + + // Test if Ollama is available + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + // Try OpenAI-compatible (might be vLLM or Ollama v1 API) + embeddingConfig.BackendType = testBackendOpenAI + embeddingConfig.BaseURL = "http://localhost:11434" + embeddingConfig.Model = embeddings.DefaultModelAllMiniLM + embeddingConfig.Dimension = 768 + embeddingManager, err = embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping semantic search test: No embedding backend available (Ollama or OpenAI-compatible). Error: %v", err) + return + } + embeddingBackend = testBackendOpenAI + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify embedding backend is actually working, not just reachable + verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend) + + // Setup optimizer integration with high semantic ratio to favor semantic search + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: embeddingConfig.BaseURL, + Model: embeddingConfig.Model, + Dimension: embeddingConfig.Dimension, + }, + HybridSearchRatio: 0.9, // 90% semantic, 10% BM25 to test semantic search + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + // Create tools with diverse descriptions to test semantic understanding + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_pull_requests", + Description: "List pull requests in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_pull_request", + Description: "Create a new pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_merge_pull_request", + Description: "Merge a pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_issue_read", + Description: "Get information about a specific issue in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_issues", + Description: "List issues in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + { + Name: "github_get_commit", + Description: "Get details for a commit from a GitHub repository", + BackendID: "github", + }, + { + Name: "github_get_branch", + Description: "Get information about a branch in a GitHub repository", + BackendID: "github", + }, + { + Name: "fetch_fetch", + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.", + BackendID: "fetch", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test cases for semantic search - queries that mean the same thing but use different words + testCases := []struct { + name string + query string + keywords string + expectedTools []string // Tools that should be found semantically + description string + }{ + { + name: "semantic_pr_synonyms", + query: "view code review request", + keywords: "", + expectedTools: []string{"github_pull_request_read", "github_list_pull_requests"}, + description: "Should find PR tools using semantic synonyms (code review = pull request)", + }, + { + name: "semantic_merge_synonyms", + query: "combine code changes", + keywords: "", + expectedTools: []string{"github_merge_pull_request"}, + description: "Should find merge tool using semantic meaning (combine = merge)", + }, + { + name: "semantic_create_synonyms", + query: "make a new code review", + keywords: "", + expectedTools: []string{"github_create_pull_request", "github_list_pull_requests", "github_pull_request_read"}, + description: "Should find PR-related tools using semantic meaning (make = create, code review = PR)", + }, + { + name: "semantic_issue_synonyms", + query: "show bug reports", + keywords: "", + expectedTools: []string{"github_issue_read", "github_list_issues"}, + description: "Should find issue tools using semantic synonyms (bug report = issue)", + }, + { + name: "semantic_repository_synonyms", + query: "start a new project", + keywords: "", + expectedTools: []string{"github_create_repository"}, + description: "Should find repository tool using semantic meaning (project = repository)", + }, + { + name: "semantic_commit_synonyms", + query: "get change details", + keywords: "", + expectedTools: []string{"github_get_commit"}, + description: "Should find commit tool using semantic meaning (change = commit)", + }, + { + name: "semantic_fetch_synonyms", + query: "download web page content", + keywords: "", + expectedTools: []string{"fetch_fetch"}, + description: "Should find fetch tool using semantic synonyms (download = fetch)", + }, + { + name: "semantic_branch_synonyms", + query: "get branch information", + keywords: "", + expectedTools: []string{"github_get_branch"}, + description: "Should find branch tool using semantic meaning", + }, + { + name: "semantic_related_concepts", + query: "code collaboration features", + keywords: "", + expectedTools: []string{"github_pull_request_read", "github_create_pull_request", "github_issue_read"}, + description: "Should find collaboration-related tools (PRs and issues are collaboration features)", + }, + { + name: "semantic_intent_based", + query: "I want to see what code changes were made", + keywords: "", + expectedTools: []string{"github_get_commit", "github_pull_request_read"}, + description: "Should find tools based on user intent (seeing code changes = commits/PRs)", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": tc.query, + "tool_keywords": tc.keywords, + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError, "Tool call should not return error for query: %s", tc.query) + + // Parse the result + require.NotEmpty(t, result.Content, "Result should have content") + textContent, okText := mcp.AsTextContent(result.Content[0]) + require.True(t, okText, "Result should be text content") + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "Result should be valid JSON") + + toolsArray, okArray := response["tools"].([]interface{}) + require.True(t, okArray, "Response should have tools array") + require.NotEmpty(t, toolsArray, "Should return at least one result for semantic query: %s", tc.query) + + // Extract tool names from results + foundTools := make([]string, 0, len(toolsArray)) + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap, "Tool should be a map") + toolName, okName := toolMap["name"].(string) + require.True(t, okName, "Tool should have name") + foundTools = append(foundTools, toolName) + + // Verify similarity score exists and is reasonable + similarity, okScore := toolMap["similarity_score"].(float64) + require.True(t, okScore, "Tool should have similarity_score") + assert.Greater(t, similarity, 0.0, "Similarity score should be positive") + } + + // Check that at least one expected tool is found + foundCount := 0 + for _, expectedTool := range tc.expectedTools { + for _, foundTool := range foundTools { + if foundTool == expectedTool { + foundCount++ + break + } + } + } + + assert.GreaterOrEqual(t, foundCount, 1, + "Semantic query '%s' should find at least one expected tool from %v. Found tools: %v (found %d/%d)", + tc.query, tc.expectedTools, foundTools, foundCount, len(tc.expectedTools)) + + // Log results for debugging + if foundCount < len(tc.expectedTools) { + t.Logf("Semantic query '%s': Found %d/%d expected tools. Found: %v, Expected: %v", + tc.query, foundCount, len(tc.expectedTools), foundTools, tc.expectedTools) + } + + // Verify token metrics exist + tokenMetrics, okMetrics := response["token_metrics"].(map[string]interface{}) + require.True(t, okMetrics, "Response should have token_metrics") + assert.Contains(t, tokenMetrics, "baseline_tokens") + assert.Contains(t, tokenMetrics, "returned_tokens") + }) + } +} + +// TestFindTool_SemanticVsKeyword tests that semantic search finds different results than keyword search +func TestFindTool_SemanticVsKeyword(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available + embeddingBackend := "ollama" + embeddingConfig := &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + // Try OpenAI-compatible + embeddingConfig.BackendType = testBackendOpenAI + embeddingManager, err = embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: No embedding backend available. Error: %v", err) + return + } + embeddingBackend = testBackendOpenAI + } + + // Verify embedding backend is actually working, not just reachable + verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend) + _ = embeddingManager.Close() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Test with high semantic ratio + configSemantic := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db-semantic"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: embeddingConfig.BaseURL, + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 0.9, // 90% semantic + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integrationSemantic, err := NewIntegration(ctx, configSemantic, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integrationSemantic.Close() }() + + // Test with low semantic ratio (high BM25) + configKeyword := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db-keyword"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: embeddingConfig.BaseURL, + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 0.1, // 10% semantic, 90% BM25 + } + + integrationKeyword, err := NewIntegration(ctx, configKeyword, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integrationKeyword.Close() }() + + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Register both integrations + err = integrationSemantic.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + err = integrationKeyword.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integrationSemantic.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + err = integrationKeyword.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + // Query that has semantic meaning but no exact keyword match + query := "view code review" + + // Test semantic search + requestSemantic := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "", + "limit": 10, + }, + }, + } + + handlerSemantic := integrationSemantic.CreateFindToolHandler() + resultSemantic, err := handlerSemantic(ctxWithCaps, requestSemantic) + require.NoError(t, err) + require.False(t, resultSemantic.IsError) + + // Test keyword search + requestKeyword := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "", + "limit": 10, + }, + }, + } + + handlerKeyword := integrationKeyword.CreateFindToolHandler() + resultKeyword, err := handlerKeyword(ctxWithCaps, requestKeyword) + require.NoError(t, err) + require.False(t, resultKeyword.IsError) + + // Parse both results + textSemantic, _ := mcp.AsTextContent(resultSemantic.Content[0]) + var responseSemantic map[string]any + json.Unmarshal([]byte(textSemantic.Text), &responseSemantic) + + textKeyword, _ := mcp.AsTextContent(resultKeyword.Content[0]) + var responseKeyword map[string]any + json.Unmarshal([]byte(textKeyword.Text), &responseKeyword) + + toolsSemantic, _ := responseSemantic["tools"].([]interface{}) + toolsKeyword, _ := responseKeyword["tools"].([]interface{}) + + // Both should find results (semantic should find PR tools, keyword might not) + assert.NotEmpty(t, toolsSemantic, "Semantic search should find results") + assert.NotEmpty(t, toolsKeyword, "Keyword search should find results") + + // Semantic search should find pull request tools even without exact keyword match + foundPRSemantic := false + for _, toolInterface := range toolsSemantic { + toolMap, _ := toolInterface.(map[string]interface{}) + toolName, _ := toolMap["name"].(string) + if toolName == "github_pull_request_read" { + foundPRSemantic = true + break + } + } + + t.Logf("Semantic search (90%% semantic): Found %d tools", len(toolsSemantic)) + t.Logf("Keyword search (10%% semantic): Found %d tools", len(toolsKeyword)) + t.Logf("Semantic search found PR tool: %v", foundPRSemantic) + + // Semantic search should be able to find semantically related tools + // even when keywords don't match exactly + assert.True(t, foundPRSemantic, + "Semantic search should find 'github_pull_request_read' for query 'view code review' even without exact keyword match") +} + +// TestFindTool_SemanticSimilarityScores tests that similarity scores are meaningful +func TestFindTool_SemanticSimilarityScores(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available + embeddingBackend := "ollama" + embeddingConfig := &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + // Try OpenAI-compatible + embeddingConfig.BackendType = testBackendOpenAI + embeddingManager, err = embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: No embedding backend available. Error: %v", err) + return + } + embeddingBackend = testBackendOpenAI + } + + // Verify embedding backend is actually working, not just reachable + verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend) + _ = embeddingManager.Close() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: embeddingConfig.BaseURL, + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 0.9, // High semantic ratio + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + { + Name: "fetch_fetch", + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.", + BackendID: "fetch", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Query for pull request + query := "view pull request" + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "", + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.False(t, result.IsError) + + textContent, _ := mcp.AsTextContent(result.Content[0]) + var response map[string]any + json.Unmarshal([]byte(textContent.Text), &response) + + toolsArray, _ := response["tools"].([]interface{}) + require.NotEmpty(t, toolsArray) + + // Check that results are sorted by similarity (highest first) + var similarities []float64 + for _, toolInterface := range toolsArray { + toolMap, _ := toolInterface.(map[string]interface{}) + similarity, _ := toolMap["similarity_score"].(float64) + similarities = append(similarities, similarity) + } + + // Verify results are sorted by similarity (descending) + for i := 1; i < len(similarities); i++ { + assert.GreaterOrEqual(t, similarities[i-1], similarities[i], + "Results should be sorted by similarity score (descending). Scores: %v", similarities) + } + + // The most relevant tool (pull request) should have a higher similarity than unrelated tools + if len(similarities) > 1 { + // First result should have highest similarity + assert.Greater(t, similarities[0], 0.0, "Top result should have positive similarity") + } +} diff --git a/pkg/vmcp/optimizer/find_tool_string_matching_test.go b/pkg/vmcp/optimizer/find_tool_string_matching_test.go new file mode 100644 index 0000000000..b994d7b95d --- /dev/null +++ b/pkg/vmcp/optimizer/find_tool_string_matching_test.go @@ -0,0 +1,696 @@ +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// verifyOllamaWorking verifies that Ollama is actually working by attempting to generate an embedding +// This ensures the service is not just reachable but actually functional +func verifyOllamaWorking(t *testing.T, manager *embeddings.Manager) { + t.Helper() + _, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Skipf("Skipping test: Ollama is reachable but embedding generation failed. Error: %v. Ensure 'ollama pull %s' has been executed", err, embeddings.DefaultModelAllMiniLM) + } +} + +// getRealToolData returns test data based on actual MCP server tools +// These are real tool descriptions from GitHub and other MCP servers +func getRealToolData() []vmcp.Tool { + return []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_pull_requests", + Description: "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead.", + BackendID: "github", + }, + { + Name: "github_search_pull_requests", + Description: "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr", + BackendID: "github", + }, + { + Name: "github_create_pull_request", + Description: "Create a new pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_merge_pull_request", + Description: "Merge a pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_pull_request_review_write", + Description: "Create and/or submit, delete review of a pull request.", + BackendID: "github", + }, + { + Name: "github_issue_read", + Description: "Get information about a specific issue in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_issues", + Description: "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + { + Name: "github_get_commit", + Description: "Get details for a commit from a GitHub repository", + BackendID: "github", + }, + { + Name: "fetch_fetch", + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.", + BackendID: "fetch", + }, + } +} + +// TestFindTool_StringMatching tests that find_tool can match strings correctly +func TestFindTool_StringMatching(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Setup optimizer integration + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify Ollama is actually working, not just reachable + verifyOllamaWorking(t, embeddingManager) + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 0.5, // 50% semantic, 50% BM25 for better string matching + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + // Get real tool data + tools := getRealToolData() + + // Create capabilities with real tools + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + // Build routing table + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + // Register session and generate embeddings + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + // Create context with capabilities + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test cases: query -> expected tool names that should be found + testCases := []struct { + name string + query string + keywords string + expectedTools []string // Tools that should definitely be in results + minResults int // Minimum number of results expected + description string + }{ + { + name: "exact_pull_request_match", + query: "pull request", + keywords: "pull request", + expectedTools: []string{"github_pull_request_read", "github_list_pull_requests", "github_create_pull_request"}, + minResults: 3, + description: "Should find tools with exact 'pull request' string match", + }, + { + name: "pull_request_in_name", + query: "pull request", + keywords: "pull_request", + expectedTools: []string{"github_pull_request_read", "github_list_pull_requests"}, + minResults: 2, + description: "Should match tools with 'pull_request' in name", + }, + { + name: "list_pull_requests", + query: "list pull requests", + keywords: "list pull requests", + expectedTools: []string{"github_list_pull_requests"}, + minResults: 1, + description: "Should find list pull requests tool", + }, + { + name: "read_pull_request", + query: "read pull request", + keywords: "read pull request", + expectedTools: []string{"github_pull_request_read"}, + minResults: 1, + description: "Should find read pull request tool", + }, + { + name: "create_pull_request", + query: "create pull request", + keywords: "create pull request", + expectedTools: []string{"github_create_pull_request"}, + minResults: 1, + description: "Should find create pull request tool", + }, + { + name: "merge_pull_request", + query: "merge pull request", + keywords: "merge pull request", + expectedTools: []string{"github_merge_pull_request"}, + minResults: 1, + description: "Should find merge pull request tool", + }, + { + name: "search_pull_requests", + query: "search pull requests", + keywords: "search pull requests", + expectedTools: []string{"github_search_pull_requests"}, + minResults: 1, + description: "Should find search pull requests tool", + }, + { + name: "issue_tools", + query: "issue", + keywords: "issue", + expectedTools: []string{"github_issue_read", "github_list_issues"}, + minResults: 2, + description: "Should find issue-related tools", + }, + { + name: "repository_tool", + query: "create repository", + keywords: "create repository", + expectedTools: []string{"github_create_repository"}, + minResults: 1, + description: "Should find create repository tool", + }, + { + name: "commit_tool", + query: "get commit", + keywords: "commit", + expectedTools: []string{"github_get_commit"}, + minResults: 1, + description: "Should find get commit tool", + }, + { + name: "fetch_tool", + query: "fetch URL", + keywords: "fetch", + expectedTools: []string{"fetch_fetch"}, + minResults: 1, + description: "Should find fetch tool", + }, + } + + for _, tc := range testCases { + tc := tc // capture loop variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create the tool call request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": tc.query, + "tool_keywords": tc.keywords, + "limit": 20, + }, + }, + } + + // Call the handler + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError, "Tool call should not return error") + + // Parse the result + require.NotEmpty(t, result.Content, "Result should have content") + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok, "Result should be text content") + + // Parse JSON response + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "Result should be valid JSON") + + // Check tools array exists + toolsArray, ok := response["tools"].([]interface{}) + require.True(t, ok, "Response should have tools array") + require.GreaterOrEqual(t, len(toolsArray), tc.minResults, + "Should return at least %d results for query: %s", tc.minResults, tc.query) + + // Extract tool names from results + foundTools := make([]string, 0, len(toolsArray)) + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap, "Tool should be a map") + toolName, okName := toolMap["name"].(string) + require.True(t, okName, "Tool should have name") + foundTools = append(foundTools, toolName) + } + + // Check that at least some expected tools are found + // String matching may not be perfect, so we check that at least one expected tool is found + foundCount := 0 + for _, expectedTool := range tc.expectedTools { + for _, foundTool := range foundTools { + if foundTool == expectedTool { + foundCount++ + break + } + } + } + + // We should find at least one expected tool, or at least 50% of expected tools + minExpected := 1 + if len(tc.expectedTools) > 1 { + half := len(tc.expectedTools) / 2 + if half > minExpected { + minExpected = half + } + } + + assert.GreaterOrEqual(t, foundCount, minExpected, + "Query '%s' should find at least %d of expected tools %v. Found tools: %v (found %d/%d)", + tc.query, minExpected, tc.expectedTools, foundTools, foundCount, len(tc.expectedTools)) + + // Log which expected tools were found for debugging + if foundCount < len(tc.expectedTools) { + t.Logf("Query '%s': Found %d/%d expected tools. Found: %v, Expected: %v", + tc.query, foundCount, len(tc.expectedTools), foundTools, tc.expectedTools) + } + + // Verify token metrics exist + tokenMetrics, ok := response["token_metrics"].(map[string]interface{}) + require.True(t, ok, "Response should have token_metrics") + assert.Contains(t, tokenMetrics, "baseline_tokens") + assert.Contains(t, tokenMetrics, "returned_tokens") + assert.Contains(t, tokenMetrics, "tokens_saved") + assert.Contains(t, tokenMetrics, "savings_percentage") + }) + } +} + +// TestFindTool_ExactStringMatch tests that exact string matches work correctly +func TestFindTool_ExactStringMatch(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Setup optimizer integration with higher BM25 ratio for better string matching + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify Ollama is actually working, not just reachable + verifyOllamaWorking(t, embeddingManager) + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 0.3, // 30% semantic, 70% BM25 for better exact string matching + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + // Create tools with specific strings to match + tools := []vmcp.Tool{ + { + Name: "test_pull_request_tool", + Description: "This tool handles pull requests in GitHub", + BackendID: "test", + }, + { + Name: "test_issue_tool", + Description: "This tool handles issues in GitHub", + BackendID: "test", + }, + { + Name: "test_repository_tool", + Description: "This tool creates repositories", + BackendID: "test", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "test", "test", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test exact string matching + testCases := []struct { + name string + query string + keywords string + expectedTool string + description string + }{ + { + name: "exact_pull_request_string", + query: "pull request", + keywords: "pull request", + expectedTool: "test_pull_request_tool", + description: "Should match exact 'pull request' string", + }, + { + name: "exact_issue_string", + query: "issue", + keywords: "issue", + expectedTool: "test_issue_tool", + description: "Should match exact 'issue' string", + }, + { + name: "exact_repository_string", + query: "repository", + keywords: "repository", + expectedTool: "test_repository_tool", + description: "Should match exact 'repository' string", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": tc.query, + "tool_keywords": tc.keywords, + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + + textContent, okText := mcp.AsTextContent(result.Content[0]) + require.True(t, okText) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + toolsArray, okArray := response["tools"].([]interface{}) + require.True(t, okArray) + require.NotEmpty(t, toolsArray, "Should find at least one tool for query: %s", tc.query) + + // Check that the expected tool is in the results + found := false + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap) + toolName, okName := toolMap["name"].(string) + require.True(t, okName) + if toolName == tc.expectedTool { + found = true + break + } + } + + assert.True(t, found, + "Expected tool '%s' not found in results for query '%s'. This indicates string matching is not working correctly.", + tc.expectedTool, tc.query) + }) + } +} + +// TestFindTool_CaseInsensitive tests case-insensitive string matching +func TestFindTool_CaseInsensitive(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify Ollama is actually working, not just reachable + verifyOllamaWorking(t, embeddingManager) + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + HybridSearchRatio: 0.3, // Favor BM25 for string matching + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "github_pull_request_read": { + WorkloadID: "github", + WorkloadName: "github", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test different case variations + queries := []string{ + "PULL REQUEST", + "Pull Request", + "pull request", + "PuLl ReQuEsT", + } + + for _, query := range queries { + query := query + t.Run("case_"+strings.ToLower(query), func(t *testing.T) { + t.Parallel() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": strings.ToLower(query), + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + + textContent, okText := mcp.AsTextContent(result.Content[0]) + require.True(t, okText) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + toolsArray, okArray := response["tools"].([]interface{}) + require.True(t, okArray) + + // Should find the pull request tool regardless of case + found := false + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap) + toolName, okName := toolMap["name"].(string) + require.True(t, okName) + if toolName == "github_pull_request_read" { + found = true + break + } + } + + assert.True(t, found, + "Should find pull request tool with case-insensitive query: %s", query) + }) + } +} diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index 4a24d95576..19553ea2e1 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -13,7 +13,9 @@ package optimizer import ( "context" + "encoding/json" "fmt" + "sync" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -22,8 +24,11 @@ import ( "github.com/stacklok/toolhive/pkg/optimizer/db" "github.com/stacklok/toolhive/pkg/optimizer/embeddings" "github.com/stacklok/toolhive/pkg/optimizer/ingestion" + "github.com/stacklok/toolhive/pkg/optimizer/models" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" ) // Config holds optimizer configuration for vMCP integration. @@ -49,10 +54,12 @@ type Config struct { // //nolint:revive // Name is intentional for clarity in external packages type OptimizerIntegration struct { - config *Config - ingestionService *ingestion.Service - mcpServer *server.MCPServer // For registering tools - backendClient vmcp.BackendClient // For querying backends at startup + config *Config + ingestionService *ingestion.Service + mcpServer *server.MCPServer // For registering tools + backendClient vmcp.BackendClient // For querying backends at startup + sessionManager *transportsession.Manager + processedSessions sync.Map // Track sessions that have already been processed } // NewIntegration creates a new optimizer integration. @@ -61,6 +68,7 @@ func NewIntegration( cfg *Config, mcpServer *server.MCPServer, backendClient vmcp.BackendClient, + sessionManager *transportsession.Manager, ) (*OptimizerIntegration, error) { if cfg == nil || !cfg.Enabled { return nil, nil // Optimizer disabled @@ -85,6 +93,7 @@ func NewIntegration( ingestionService: svc, mcpServer: mcpServer, backendClient: backendClient, + sessionManager: sessionManager, }, nil } @@ -96,98 +105,30 @@ func NewIntegration( // 2. Generates embeddings for all tools (parallel per-backend) // 3. Registers optim.find_tool and optim.call_tool as session tools func (o *OptimizerIntegration) OnRegisterSession( - ctx context.Context, + _ context.Context, session server.ClientSession, - capabilities *aggregator.AggregatedCapabilities, + _ *aggregator.AggregatedCapabilities, ) error { if o == nil { return nil // Optimizer not enabled } sessionID := session.SessionID() - logger.Infow("Generating embeddings for session", "session_id", sessionID) - - // Group tools by backend for parallel processing - type backendTools struct { - backendID string - backendName string - backendURL string - transport string - tools []mcp.Tool - } - - backendMap := make(map[string]*backendTools) - - // Extract tools from routing table - if capabilities.RoutingTable != nil { - for toolName, target := range capabilities.RoutingTable.Tools { - // Find the tool definition from capabilities.Tools - var toolDef mcp.Tool - found := false - for i := range capabilities.Tools { - if capabilities.Tools[i].Name == toolName { - // Convert vmcp.Tool to mcp.Tool - // Note: vmcp.Tool.InputSchema is map[string]any, mcp.Tool.InputSchema is ToolInputSchema struct - // For ingestion, we just need the tool name and description - toolDef = mcp.Tool{ - Name: capabilities.Tools[i].Name, - Description: capabilities.Tools[i].Description, - // InputSchema will be empty - we only need name/description for embedding generation - } - found = true - break - } - } - if !found { - logger.Warnw("Tool in routing table but not in capabilities", - "tool_name", toolName, - "backend_id", target.WorkloadID) - continue - } - // Group by backend - if _, exists := backendMap[target.WorkloadID]; !exists { - backendMap[target.WorkloadID] = &backendTools{ - backendID: target.WorkloadID, - backendName: target.WorkloadName, - backendURL: target.BaseURL, - transport: target.TransportType, - tools: []mcp.Tool{}, - } - } - backendMap[target.WorkloadID].tools = append(backendMap[target.WorkloadID].tools, toolDef) - } - } + logger.Debugw("OnRegisterSession called", "session_id", sessionID) - // Ingest each backend's tools (in parallel - TODO: add goroutines) - for _, bt := range backendMap { - logger.Debugw("Ingesting backend for session", - "session_id", sessionID, - "backend_id", bt.backendID, - "backend_name", bt.backendName, - "tool_count", len(bt.tools)) - - // Ingest server with simplified metadata - // Note: URL and transport are not stored - vMCP manages backend lifecycle - err := o.ingestionService.IngestServer( - ctx, - bt.backendID, - bt.backendName, - nil, // description - bt.tools, - ) - if err != nil { - logger.Errorw("Failed to ingest backend", - "session_id", sessionID, - "backend_id", bt.backendID, - "error", err) - // Continue with other backends - } + // Check if this session has already been processed + if _, alreadyProcessed := o.processedSessions.LoadOrStore(sessionID, true); alreadyProcessed { + logger.Debugw("Session already processed, skipping duplicate ingestion", + "session_id", sessionID) + return nil } - logger.Infow("Embeddings generated for session", - "session_id", sessionID, - "backend_count", len(backendMap)) + // Skip ingestion in OnRegisterSession - IngestInitialBackends already handles ingestion at startup + // This prevents duplicate ingestion when sessions are registered + // The optimizer database is populated once at startup, not per-session + logger.Infow("Skipping ingestion in OnRegisterSession (handled by IngestInitialBackends at startup)", + "session_id", sessionID) return nil } @@ -252,7 +193,7 @@ func (o *OptimizerIntegration) RegisterTools(_ context.Context, session server.C Required: []string{"backend_id", "tool_name", "parameters"}, }, }, - Handler: o.createCallToolHandler(), + Handler: o.CreateCallToolHandler(), }, } @@ -265,32 +206,255 @@ func (o *OptimizerIntegration) RegisterTools(_ context.Context, session server.C return nil } -// createFindToolHandler creates the handler for optim.find_tool -func (*OptimizerIntegration) createFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // TODO: Implement semantic search - // 1. Extract tool_description and tool_keywords from request.Params.Arguments - // 2. Call optimizer search service (hybrid semantic + BM25) - // 3. Return ranked list of tools with scores and token metrics +// CreateFindToolHandler creates the handler for optim.find_tool +// Exported for testing purposes +func (o *OptimizerIntegration) CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return o.createFindToolHandler() +} + +// extractFindToolParams extracts and validates parameters from the find_tool request +func extractFindToolParams(args map[string]any) (toolDescription, toolKeywords string, limit int, err *mcp.CallToolResult) { + // Extract tool_description (required) + toolDescription, ok := args["tool_description"].(string) + if !ok || toolDescription == "" { + return "", "", 0, mcp.NewToolResultError("tool_description is required and must be a non-empty string") + } + + // Extract tool_keywords (optional) + toolKeywords, _ = args["tool_keywords"].(string) + + // Extract limit (optional, default: 10) + limit = 10 + if limitVal, ok := args["limit"]; ok { + if limitFloat, ok := limitVal.(float64); ok { + limit = int(limitFloat) + } + } + + return toolDescription, toolKeywords, limit, nil +} +// convertSearchResultsToResponse converts database search results to the response format +func convertSearchResultsToResponse(results []*models.BackendToolWithMetadata) ([]map[string]any, int) { + responseTools := make([]map[string]any, 0, len(results)) + totalReturnedTokens := 0 + + for _, result := range results { + // Unmarshal InputSchema + var inputSchema map[string]any + if len(result.InputSchema) > 0 { + if err := json.Unmarshal(result.InputSchema, &inputSchema); err != nil { + logger.Warnw("Failed to unmarshal input schema", + "tool_id", result.ID, + "tool_name", result.ToolName, + "error", err) + inputSchema = map[string]any{} // Use empty schema on error + } + } + + // Handle nil description + description := "" + if result.Description != nil { + description = *result.Description + } + + tool := map[string]any{ + "name": result.ToolName, + "description": description, + "input_schema": inputSchema, + "backend_id": result.MCPServerID, + "similarity_score": result.Similarity, + "token_count": result.TokenCount, + } + responseTools = append(responseTools, tool) + totalReturnedTokens += result.TokenCount + } + + return responseTools, totalReturnedTokens +} + +// createFindToolHandler creates the handler for optim.find_tool +func (o *OptimizerIntegration) createFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { logger.Debugw("optim.find_tool called", "request", request) - return mcp.NewToolResultError("optim.find_tool not yet implemented"), nil + // Extract parameters from request arguments + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return mcp.NewToolResultError("invalid arguments: expected object"), nil + } + + // Extract and validate parameters + toolDescription, toolKeywords, limit, err := extractFindToolParams(args) + if err != nil { + return err, nil + } + + // Perform hybrid search using database operations + if o.ingestionService == nil { + return mcp.NewToolResultError("backend tool operations not initialized"), nil + } + backendToolOps := o.ingestionService.GetBackendToolOps() + if backendToolOps == nil { + return mcp.NewToolResultError("backend tool operations not initialized"), nil + } + + // Configure hybrid search + hybridConfig := &db.HybridSearchConfig{ + SemanticRatio: o.config.HybridSearchRatio, + Limit: limit, + ServerID: nil, // Search across all servers + } + + // Execute hybrid search + queryText := toolDescription + if toolKeywords != "" { + queryText = toolDescription + " " + toolKeywords + } + results, err2 := backendToolOps.SearchHybrid(ctx, queryText, hybridConfig) + if err2 != nil { + logger.Errorw("Hybrid search failed", + "error", err2, + "tool_description", toolDescription, + "tool_keywords", toolKeywords, + "query_text", queryText) + return mcp.NewToolResultError(fmt.Sprintf("search failed: %v", err2)), nil + } + + // Convert results to response format + responseTools, totalReturnedTokens := convertSearchResultsToResponse(results) + + // Calculate token metrics + baselineTokens := o.ingestionService.GetTotalToolTokens(ctx) + tokensSaved := baselineTokens - totalReturnedTokens + savingsPercentage := 0.0 + if baselineTokens > 0 { + savingsPercentage = (float64(tokensSaved) / float64(baselineTokens)) * 100.0 + } + + tokenMetrics := map[string]any{ + "baseline_tokens": baselineTokens, + "returned_tokens": totalReturnedTokens, + "tokens_saved": tokensSaved, + "savings_percentage": savingsPercentage, + } + + // Build response + response := map[string]any{ + "tools": responseTools, + "token_metrics": tokenMetrics, + } + + // Marshal to JSON for the result + responseJSON, err3 := json.Marshal(response) + if err3 != nil { + logger.Errorw("Failed to marshal response", "error", err3) + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal response: %v", err3)), nil + } + + logger.Infow("optim.find_tool completed", + "query", toolDescription, + "results_count", len(responseTools), + "tokens_saved", tokensSaved, + "savings_percentage", fmt.Sprintf("%.2f%%", savingsPercentage)) + + return mcp.NewToolResultText(string(responseJSON)), nil } } -// createCallToolHandler creates the handler for optim.call_tool -func (*OptimizerIntegration) createCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(_ context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - // TODO: Implement dynamic tool invocation - // 1. Extract backend_id, tool_name, parameters from request.Params.Arguments - // 2. Validate backend and tool exist - // 3. Route to backend via existing router - // 4. Return result +// CreateCallToolHandler creates the handler for optim.call_tool +// Exported for testing purposes +func (o *OptimizerIntegration) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return o.createCallToolHandler() +} +// createCallToolHandler creates the handler for optim.call_tool +func (o *OptimizerIntegration) createCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { logger.Debugw("optim.call_tool called", "request", request) - return mcp.NewToolResultError("optim.call_tool not yet implemented"), nil + // Extract parameters from request arguments + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return mcp.NewToolResultError("invalid arguments: expected object"), nil + } + + // Extract backend_id (required) + backendID, ok := args["backend_id"].(string) + if !ok || backendID == "" { + return mcp.NewToolResultError("backend_id is required and must be a non-empty string"), nil + } + + // Extract tool_name (required) + toolName, ok := args["tool_name"].(string) + if !ok || toolName == "" { + return mcp.NewToolResultError("tool_name is required and must be a non-empty string"), nil + } + + // Extract parameters (required) + parameters, ok := args["parameters"].(map[string]any) + if !ok { + return mcp.NewToolResultError("parameters is required and must be an object"), nil + } + + // Get routing table from context via discovered capabilities + capabilities, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) + if !ok || capabilities == nil { + return mcp.NewToolResultError("routing information not available in context"), nil + } + + if capabilities.RoutingTable == nil || capabilities.RoutingTable.Tools == nil { + return mcp.NewToolResultError("routing table not initialized"), nil + } + + // Find the tool in the routing table + target, exists := capabilities.RoutingTable.Tools[toolName] + if !exists { + return mcp.NewToolResultError(fmt.Sprintf("tool not found in routing table: %s", toolName)), nil + } + + // Verify the tool belongs to the specified backend + if target.WorkloadID != backendID { + return mcp.NewToolResultError(fmt.Sprintf( + "tool %s belongs to backend %s, not %s", + toolName, + target.WorkloadID, + backendID, + )), nil + } + + // Get the backend capability name (handles renamed tools) + backendToolName := target.GetBackendCapabilityName(toolName) + + logger.Infow("Calling tool via optimizer", + "backend_id", backendID, + "tool_name", toolName, + "backend_tool_name", backendToolName, + "workload_name", target.WorkloadName) + + // Call the tool on the backend using the backend client + result, err := o.backendClient.CallTool(ctx, target, backendToolName, parameters) + if err != nil { + logger.Errorw("Tool call failed", + "error", err, + "backend_id", backendID, + "tool_name", toolName, + "backend_tool_name", backendToolName) + return mcp.NewToolResultError(fmt.Sprintf("tool call failed: %v", err)), nil + } + + // Convert result to JSON + resultJSON, err := json.Marshal(result) + if err != nil { + logger.Errorw("Failed to marshal tool result", "error", err) + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil + } + + logger.Infow("optim.call_tool completed successfully", + "backend_id", backendID, + "tool_name", toolName) + + return mcp.NewToolResultText(string(resultJSON)), nil } } @@ -362,3 +526,18 @@ func (o *OptimizerIntegration) Close() error { } return o.ingestionService.Close() } + +// IngestToolsForTesting manually ingests tools for testing purposes. +// This is a test helper that bypasses the normal ingestion flow. +func (o *OptimizerIntegration) IngestToolsForTesting( + ctx context.Context, + serverID string, + serverName string, + description *string, + tools []mcp.Tool, +) error { + if o == nil || o.ingestionService == nil { + return fmt.Errorf("optimizer integration not initialized") + } + return o.ingestionService.IngestServer(ctx, serverID, serverName, description, tools) +} diff --git a/pkg/vmcp/optimizer/optimizer_handlers_test.go b/pkg/vmcp/optimizer/optimizer_handlers_test.go new file mode 100644 index 0000000000..3889a47e37 --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer_handlers_test.go @@ -0,0 +1,1026 @@ +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// mockMCPServerWithSession implements AddSessionTools for testing +type mockMCPServerWithSession struct { + *server.MCPServer + toolsAdded map[string][]server.ServerTool +} + +func newMockMCPServerWithSession() *mockMCPServerWithSession { + return &mockMCPServerWithSession{ + MCPServer: server.NewMCPServer("test-server", "1.0"), + toolsAdded: make(map[string][]server.ServerTool), + } +} + +func (m *mockMCPServerWithSession) AddSessionTools(sessionID string, tools ...server.ServerTool) error { + m.toolsAdded[sessionID] = tools + return nil +} + +// mockBackendClientWithCallTool implements CallTool for testing +type mockBackendClientWithCallTool struct { + callToolResult map[string]any + callToolError error +} + +func (*mockBackendClientWithCallTool) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + return &vmcp.CapabilityList{}, nil +} + +func (m *mockBackendClientWithCallTool) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) { + if m.callToolError != nil { + return nil, m.callToolError + } + return m.callToolResult, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClientWithCallTool) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) { + return "", nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClientWithCallTool) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) { + return nil, nil +} + +// TestCreateFindToolHandler_InvalidArguments tests error handling for invalid arguments +func TestCreateFindToolHandler_InvalidArguments(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Setup optimizer integration + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateFindToolHandler() + + // Test with invalid arguments type + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: "not a map", + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for invalid arguments") + + // Test with missing tool_description + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "limit": 10, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing tool_description") + + // Test with empty tool_description + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": "", + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for empty tool_description") + + // Test with non-string tool_description + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": 123, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for non-string tool_description") +} + +// TestCreateFindToolHandler_WithKeywords tests find_tool with keywords +func TestCreateFindToolHandler_WithKeywords(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Ingest a tool for testing + tools := []mcp.Tool{ + { + Name: "test_tool", + Description: "A test tool for searching", + }, + } + + err = integration.IngestToolsForTesting(ctx, "server-1", "TestServer", nil, tools) + require.NoError(t, err) + + handler := integration.CreateFindToolHandler() + + // Test with keywords + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": "search tool", + "tool_keywords": "test search", + "limit": 10, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.False(t, result.IsError, "Should not return error") + + // Verify response structure + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + _, ok = response["tools"] + require.True(t, ok, "Response should have tools") + + _, ok = response["token_metrics"] + require.True(t, ok, "Response should have token_metrics") +} + +// TestCreateFindToolHandler_Limit tests limit parameter handling +func TestCreateFindToolHandler_Limit(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateFindToolHandler() + + // Test with custom limit + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": "test", + "limit": 5, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.False(t, result.IsError) + + // Test with float64 limit (from JSON) + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": "test", + "limit": float64(3), + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.False(t, result.IsError) +} + +// TestCreateFindToolHandler_BackendToolOpsNil tests error when backend tool ops is nil +func TestCreateFindToolHandler_BackendToolOpsNil(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create integration with nil ingestion service to trigger error path + integration := &OptimizerIntegration{ + config: &Config{Enabled: true}, + ingestionService: nil, // This will cause GetBackendToolOps to return nil + } + + handler := integration.CreateFindToolHandler() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": "test", + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when backend tool ops is nil") +} + +// TestCreateCallToolHandler_InvalidArguments tests error handling for invalid arguments +func TestCreateCallToolHandler_InvalidArguments(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Test with invalid arguments type + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: "not a map", + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for invalid arguments") + + // Test with missing backend_id + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing backend_id") + + // Test with empty backend_id + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "backend_id": "", + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for empty backend_id") + + // Test with missing tool_name + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "parameters": map[string]any{}, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing tool_name") + + // Test with missing parameters + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing parameters") + + // Test with invalid parameters type + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": "not a map", + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for invalid parameters type") +} + +// TestCreateCallToolHandler_NoRoutingTable tests error when routing table is missing +func TestCreateCallToolHandler_NoRoutingTable(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Test without routing table in context + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when routing table is missing") +} + +// TestCreateCallToolHandler_ToolNotFound tests error when tool is not found +func TestCreateCallToolHandler_ToolNotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Create context with routing table but tool not found + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "nonexistent_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when tool is not found") +} + +// TestCreateCallToolHandler_BackendMismatch tests error when backend doesn't match +func TestCreateCallToolHandler_BackendMismatch(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Create context with routing table where tool belongs to different backend + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": { + WorkloadID: "backend-2", // Different backend + WorkloadName: "Backend 2", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", // Requesting backend-1 + "tool_name": "test_tool", // But tool belongs to backend-2 + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when backend doesn't match") +} + +// TestCreateCallToolHandler_Success tests successful tool call +func TestCreateCallToolHandler_Success(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{ + callToolResult: map[string]any{ + "result": "success", + }, + } + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Create context with routing table + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "Backend 1", + BaseURL: "http://localhost:8000", + } + + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": target, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": map[string]any{ + "param1": "value1", + }, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.False(t, result.IsError, "Should not return error") + + // Verify response + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, "success", response["result"]) +} + +// TestCreateCallToolHandler_CallToolError tests error handling when CallTool fails +func TestCreateCallToolHandler_CallToolError(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{ + callToolError: assert.AnError, + } + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "Backend 1", + BaseURL: "http://localhost:8000", + } + + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": target, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when CallTool fails") +} + +// TestCreateFindToolHandler_InputSchemaUnmarshalError tests error handling for invalid input schema +func TestCreateFindToolHandler_InputSchemaUnmarshalError(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateFindToolHandler() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": "test", + }, + }, + } + + // The handler should handle invalid input schema gracefully + result, err := handler(ctx, request) + require.NoError(t, err) + // Should not error even if some tools have invalid schemas + require.False(t, result.IsError) +} + +// TestOnRegisterSession_DuplicateSession tests duplicate session handling +func TestOnRegisterSession_DuplicateSession(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + session := &mockSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{} + + // First call + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Second call with same session ID (should be skipped) + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err, "Should handle duplicate session gracefully") +} + +// TestIngestInitialBackends_ErrorHandling tests error handling during ingestion +func TestIngestInitialBackends_ErrorHandling(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{ + err: assert.AnError, // Simulate error when listing capabilities + } + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + backends := []vmcp.Backend{ + { + ID: "backend-1", + Name: "Backend 1", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + // Should not fail even if backend query fails + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err, "Should handle backend query errors gracefully") +} + +// TestIngestInitialBackends_NilIntegration tests nil integration handling +func TestIngestInitialBackends_NilIntegration(t *testing.T) { + t.Parallel() + ctx := context.Background() + + var integration *OptimizerIntegration = nil + backends := []vmcp.Backend{} + + err := integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err, "Should handle nil integration gracefully") +} diff --git a/pkg/vmcp/optimizer/optimizer_integration_test.go b/pkg/vmcp/optimizer/optimizer_integration_test.go index 82a51a925a..2fcb912743 100644 --- a/pkg/vmcp/optimizer/optimizer_integration_test.go +++ b/pkg/vmcp/optimizer/optimizer_integration_test.go @@ -4,14 +4,17 @@ import ( "context" "path/filepath" "testing" + "time" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" ) // mockBackendClient implements vmcp.BackendClient for integration testing @@ -107,18 +110,36 @@ func TestOptimizerIntegration_WithVMCP(t *testing.T) { }, }) + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + // Configure optimizer optimizerConfig := &Config{ Enabled: true, PersistPath: filepath.Join(tmpDir, "optimizer-db"), EmbeddingConfig: &embeddings.Config{ - BackendType: "placeholder", + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, Dimension: 384, }, } // Create optimizer integration - integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient) + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) require.NoError(t, err) defer func() { _ = integration.Close() }() diff --git a/pkg/vmcp/optimizer/optimizer_unit_test.go b/pkg/vmcp/optimizer/optimizer_unit_test.go index 794069b851..8b09a99ee8 100644 --- a/pkg/vmcp/optimizer/optimizer_unit_test.go +++ b/pkg/vmcp/optimizer/optimizer_unit_test.go @@ -4,6 +4,7 @@ import ( "context" "path/filepath" "testing" + "time" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -11,8 +12,10 @@ import ( "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" "github.com/stacklok/toolhive/pkg/vmcp" "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" ) // mockBackendClient implements vmcp.BackendClient for testing @@ -85,13 +88,13 @@ func TestNewIntegration_Disabled(t *testing.T) { ctx := context.Background() // Test with nil config - integration, err := NewIntegration(ctx, nil, nil, nil) + integration, err := NewIntegration(ctx, nil, nil, nil, nil) require.NoError(t, err) assert.Nil(t, integration, "Should return nil when config is nil") // Test with disabled config config := &Config{Enabled: false} - integration, err = NewIntegration(ctx, config, nil, nil) + integration, err = NewIntegration(ctx, config, nil, nil, nil) require.NoError(t, err) assert.Nil(t, integration, "Should return nil when optimizer is disabled") } @@ -102,6 +105,21 @@ func TestNewIntegration_Enabled(t *testing.T) { ctx := context.Background() tmpDir := t.TempDir() + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + mcpServer := server.NewMCPServer("test-server", "1.0") mockClient := &mockBackendClient{} @@ -109,12 +127,15 @@ func TestNewIntegration_Enabled(t *testing.T) { Enabled: true, PersistPath: filepath.Join(tmpDir, "optimizer-db"), EmbeddingConfig: &embeddings.Config{ - BackendType: "placeholder", - Dimension: 384, + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, }, } - integration, err := NewIntegration(ctx, config, mcpServer, mockClient) + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) require.NoError(t, err) require.NotNil(t, integration) defer func() { _ = integration.Close() }() @@ -129,16 +150,34 @@ func TestOnRegisterSession(t *testing.T) { mcpServer := server.NewMCPServer("test-server", "1.0") mockClient := &mockBackendClient{} + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + config := &Config{ Enabled: true, PersistPath: filepath.Join(tmpDir, "optimizer-db"), EmbeddingConfig: &embeddings.Config{ - BackendType: "placeholder", - Dimension: 384, + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, }, } - integration, err := NewIntegration(ctx, config, mcpServer, mockClient) + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) require.NoError(t, err) defer func() { _ = integration.Close() }() @@ -189,16 +228,34 @@ func TestRegisterTools(t *testing.T) { mcpServer := server.NewMCPServer("test-server", "1.0") mockClient := &mockBackendClient{} + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + config := &Config{ Enabled: true, PersistPath: filepath.Join(tmpDir, "optimizer-db"), EmbeddingConfig: &embeddings.Config{ - BackendType: "placeholder", - Dimension: 384, + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, }, } - integration, err := NewIntegration(ctx, config, mcpServer, mockClient) + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) require.NoError(t, err) defer func() { _ = integration.Close() }() @@ -230,16 +287,34 @@ func TestClose(t *testing.T) { mcpServer := server.NewMCPServer("test-server", "1.0") mockClient := &mockBackendClient{} + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + config := &Config{ Enabled: true, PersistPath: filepath.Join(tmpDir, "optimizer-db"), EmbeddingConfig: &embeddings.Config{ - BackendType: "placeholder", - Dimension: 384, + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, }, } - integration, err := NewIntegration(ctx, config, mcpServer, mockClient) + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) require.NoError(t, err) err = integration.Close() diff --git a/pkg/vmcp/server/optimizer_test.go b/pkg/vmcp/server/optimizer_test.go new file mode 100644 index 0000000000..0d8cba1ad5 --- /dev/null +++ b/pkg/vmcp/server/optimizer_test.go @@ -0,0 +1,350 @@ +package server + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + discoveryMocks "github.com/stacklok/toolhive/pkg/vmcp/discovery/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/mocks" + "github.com/stacklok/toolhive/pkg/vmcp/router" +) + +// TestNew_OptimizerEnabled tests server creation with optimizer enabled +func TestNew_OptimizerEnabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockBackendClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + AnyTimes() + + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT(). + Discover(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + AnyTimes() + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + tmpDir := t.TempDir() + + // Try to use Ollama if available + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &OptimizerConfig{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + HybridSearchRatio: 0.7, + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{ + { + ID: "backend-1", + Name: "Backend 1", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() + + // Verify optimizer integration was created + // We can't directly access optimizerIntegration, but we can verify server was created successfully +} + +// TestNew_OptimizerDisabled tests server creation with optimizer disabled +func TestNew_OptimizerDisabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &OptimizerConfig{ + Enabled: false, // Disabled + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{} + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() +} + +// TestNew_OptimizerConfigNil tests server creation with nil optimizer config +func TestNew_OptimizerConfigNil(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: nil, // Nil config + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{} + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() +} + +// TestNew_OptimizerIngestionError tests error handling during optimizer ingestion +func TestNew_OptimizerIngestionError(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + // Return error when listing capabilities + mockBackendClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(nil, assert.AnError). + AnyTimes() + + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &OptimizerConfig{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{ + { + ID: "backend-1", + Name: "Backend 1", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + // Should not fail even if ingestion fails + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err, "Server should be created even if optimizer ingestion fails") + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() +} + +// TestNew_OptimizerHybridRatio tests hybrid ratio configuration +func TestNew_OptimizerHybridRatio(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockBackendClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + AnyTimes() + + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT(). + Discover(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + AnyTimes() + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &OptimizerConfig{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + HybridSearchRatio: 0.5, // Custom ratio + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{} + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + defer func() { _ = srv.Stop(context.Background()) }() +} + +// TestServer_Stop_OptimizerCleanup tests optimizer cleanup on server stop +func TestServer_Stop_OptimizerCleanup(t *testing.T) { + t.Parallel() + ctx := context.Background() + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockBackendClient := mocks.NewMockBackendClient(ctrl) + mockBackendClient.EXPECT(). + ListCapabilities(gomock.Any(), gomock.Any()). + Return(&vmcp.CapabilityList{}, nil). + AnyTimes() + + mockDiscoveryMgr := discoveryMocks.NewMockManager(ctrl) + mockDiscoveryMgr.EXPECT(). + Discover(gomock.Any(), gomock.Any()). + Return(&aggregator.AggregatedCapabilities{}, nil). + AnyTimes() + mockDiscoveryMgr.EXPECT().Stop().AnyTimes() + + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + cfg := &Config{ + Name: "test-server", + Version: "1.0.0", + Host: "127.0.0.1", + Port: 0, + SessionTTL: 5 * time.Minute, + OptimizerConfig: &OptimizerConfig{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + }, + } + + rt := router.NewDefaultRouter() + backends := []vmcp.Backend{} + + srv, err := New(ctx, cfg, rt, mockBackendClient, mockDiscoveryMgr, vmcp.NewImmutableRegistry(backends), nil) + require.NoError(t, err) + require.NotNil(t, srv) + + // Stop should clean up optimizer + err = srv.Stop(context.Background()) + require.NoError(t, err) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index e0fc6235e4..7ec8b2bab5 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -370,7 +370,15 @@ func New( if cfg.HealthMonitorConfig != nil { // Get initial backends list from registry for health monitoring setup initialBackends := backendRegistry.List(ctx) - healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig) + + // Construct server's own URL for self-check detection + // Use http:// as default scheme (most common for local development) + var selfURL string + if cfg.Host != "" && cfg.Port > 0 { + selfURL = fmt.Sprintf("http://%s:%d", cfg.Host, cfg.Port) + } + + healthMon, err = health.NewMonitor(backendClient, initialBackends, *cfg.HealthMonitorConfig, selfURL) if err != nil { return nil, fmt.Errorf("failed to create health monitor: %w", err) } @@ -385,40 +393,44 @@ func New( // Initialize optimizer integration if enabled var optimizerInteg OptimizerIntegration - if cfg.OptimizerConfig != nil && cfg.OptimizerConfig.Enabled { - logger.Infow("Initializing optimizer integration (chromem-go)", - "persist_path", cfg.OptimizerConfig.PersistPath, - "embedding_backend", cfg.OptimizerConfig.EmbeddingBackend) - - // Convert server config to optimizer config - hybridRatio := 0.7 // Default - if cfg.OptimizerConfig.HybridSearchRatio != 0 { - hybridRatio = cfg.OptimizerConfig.HybridSearchRatio - } - optimizerCfg := &optimizer.Config{ - Enabled: cfg.OptimizerConfig.Enabled, - PersistPath: cfg.OptimizerConfig.PersistPath, - FTSDBPath: cfg.OptimizerConfig.FTSDBPath, - HybridSearchRatio: hybridRatio, - EmbeddingConfig: &embeddings.Config{ - BackendType: cfg.OptimizerConfig.EmbeddingBackend, - BaseURL: cfg.OptimizerConfig.EmbeddingURL, - Model: cfg.OptimizerConfig.EmbeddingModel, - Dimension: cfg.OptimizerConfig.EmbeddingDimension, - }, - } + if cfg.OptimizerConfig != nil { + if cfg.OptimizerConfig.Enabled { + logger.Infow("Initializing optimizer integration (chromem-go)", + "persist_path", cfg.OptimizerConfig.PersistPath, + "embedding_backend", cfg.OptimizerConfig.EmbeddingBackend) + + // Convert server config to optimizer config + hybridRatio := 0.7 // Default + if cfg.OptimizerConfig.HybridSearchRatio != 0 { + hybridRatio = cfg.OptimizerConfig.HybridSearchRatio + } + optimizerCfg := &optimizer.Config{ + Enabled: cfg.OptimizerConfig.Enabled, + PersistPath: cfg.OptimizerConfig.PersistPath, + FTSDBPath: cfg.OptimizerConfig.FTSDBPath, + HybridSearchRatio: hybridRatio, + EmbeddingConfig: &embeddings.Config{ + BackendType: cfg.OptimizerConfig.EmbeddingBackend, + BaseURL: cfg.OptimizerConfig.EmbeddingURL, + Model: cfg.OptimizerConfig.EmbeddingModel, + Dimension: cfg.OptimizerConfig.EmbeddingDimension, + }, + } - optimizerInteg, err = optimizer.NewIntegration(ctx, optimizerCfg, mcpServer, backendClient) - if err != nil { - return nil, fmt.Errorf("failed to initialize optimizer: %w", err) - } - logger.Info("Optimizer integration initialized successfully") + optimizerInteg, err = optimizer.NewIntegration(ctx, optimizerCfg, mcpServer, backendClient, sessionManager) + if err != nil { + return nil, fmt.Errorf("failed to initialize optimizer: %w", err) + } + logger.Info("Optimizer integration initialized successfully") - // Ingest discovered backends at startup (populate optimizer database) - initialBackends := backendRegistry.List(ctx) - if err := optimizerInteg.IngestInitialBackends(ctx, initialBackends); err != nil { - logger.Warnf("Failed to ingest initial backends: %v", err) - // Don't fail server startup - optimizer can still work with incremental ingestion + // Ingest discovered backends at startup (populate optimizer database) + initialBackends := backendRegistry.List(ctx) + if err := optimizerInteg.IngestInitialBackends(ctx, initialBackends); err != nil { + logger.Warnf("Failed to ingest initial backends: %v", err) + // Don't fail server startup - optimizer can still work with incremental ingestion + } + } else { + logger.Info("Optimizer configuration present but disabled (enabled=false), skipping initialization") } } @@ -512,23 +524,59 @@ func New( "resource_count", len(caps.RoutingTable.Resources), "prompt_count", len(caps.RoutingTable.Prompts)) - // Inject capabilities into SDK session - if err := srv.injectCapabilities(sessionID, caps); err != nil { - logger.Errorw("failed to inject session capabilities", - "error", err, - "session_id", sessionID) - return - } + // When optimizer is enabled, we should NOT inject backend tools directly. + // Instead, only optimizer tools (optim.find_tool, optim.call_tool) will be exposed. + // Backend tools are still discovered and stored for optimizer ingestion, + // but not exposed directly to clients. + if srv.optimizerIntegration == nil { + // Inject capabilities into SDK session (only when optimizer is disabled) + if err := srv.injectCapabilities(sessionID, caps); err != nil { + logger.Errorw("failed to inject session capabilities", + "error", err, + "session_id", sessionID) + return + } - logger.Infow("session capabilities injected", - "session_id", sessionID, - "tool_count", len(caps.Tools), - "resource_count", len(caps.Resources)) + logger.Infow("session capabilities injected", + "session_id", sessionID, + "tool_count", len(caps.Tools), + "resource_count", len(caps.Resources)) + } else { + // Optimizer is enabled - register optimizer tools FIRST so they're available immediately + // Backend tools will be accessible via optim.find_tool and optim.call_tool + if err := srv.optimizerIntegration.RegisterTools(ctx, session); err != nil { + logger.Errorw("failed to register optimizer tools", + "error", err, + "session_id", sessionID) + // Don't fail session initialization - continue without optimizer tools + } else { + logger.Infow("optimizer tools registered", + "session_id", sessionID) + } - // Generate embeddings and register optimizer tools if enabled - if srv.optimizerIntegration != nil { - logger.Debugw("Generating embeddings for optimizer", "session_id", sessionID) + // Inject resources (but not backend tools) + if len(caps.Resources) > 0 { + sdkResources := srv.capabilityAdapter.ToSDKResources(caps.Resources) + if err := srv.mcpServer.AddSessionResources(sessionID, sdkResources...); err != nil { + logger.Errorw("failed to add session resources", + "error", err, + "session_id", sessionID) + return + } + logger.Debugw("added session resources (optimizer mode)", + "session_id", sessionID, + "count", len(sdkResources)) + } + logger.Infow("optimizer mode: backend tools not exposed directly", + "session_id", sessionID, + "backend_tool_count", len(caps.Tools), + "resource_count", len(caps.Resources)) + } + // Generate embeddings for optimizer if enabled + // This happens after tools are registered so tools are available immediately + if srv.optimizerIntegration != nil { + logger.Debugw("Calling OnRegisterSession for optimizer", "session_id", sessionID) // Generate embeddings for all tools in this session if err := srv.optimizerIntegration.OnRegisterSession(ctx, session, caps); err != nil { logger.Errorw("failed to generate embeddings for optimizer", @@ -536,16 +584,7 @@ func New( "session_id", sessionID) // Don't fail session initialization - continue without optimizer } else { - // Register optimizer tools (optim.find_tool, optim.call_tool) - if err := srv.optimizerIntegration.RegisterTools(ctx, session); err != nil { - logger.Errorw("failed to register optimizer tools", - "error", err, - "session_id", sessionID) - // Don't fail session initialization - continue without optimizer tools - } else { - logger.Infow("optimizer tools registered", - "session_id", sessionID) - } + logger.Debugw("OnRegisterSession completed successfully", "session_id", sessionID) } } }) diff --git a/scripts/README.md b/scripts/README.md index 09a382f6b0..fa19fe399d 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -81,7 +81,40 @@ Then open any `.db` file in VSCode to browse tables visually. ## Testing Scripts -### Optimizer Tests +### Optimizer Tool Finding Tests + +These scripts test the `optim.find_tool` functionality in different scenarios: + +#### Test via vMCP Server Connection +```bash +# Test optim.find_tool through a running vMCP server +go run scripts/test-vmcp-find-tool/main.go "read pull requests from GitHub" [server_url] + +# Default server URL: http://localhost:4483/mcp +# Example: +go run scripts/test-vmcp-find-tool/main.go "search the web" http://localhost:4483/mcp +``` +Connects to a running vMCP server and calls `optim.find_tool` via the MCP protocol. Useful for integration testing with a live server. + +#### Call Optimizer Tool Directly +```bash +# Call optim.find_tool via MCP client +go run scripts/call-optim-find-tool/main.go [tool_keywords] [limit] [server_url] + +# Examples: +go run scripts/call-optim-find-tool/main.go "search the web" "web search" 20 +go run scripts/call-optim-find-tool/main.go "read files" "" 10 http://localhost:4483/mcp +``` +A more flexible client for calling `optim.find_tool` with various parameters. Useful for manual testing and debugging. + +#### Test Optimizer Handler Directly +```bash +# Test the optimizer handler directly (unit test style) +go run scripts/test-optim-find-tool/main.go "read pull requests from GitHub" +``` +Tests the optimizer's `find_tool` handler directly without requiring a full vMCP server. Creates a mock environment with test tools and embeddings. Useful for development and debugging the optimizer logic. + +### Other Optimizer Tests ```bash # Test with sqlite-vec extension ./scripts/test-optimizer-with-sqlite-vec.sh diff --git a/scripts/call-optim-find-tool/main.go b/scripts/call-optim-find-tool/main.go new file mode 100644 index 0000000000..3df36a3e86 --- /dev/null +++ b/scripts/call-optim-find-tool/main.go @@ -0,0 +1,137 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +func main() { + if len(os.Args) < 2 { + fmt.Println("Usage: go run main.go [tool_keywords] [limit] [server_url]") + fmt.Println("Example: go run main.go 'search the web' 'web search' 20") + fmt.Println("Default server URL: http://localhost:4483/mcp") + os.Exit(1) + } + + toolDescription := os.Args[1] + toolKeywords := "" + if len(os.Args) >= 3 { + toolKeywords = os.Args[2] + } + limit := 20 + if len(os.Args) >= 4 { + if l, err := fmt.Sscanf(os.Args[3], "%d", &limit); err != nil || l != 1 { + fmt.Printf("Invalid limit: %s, using default 20\n", os.Args[3]) + limit = 20 + } + } + serverURL := "http://localhost:4483/mcp" + if len(os.Args) >= 5 { + serverURL = os.Args[4] + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Create streamable-http client to connect to vmcp server + mcpClient, err := client.NewStreamableHttpClient( + serverURL, + transport.WithHTTPTimeout(30*time.Second), + transport.WithContinuousListening(), + ) + if err != nil { + fmt.Printf("❌ Failed to create MCP client: %v\n", err) + os.Exit(1) + } + defer func() { + if err := mcpClient.Close(); err != nil { + fmt.Printf("⚠️ Error closing client: %v\n", err) + } + }() + + // Start the client connection + if err := mcpClient.Start(ctx); err != nil { + fmt.Printf("❌ Failed to start client connection: %v\n", err) + os.Exit(1) + } + + // Initialize the client + initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "optim-find-tool-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{}, + }, + }) + if err != nil { + fmt.Printf("❌ Failed to initialize client: %v\n", err) + os.Exit(1) + } + fmt.Printf("✅ Connected to: %s %s\n", initResult.ServerInfo.Name, initResult.ServerInfo.Version) + + // Call optim.find_tool + args := map[string]any{ + "tool_description": toolDescription, + "limit": limit, + } + if toolKeywords != "" { + args["tool_keywords"] = toolKeywords + } + + callResult, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: args, + }, + }) + if err != nil { + fmt.Printf("❌ Failed to call optim.find_tool: %v\n", err) + os.Exit(1) + } + + if callResult.IsError { + fmt.Printf("❌ Tool call returned an error\n") + if len(callResult.Content) > 0 { + if textContent, ok := mcp.AsTextContent(callResult.Content[0]); ok { + fmt.Printf("Error: %s\n", textContent.Text) + } + } + os.Exit(1) + } + + // Parse and display the result + if len(callResult.Content) > 0 { + if textContent, ok := mcp.AsTextContent(callResult.Content[0]); ok { + // Try to parse as JSON for pretty printing + var resultData map[string]any + if err := json.Unmarshal([]byte(textContent.Text), &resultData); err == nil { + // Pretty print JSON + prettyJSON, err := json.MarshalIndent(resultData, "", " ") + if err == nil { + fmt.Println(string(prettyJSON)) + } else { + fmt.Println(textContent.Text) + } + } else { + fmt.Println(textContent.Text) + } + } else { + fmt.Printf("%+v\n", callResult.Content) + } + } else { + fmt.Println("(No content returned)") + } +} diff --git a/scripts/inspect-chromem/inspect-chromem.go b/scripts/inspect-chromem/inspect-chromem.go index 672741b5ae..14b5c5e4a0 100644 --- a/scripts/inspect-chromem/inspect-chromem.go +++ b/scripts/inspect-chromem/inspect-chromem.go @@ -35,9 +35,9 @@ func main() { fmt.Println(" - backend_tools") fmt.Println() - // Create a dummy embedding function (we're just inspecting, not querying) + // Create an embedding function for collection access (we're just inspecting, not querying) dummyEmbedding := func(ctx context.Context, text string) ([]float32, error) { - return make([]float32, 384), nil // Placeholder + return make([]float32, 384), nil } // Inspect backend_servers collection diff --git a/scripts/test-optim-find-tool/main.go b/scripts/test-optim-find-tool/main.go new file mode 100644 index 0000000000..e61fc8c9c2 --- /dev/null +++ b/scripts/test-optim-find-tool/main.go @@ -0,0 +1,246 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/optimizer/embeddings" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +func main() { + if len(os.Args) < 2 { + fmt.Println("Usage: go run main.go ") + fmt.Println("Example: go run main.go 'read pull requests from GitHub'") + os.Exit(1) + } + + query := os.Args[1] + ctx := context.Background() + tmpDir := filepath.Join(os.TempDir(), "optimizer-test") + os.MkdirAll(tmpDir, 0755) + + fmt.Printf("🔍 Testing optim.find_tool with query: %s\n\n", query) + + // Create MCP server + mcpServer := server.NewMCPServer("test-server", "1.0") + + // Create mock backend client + mockClient := &mockBackendClient{} + + // Configure optimizer + optimizerConfig := &optimizer.Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + // Create optimizer integration + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := optimizer.NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + if err != nil { + fmt.Printf("❌ Failed to create optimizer integration: %v\n", err) + os.Exit(1) + } + defer func() { _ = integration.Close() }() + + fmt.Println("✅ Optimizer integration created") + + // Ingest some test tools + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + if err != nil { + fmt.Printf("⚠️ Failed to ingest initial backends: %v (continuing...)\n", err) + } + + // Create a test session + sessionID := "test-session-123" + testSession := &mockSession{sessionID: sessionID} + + // Create capabilities with GitHub tools + capabilities := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Read details of a pull request from GitHub", + BackendID: "github", + }, + { + Name: "github_issue_read", + Description: "Read details of an issue from GitHub", + BackendID: "github", + }, + { + Name: "github_pull_request_list", + Description: "List pull requests in a GitHub repository", + BackendID: "github", + }, + }, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "github_pull_request_read": { + WorkloadID: "github", + WorkloadName: "GitHub", + }, + "github_issue_read": { + WorkloadID: "github", + WorkloadName: "GitHub", + }, + "github_pull_request_list": { + WorkloadID: "github", + WorkloadName: "GitHub", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + // Register session with MCP server first (needed for RegisterTools) + err = mcpServer.RegisterSession(ctx, testSession) + if err != nil { + fmt.Printf("⚠️ Failed to register session: %v\n", err) + } + + // Generate embeddings for session + err = integration.OnRegisterSession(ctx, testSession, capabilities) + if err != nil { + fmt.Printf("❌ Failed to generate embeddings: %v\n", err) + os.Exit(1) + } + fmt.Println("✅ Embeddings generated for session") + + // Skip RegisterTools since we're calling the handler directly + // RegisterTools requires per-session tool support which the mock doesn't have + // err = integration.RegisterTools(ctx, testSession) + // if err != nil { + // fmt.Printf("⚠️ Failed to register optimizer tools: %v (skipping, calling handler directly)\n", err) + // } + fmt.Println("⏭️ Skipping tool registration (testing handler directly)") + + // Now try to call optim.find_tool directly via the handler + fmt.Printf("\n🔍 Calling optim.find_tool handler directly...\n\n") + + // Create a context with capabilities (needed for the handler) + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Create the tool call request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "github pull request", + "limit": 10, + }, + }, + } + + // Call the handler directly using the exported test method + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + if err != nil { + fmt.Printf("❌ Failed to call optim.find_tool: %v\n", err) + os.Exit(1) + } + + fmt.Println("\n✅ Successfully called optim.find_tool!") + fmt.Println("\n📊 Results:") + + // Print the result - CallToolResult has Content field which is a slice + resultJSON, err := json.MarshalIndent(result, "", " ") + if err != nil { + fmt.Printf("Error marshaling result: %v\n", err) + fmt.Printf("Raw result: %+v\n", result) + } else { + fmt.Println(string(resultJSON)) + } +} + +type mockBackendClient struct{} + +func (m *mockBackendClient) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + return &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Read details of a pull request from GitHub", + }, + { + Name: "github_issue_read", + Description: "Read details of an issue from GitHub", + }, + { + Name: "github_pull_request_list", + Description: "List pull requests in a GitHub repository", + }, + }, + }, nil +} + +func (m *mockBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (map[string]any, error) { + return nil, nil +} + +func (m *mockBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (string, error) { + return "", nil +} + +func (m *mockBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) ([]byte, error) { + return nil, nil +} + +type mockSession struct { + sessionID string +} + +func (m *mockSession) SessionID() string { + return m.sessionID +} + +func (m *mockSession) Send(_ interface{}) error { + return nil +} + +func (m *mockSession) Close() error { + return nil +} + +func (m *mockSession) Initialize() {} + +func (m *mockSession) Initialized() bool { + return true +} + +func (m *mockSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + ch := make(chan mcp.JSONRPCNotification, 1) + return ch +} diff --git a/scripts/test-vmcp-find-tool/main.go b/scripts/test-vmcp-find-tool/main.go new file mode 100644 index 0000000000..71861d2508 --- /dev/null +++ b/scripts/test-vmcp-find-tool/main.go @@ -0,0 +1,158 @@ +//go:build ignore +// +build ignore + +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +func main() { + if len(os.Args) < 2 { + fmt.Println("Usage: go run main.go [server_url]") + fmt.Println("Example: go run main.go 'read pull requests from GitHub'") + fmt.Println("Default server URL: http://localhost:4483/mcp") + os.Exit(1) + } + + query := os.Args[1] + serverURL := "http://localhost:4483/mcp" + if len(os.Args) >= 3 { + serverURL = os.Args[2] + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + fmt.Printf("🔍 Testing optim.find_tool via vmcp server\n") + fmt.Printf(" Server: %s\n", serverURL) + fmt.Printf(" Query: %s\n\n", query) + + // Create streamable-http client to connect to vmcp server + mcpClient, err := client.NewStreamableHttpClient( + serverURL, + transport.WithHTTPTimeout(30*time.Second), + transport.WithContinuousListening(), + ) + if err != nil { + fmt.Printf("❌ Failed to create MCP client: %v\n", err) + os.Exit(1) + } + defer func() { + if err := mcpClient.Close(); err != nil { + fmt.Printf("⚠️ Error closing client: %v\n", err) + } + }() + + // Start the client connection + if err := mcpClient.Start(ctx); err != nil { + fmt.Printf("❌ Failed to start client connection: %v\n", err) + os.Exit(1) + } + fmt.Println("✅ Connected to vmcp server") + + // Initialize the client + initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-vmcp-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{}, + }, + }) + if err != nil { + fmt.Printf("❌ Failed to initialize client: %v\n", err) + os.Exit(1) + } + fmt.Printf("✅ Initialized - Server: %s %s\n\n", initResult.ServerInfo.Name, initResult.ServerInfo.Version) + + // List available tools to see if optim.find_tool is available + fmt.Println("📋 Listing available tools...") + toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + fmt.Printf("❌ Failed to list tools: %v\n", err) + os.Exit(1) + } + + fmt.Printf("Found %d tools:\n", len(toolsResult.Tools)) + hasFindTool := false + for _, tool := range toolsResult.Tools { + fmt.Printf(" - %s: %s\n", tool.Name, tool.Description) + if tool.Name == "optim.find_tool" { + hasFindTool = true + } + } + fmt.Println() + + if !hasFindTool { + fmt.Println("⚠️ Warning: optim.find_tool not found in available tools") + fmt.Println(" The optimizer may not be enabled on this vmcp server") + fmt.Println(" Continuing anyway...\n") + } + + // Call optim.find_tool + fmt.Printf("🔍 Calling optim.find_tool with query: %s\n\n", query) + + callResult, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "pull request", + "limit": 20, + }, + }, + }) + if err != nil { + fmt.Printf("❌ Failed to call optim.find_tool: %v\n", err) + os.Exit(1) + } + + if callResult.IsError { + fmt.Printf("❌ Tool call returned an error\n") + if len(callResult.Content) > 0 { + if textContent, ok := mcp.AsTextContent(callResult.Content[0]); ok { + fmt.Printf("Error: %s\n", textContent.Text) + } + } + os.Exit(1) + } + + fmt.Println("✅ Successfully called optim.find_tool!") + fmt.Println("\n📊 Results:") + + // Parse and display the result + if len(callResult.Content) > 0 { + if textContent, ok := mcp.AsTextContent(callResult.Content[0]); ok { + // Try to parse as JSON for pretty printing + var resultData map[string]any + if err := json.Unmarshal([]byte(textContent.Text), &resultData); err == nil { + // Pretty print JSON + prettyJSON, err := json.MarshalIndent(resultData, "", " ") + if err == nil { + fmt.Println(string(prettyJSON)) + } else { + fmt.Println(textContent.Text) + } + } else { + // Not JSON, print as-is + fmt.Println(textContent.Text) + } + } else { + // Not text content, print raw + fmt.Printf("%+v\n", callResult.Content) + } + } else { + fmt.Println("(No content returned)") + } +} From 949b284b744d6ed15bf13f5900e27ea5bc30eff7 Mon Sep 17 00:00:00 2001 From: Nigel Brown Date: Mon, 19 Jan 2026 12:49:22 +0000 Subject: [PATCH 04/16] fix: Resolve tool names in optim.find_tool to match routing table (#3337) * fix: Resolve tool names in optim.find_tool to match routing table --- pkg/vmcp/discovery/middleware_test.go | 13 ++++- pkg/vmcp/optimizer/optimizer.go | 71 +++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 6 deletions(-) diff --git a/pkg/vmcp/discovery/middleware_test.go b/pkg/vmcp/discovery/middleware_test.go index 7cbaad0ab1..4d82eb0dca 100644 --- a/pkg/vmcp/discovery/middleware_test.go +++ b/pkg/vmcp/discovery/middleware_test.go @@ -301,8 +301,19 @@ func TestMiddleware_CapabilitiesInContext(t *testing.T) { }, } + // Use Do to capture and verify backends separately, since order may vary mockMgr.EXPECT(). - Discover(gomock.Any(), backends). + Discover(gomock.Any(), gomock.Any()). + Do(func(_ context.Context, actualBackends []vmcp.Backend) { + // Verify that we got the expected backends regardless of order + assert.Len(t, actualBackends, 2) + backendIDs := make(map[string]bool) + for _, b := range actualBackends { + backendIDs[b.ID] = true + } + assert.True(t, backendIDs["backend1"], "backend1 should be present") + assert.True(t, backendIDs["backend2"], "backend2 should be present") + }). Return(expectedCaps, nil) // Create handler that inspects context in detail diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index 19553ea2e1..d03c294fa2 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -234,8 +234,60 @@ func extractFindToolParams(args map[string]any) (toolDescription, toolKeywords s return toolDescription, toolKeywords, limit, nil } -// convertSearchResultsToResponse converts database search results to the response format -func convertSearchResultsToResponse(results []*models.BackendToolWithMetadata) ([]map[string]any, int) { +// resolveToolName looks up the resolved name for a tool in the routing table. +// Returns the resolved name if found, otherwise returns the original name. +// +// The routing table maps resolved names (after conflict resolution) to BackendTarget. +// Each BackendTarget contains: +// - WorkloadID: the backend ID +// - OriginalCapabilityName: the original tool name (empty if not renamed) +// +// We need to find the resolved name by matching backend ID and original name. +func resolveToolName(routingTable *vmcp.RoutingTable, backendID string, originalName string) string { + if routingTable == nil || routingTable.Tools == nil { + return originalName + } + + // Search through routing table to find the resolved name + // Match by backend ID and original capability name + for resolvedName, target := range routingTable.Tools { + // Case 1: Tool was renamed (OriginalCapabilityName is set) + // Match by backend ID and original name + if target.WorkloadID == backendID && target.OriginalCapabilityName == originalName { + logger.Debugw("Resolved tool name (renamed)", + "backend_id", backendID, + "original_name", originalName, + "resolved_name", resolvedName) + return resolvedName + } + + // Case 2: Tool was not renamed (OriginalCapabilityName is empty) + // Match by backend ID and resolved name equals original name + if target.WorkloadID == backendID && target.OriginalCapabilityName == "" && resolvedName == originalName { + logger.Debugw("Resolved tool name (not renamed)", + "backend_id", backendID, + "original_name", originalName, + "resolved_name", resolvedName) + return resolvedName + } + } + + // If not found, return original name (fallback for tools not in routing table) + // This can happen if: + // - Tool was just ingested but routing table hasn't been updated yet + // - Tool belongs to a backend that's not currently registered + logger.Debugw("Tool name not found in routing table, using original name", + "backend_id", backendID, + "original_name", originalName) + return originalName +} + +// convertSearchResultsToResponse converts database search results to the response format. +// It resolves tool names using the routing table to ensure returned names match routing table keys. +func convertSearchResultsToResponse( + results []*models.BackendToolWithMetadata, + routingTable *vmcp.RoutingTable, +) ([]map[string]any, int) { responseTools := make([]map[string]any, 0, len(results)) totalReturnedTokens := 0 @@ -258,8 +310,11 @@ func convertSearchResultsToResponse(results []*models.BackendToolWithMetadata) ( description = *result.Description } + // Resolve tool name using routing table to ensure it matches routing table keys + resolvedName := resolveToolName(routingTable, result.MCPServerID, result.ToolName) + tool := map[string]any{ - "name": result.ToolName, + "name": resolvedName, "description": description, "input_schema": inputSchema, "backend_id": result.MCPServerID, @@ -321,8 +376,14 @@ func (o *OptimizerIntegration) createFindToolHandler() func(context.Context, mcp return mcp.NewToolResultError(fmt.Sprintf("search failed: %v", err2)), nil } - // Convert results to response format - responseTools, totalReturnedTokens := convertSearchResultsToResponse(results) + // Get routing table from context to resolve tool names + var routingTable *vmcp.RoutingTable + if capabilities, ok := discovery.DiscoveredCapabilitiesFromContext(ctx); ok && capabilities != nil { + routingTable = capabilities.RoutingTable + } + + // Convert results to response format, resolving tool names to match routing table + responseTools, totalReturnedTokens := convertSearchResultsToResponse(results, routingTable) // Calculate token metrics baselineTokens := o.ingestionService.GetTotalToolTokens(ctx) From eb22a40558f324149e839dd8c2633c521b392d26 Mon Sep 17 00:00:00 2001 From: Nigel Brown Date: Tue, 20 Jan 2026 10:13:45 +0000 Subject: [PATCH 05/16] Add token metrics and observability to optimizer integration (#3347) * feat: Add token metrics and observability to optimizer integration --- .gitignore | 3 +- examples/vmcp-config-optimizer.yaml | 13 + pkg/optimizer/ingestion/service.go | 119 ++++++++- pkg/optimizer/ingestion/service_test.go | 73 ++++++ pkg/vmcp/optimizer/optimizer.go | 114 +++++++- .../optimizer/optimizer_integration_test.go | 248 ++++++++++++++++++ pkg/vmcp/server/server.go | 15 +- 7 files changed, 566 insertions(+), 19 deletions(-) diff --git a/.gitignore b/.gitignore index 7f5be4bf18..f0840c001e 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,5 @@ cmd/thv-operator/.task/checksum/crdref-gen # Test coverage coverage* -crd-helm-wrapper \ No newline at end of file +crd-helm-wrapper +cmd/vmcp/__debug_bin* diff --git a/examples/vmcp-config-optimizer.yaml b/examples/vmcp-config-optimizer.yaml index 7687dabb7d..4770caf355 100644 --- a/examples/vmcp-config-optimizer.yaml +++ b/examples/vmcp-config-optimizer.yaml @@ -95,6 +95,19 @@ optimizer: # embeddingService: embedding-service-name # (vMCP will resolve the service DNS name) +# ============================================================================= +# TELEMETRY CONFIGURATION (for Jaeger tracing) +# ============================================================================= +# Configure OpenTelemetry to send traces to Jaeger +telemetry: + endpoint: "localhost:4318" # OTLP HTTP endpoint (Jaeger collector) - no http:// prefix needed with insecure: true + serviceName: "vmcp-optimizer" + serviceVersion: "1.0.0" # Optional: service version + tracingEnabled: true + metricsEnabled: false # Set to true if you want metrics too + samplingRate: "1.0" # 100% sampling for development (use lower in production) + insecure: true # Use HTTP instead of HTTPS + # ============================================================================= # USAGE # ============================================================================= diff --git a/pkg/optimizer/ingestion/service.go b/pkg/optimizer/ingestion/service.go index 9b63e01289..66e46f57d6 100644 --- a/pkg/optimizer/ingestion/service.go +++ b/pkg/optimizer/ingestion/service.go @@ -4,10 +4,15 @@ import ( "context" "encoding/json" "fmt" + "sync" "time" "github.com/google/uuid" "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/optimizer/db" @@ -47,6 +52,11 @@ type Service struct { tokenCounter *tokens.Counter backendServerOps *db.BackendServerOps backendToolOps *db.BackendToolOps + tracer trace.Tracer + + // Embedding time tracking + embeddingTimeMu sync.Mutex + totalEmbeddingTime time.Duration } // NewService creates a new ingestion service @@ -80,27 +90,58 @@ func NewService(config *Config) (*Service, error) { // Initialize token counter tokenCounter := tokens.NewCounter() - // Create chromem-go embeddingFunc from our embedding manager - embeddingFunc := func(_ context.Context, text string) ([]float32, error) { + // Initialize tracer + tracer := otel.Tracer("github.com/stacklok/toolhive/pkg/optimizer/ingestion") + + svc := &Service{ + config: config, + database: database, + embeddingManager: embeddingManager, + tokenCounter: tokenCounter, + tracer: tracer, + totalEmbeddingTime: 0, + } + + // Create chromem-go embeddingFunc from our embedding manager with tracing + embeddingFunc := func(ctx context.Context, text string) ([]float32, error) { + // Create a span for embedding calculation + _, span := svc.tracer.Start(ctx, "optimizer.ingestion.calculate_embedding", + trace.WithAttributes( + attribute.String("operation", "embedding_calculation"), + )) + defer span.End() + + start := time.Now() + // Our manager takes a slice, so wrap the single text embeddingsResult, err := embeddingManager.GenerateEmbedding([]string{text}) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, err } if len(embeddingsResult) == 0 { - return nil, fmt.Errorf("no embeddings generated") + err := fmt.Errorf("no embeddings generated") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return nil, err } + + // Track embedding time + duration := time.Since(start) + svc.embeddingTimeMu.Lock() + svc.totalEmbeddingTime += duration + svc.embeddingTimeMu.Unlock() + + span.SetAttributes( + attribute.Int64("embedding.duration_ms", duration.Milliseconds()), + ) + return embeddingsResult[0], nil } - svc := &Service{ - config: config, - database: database, - embeddingManager: embeddingManager, - tokenCounter: tokenCounter, - backendServerOps: db.NewBackendServerOps(database, embeddingFunc), - backendToolOps: db.NewBackendToolOps(database, embeddingFunc), - } + svc.backendServerOps = db.NewBackendServerOps(database, embeddingFunc) + svc.backendToolOps = db.NewBackendToolOps(database, embeddingFunc) logger.Info("Ingestion service initialized for event-driven ingestion (chromem-go)") return svc, nil @@ -129,6 +170,16 @@ func (s *Service) IngestServer( description *string, tools []mcp.Tool, ) error { + // Create a span for the entire ingestion operation + ctx, span := s.tracer.Start(ctx, "optimizer.ingestion.ingest_server", + trace.WithAttributes( + attribute.String("server.id", serverID), + attribute.String("server.name", serverName), + attribute.Int("tools.count", len(tools)), + )) + defer span.End() + + start := time.Now() logger.Infof("Ingesting server: %s (%d tools) [serverID=%s]", serverName, len(tools), serverID) // Create backend server record (simplified - vMCP manages lifecycle) @@ -144,6 +195,8 @@ func (s *Service) IngestServer( // Create or update server (chromem-go handles embeddings) if err := s.backendServerOps.Update(ctx, backendServer); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return fmt.Errorf("failed to create/update server %s: %w", serverName, err) } logger.Debugf("Created/updated server: %s", serverName) @@ -151,18 +204,42 @@ func (s *Service) IngestServer( // Sync tools for this server toolCount, err := s.syncBackendTools(ctx, serverID, serverName, tools) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return fmt.Errorf("failed to sync tools for %s: %w", serverName, err) } - logger.Infof("Successfully ingested server %s with %d tools", serverName, toolCount) + duration := time.Since(start) + span.SetAttributes( + attribute.Int64("ingestion.duration_ms", duration.Milliseconds()), + attribute.Int("tools.ingested", toolCount), + ) + + logger.Infow("Successfully ingested server", + "server_name", serverName, + "server_id", serverID, + "tools_count", toolCount, + "duration_ms", duration.Milliseconds()) return nil } // syncBackendTools synchronizes tools for a backend server func (s *Service) syncBackendTools(ctx context.Context, serverID string, serverName string, tools []mcp.Tool) (int, error) { + // Create a span for tool synchronization + ctx, span := s.tracer.Start(ctx, "optimizer.ingestion.sync_backend_tools", + trace.WithAttributes( + attribute.String("server.id", serverID), + attribute.String("server.name", serverName), + attribute.Int("tools.count", len(tools)), + )) + defer span.End() + logger.Debugf("syncBackendTools: server=%s, serverID=%s, tool_count=%d", serverName, serverID, len(tools)) + // Delete existing tools if err := s.backendToolOps.DeleteByServer(ctx, serverID); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return 0, fmt.Errorf("failed to delete existing tools: %w", err) } @@ -178,6 +255,8 @@ func (s *Service) syncBackendTools(ctx context.Context, serverID string, serverN // Convert InputSchema to JSON schemaJSON, err := json.Marshal(tool.InputSchema) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return 0, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) } @@ -193,6 +272,8 @@ func (s *Service) syncBackendTools(ctx context.Context, serverID string, serverN } if err := s.backendToolOps.Create(ctx, backendTool, serverName); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return 0, fmt.Errorf("failed to create tool %s: %w", tool.Name, err) } } @@ -228,6 +309,20 @@ func (s *Service) GetTotalToolTokens(ctx context.Context) int { return 0 } +// GetTotalEmbeddingTime returns the total time spent calculating embeddings +func (s *Service) GetTotalEmbeddingTime() time.Duration { + s.embeddingTimeMu.Lock() + defer s.embeddingTimeMu.Unlock() + return s.totalEmbeddingTime +} + +// ResetEmbeddingTime resets the total embedding time counter +func (s *Service) ResetEmbeddingTime() { + s.embeddingTimeMu.Lock() + defer s.embeddingTimeMu.Unlock() + s.totalEmbeddingTime = 0 +} + // Close releases resources func (s *Service) Close() error { var errs []error diff --git a/pkg/optimizer/ingestion/service_test.go b/pkg/optimizer/ingestion/service_test.go index acc5b18754..5777bf3049 100644 --- a/pkg/optimizer/ingestion/service_test.go +++ b/pkg/optimizer/ingestion/service_test.go @@ -5,6 +5,7 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" @@ -108,6 +109,78 @@ func TestServiceCreationAndIngestion(t *testing.T) { require.True(t, toolNamesFound["search_web"], "search_web should be in results") } +// TestService_EmbeddingTimeTracking tests that embedding time is tracked correctly +func TestService_EmbeddingTimeTracking(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + // Initialize service + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Initially, embedding time should be 0 + initialTime := svc.GetTotalEmbeddingTime() + require.Equal(t, time.Duration(0), initialTime, "Initial embedding time should be 0") + + // Create test tools + tools := []mcp.Tool{ + { + Name: "test_tool_1", + Description: "First test tool for embedding", + }, + { + Name: "test_tool_2", + Description: "Second test tool for embedding", + }, + } + + // Reset embedding time before ingestion + svc.ResetEmbeddingTime() + + // Ingest server with tools (this will generate embeddings) + err = svc.IngestServer(ctx, "test-server-id", "TestServer", nil, tools) + require.NoError(t, err) + + // After ingestion, embedding time should be greater than 0 + totalEmbeddingTime := svc.GetTotalEmbeddingTime() + require.Greater(t, totalEmbeddingTime, time.Duration(0), + "Total embedding time should be greater than 0 after ingestion") + + // Reset and verify it's back to 0 + svc.ResetEmbeddingTime() + resetTime := svc.GetTotalEmbeddingTime() + require.Equal(t, time.Duration(0), resetTime, "Embedding time should be 0 after reset") +} + // TestServiceWithOllama demonstrates using real embeddings (requires Ollama running) // This test can be enabled locally to verify Ollama integration func TestServiceWithOllama(t *testing.T) { diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index d03c294fa2..03e32ce5d3 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -16,9 +16,14 @@ import ( "encoding/json" "fmt" "sync" + "time" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/optimizer/db" @@ -60,6 +65,7 @@ type OptimizerIntegration struct { backendClient vmcp.BackendClient // For querying backends at startup sessionManager *transportsession.Manager processedSessions sync.Map // Track sessions that have already been processed + tracer trace.Tracer } // NewIntegration creates a new optimizer integration. @@ -94,6 +100,7 @@ func NewIntegration( mcpServer: mcpServer, backendClient: backendClient, sessionManager: sessionManager, + tracer: otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/optimizer"), }, nil } @@ -400,6 +407,9 @@ func (o *OptimizerIntegration) createFindToolHandler() func(context.Context, mcp "savings_percentage": savingsPercentage, } + // Record OpenTelemetry metrics for token savings + o.recordTokenMetrics(ctx, baselineTokens, totalReturnedTokens, tokensSaved, savingsPercentage) + // Build response response := map[string]any{ "tools": responseTools, @@ -423,6 +433,72 @@ func (o *OptimizerIntegration) createFindToolHandler() func(context.Context, mcp } } +// recordTokenMetrics records OpenTelemetry metrics for token savings +func (*OptimizerIntegration) recordTokenMetrics( + ctx context.Context, + baselineTokens int, + returnedTokens int, + tokensSaved int, + savingsPercentage float64, +) { + // Get meter from global OpenTelemetry provider + meter := otel.Meter("github.com/stacklok/toolhive/pkg/vmcp/optimizer") + + // Create metrics if they don't exist (they'll be cached by the meter) + baselineCounter, err := meter.Int64Counter( + "toolhive_vmcp_optimizer_baseline_tokens", + metric.WithDescription("Total tokens for all tools in the optimizer database (baseline)"), + ) + if err != nil { + logger.Debugw("Failed to create baseline_tokens counter", "error", err) + return + } + + returnedCounter, err := meter.Int64Counter( + "toolhive_vmcp_optimizer_returned_tokens", + metric.WithDescription("Total tokens for tools returned by optim.find_tool"), + ) + if err != nil { + logger.Debugw("Failed to create returned_tokens counter", "error", err) + return + } + + savedCounter, err := meter.Int64Counter( + "toolhive_vmcp_optimizer_tokens_saved", + metric.WithDescription("Number of tokens saved by filtering tools with optim.find_tool"), + ) + if err != nil { + logger.Debugw("Failed to create tokens_saved counter", "error", err) + return + } + + savingsGauge, err := meter.Float64Gauge( + "toolhive_vmcp_optimizer_savings_percentage", + metric.WithDescription("Percentage of tokens saved by filtering tools (0-100)"), + metric.WithUnit("%"), + ) + if err != nil { + logger.Debugw("Failed to create savings_percentage gauge", "error", err) + return + } + + // Record metrics with attributes + attrs := metric.WithAttributes( + attribute.String("operation", "find_tool"), + ) + + baselineCounter.Add(ctx, int64(baselineTokens), attrs) + returnedCounter.Add(ctx, int64(returnedTokens), attrs) + savedCounter.Add(ctx, int64(tokensSaved), attrs) + savingsGauge.Record(ctx, savingsPercentage, attrs) + + logger.Debugw("Token metrics recorded", + "baseline_tokens", baselineTokens, + "returned_tokens", returnedTokens, + "tokens_saved", tokensSaved, + "savings_percentage", savingsPercentage) +} + // CreateCallToolHandler creates the handler for optim.call_tool // Exported for testing purposes func (o *OptimizerIntegration) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { @@ -523,11 +599,26 @@ func (o *OptimizerIntegration) createCallToolHandler() func(context.Context, mcp // This should be called after backends are discovered during server initialization. func (o *OptimizerIntegration) IngestInitialBackends(ctx context.Context, backends []vmcp.Backend) error { if o == nil || o.ingestionService == nil { - return nil // Optimizer disabled + // Optimizer disabled - log that embedding time is 0 + logger.Infow("Optimizer disabled, embedding time: 0ms") + return nil } + // Reset embedding time before starting ingestion + o.ingestionService.ResetEmbeddingTime() + + // Create a span for the entire ingestion process + ctx, span := o.tracer.Start(ctx, "optimizer.ingestion.ingest_initial_backends", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + )) + defer span.End() + + start := time.Now() logger.Infof("Ingesting %d discovered backends into optimizer", len(backends)) + ingestedCount := 0 + totalToolsIngested := 0 for _, backend := range backends { // Convert Backend to BackendTarget for client API target := vmcp.BackendToTarget(&backend) @@ -574,9 +665,28 @@ func (o *OptimizerIntegration) IngestInitialBackends(ctx context.Context, backen logger.Warnf("Failed to ingest backend %s: %v", backend.Name, err) continue // Log but don't fail startup } + ingestedCount++ + totalToolsIngested += len(tools) } - logger.Info("Initial backend ingestion completed") + // Get total embedding time + totalEmbeddingTime := o.ingestionService.GetTotalEmbeddingTime() + totalDuration := time.Since(start) + + span.SetAttributes( + attribute.Int64("ingestion.duration_ms", totalDuration.Milliseconds()), + attribute.Int64("embedding.duration_ms", totalEmbeddingTime.Milliseconds()), + attribute.Int("backends.ingested", ingestedCount), + attribute.Int("tools.ingested", totalToolsIngested), + ) + + logger.Infow("Initial backend ingestion completed", + "servers_ingested", ingestedCount, + "tools_ingested", totalToolsIngested, + "total_duration_ms", totalDuration.Milliseconds(), + "total_embedding_time_ms", totalEmbeddingTime.Milliseconds(), + "embedding_time_percentage", fmt.Sprintf("%.2f%%", float64(totalEmbeddingTime)/float64(totalDuration)*100)) + return nil } diff --git a/pkg/vmcp/optimizer/optimizer_integration_test.go b/pkg/vmcp/optimizer/optimizer_integration_test.go index 2fcb912743..4742de843d 100644 --- a/pkg/vmcp/optimizer/optimizer_integration_test.go +++ b/pkg/vmcp/optimizer/optimizer_integration_test.go @@ -2,6 +2,7 @@ package optimizer import ( "context" + "encoding/json" "path/filepath" "testing" "time" @@ -186,3 +187,250 @@ func TestOptimizerIntegration_WithVMCP(t *testing.T) { // of this integration test. The RegisterTools method is tested separately // in unit tests where we can properly mock the MCP server behavior. } + +// TestOptimizerIntegration_EmbeddingTimeTracking tests that embedding time is tracked and logged +func TestOptimizerIntegration_EmbeddingTimeTracking(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Create MCP server + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + + // Create mock backend client + mockClient := newMockIntegrationBackendClient() + mockClient.addBackend("github", &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + }, + { + Name: "get_repo", + Description: "Get repository information", + }, + }, + }) + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Configure optimizer + optimizerConfig := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + } + + // Create optimizer integration + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Verify embedding time starts at 0 + embeddingTime := integration.ingestionService.GetTotalEmbeddingTime() + require.Equal(t, time.Duration(0), embeddingTime, "Initial embedding time should be 0") + + // Ingest backends + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err) + + // After ingestion, embedding time should be tracked + // Note: The actual time depends on Ollama performance, but it should be > 0 + finalEmbeddingTime := integration.ingestionService.GetTotalEmbeddingTime() + require.Greater(t, finalEmbeddingTime, time.Duration(0), + "Embedding time should be tracked after ingestion") +} + +// TestOptimizerIntegration_DisabledEmbeddingTime tests that embedding time is 0 when optimizer is disabled +func TestOptimizerIntegration_DisabledEmbeddingTime(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create optimizer integration with disabled optimizer + optimizerConfig := &Config{ + Enabled: false, + } + + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + mockClient := newMockIntegrationBackendClient() + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.Nil(t, integration, "Integration should be nil when optimizer is disabled") + + // Try to ingest backends - should return nil without error + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + // This should handle nil integration gracefully + var nilIntegration *OptimizerIntegration + err = nilIntegration.IngestInitialBackends(ctx, backends) + require.NoError(t, err, "Should handle nil integration gracefully") +} + +// TestOptimizerIntegration_TokenMetrics tests that token metrics are calculated and returned in optim.find_tool +func TestOptimizerIntegration_TokenMetrics(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Create MCP server + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + + // Create mock backend client with multiple tools + mockClient := newMockIntegrationBackendClient() + mockClient.addBackend("github", &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + }, + { + Name: "get_pull_request", + Description: "Get a pull request from GitHub", + }, + { + Name: "list_repositories", + Description: "List repositories from GitHub", + }, + }, + }) + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Configure optimizer + optimizerConfig := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingConfig: &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + }, + } + + // Create optimizer integration + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Ingest backends + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err) + + // Get the find_tool handler + handler := integration.CreateFindToolHandler() + require.NotNil(t, handler) + + // Call optim.find_tool + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim.find_tool", + Arguments: map[string]any{ + "tool_description": "create issue", + "limit": 5, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify result contains token_metrics + require.NotNil(t, result.Content) + require.Len(t, result.Content, 1) + textResult, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok, "Result should be TextContent") + + // Parse JSON response + var response map[string]any + err = json.Unmarshal([]byte(textResult.Text), &response) + require.NoError(t, err) + + // Verify token_metrics exist + tokenMetrics, ok := response["token_metrics"].(map[string]any) + require.True(t, ok, "Response should contain token_metrics") + + // Verify token metrics fields + baselineTokens, ok := tokenMetrics["baseline_tokens"].(float64) + require.True(t, ok, "token_metrics should contain baseline_tokens") + require.Greater(t, baselineTokens, float64(0), "baseline_tokens should be greater than 0") + + returnedTokens, ok := tokenMetrics["returned_tokens"].(float64) + require.True(t, ok, "token_metrics should contain returned_tokens") + require.GreaterOrEqual(t, returnedTokens, float64(0), "returned_tokens should be >= 0") + + tokensSaved, ok := tokenMetrics["tokens_saved"].(float64) + require.True(t, ok, "token_metrics should contain tokens_saved") + require.GreaterOrEqual(t, tokensSaved, float64(0), "tokens_saved should be >= 0") + + savingsPercentage, ok := tokenMetrics["savings_percentage"].(float64) + require.True(t, ok, "token_metrics should contain savings_percentage") + require.GreaterOrEqual(t, savingsPercentage, float64(0), "savings_percentage should be >= 0") + require.LessOrEqual(t, savingsPercentage, float64(100), "savings_percentage should be <= 100") + + // Verify tools are returned + tools, ok := response["tools"].([]any) + require.True(t, ok, "Response should contain tools") + require.Greater(t, len(tools), 0, "Should return at least one tool") +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 7ec8b2bab5..8092268c63 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -393,6 +393,7 @@ func New( // Initialize optimizer integration if enabled var optimizerInteg OptimizerIntegration + if cfg.OptimizerConfig != nil { if cfg.OptimizerConfig.Enabled { logger.Infow("Initializing optimizer integration (chromem-go)", @@ -423,16 +424,22 @@ func New( } logger.Info("Optimizer integration initialized successfully") - // Ingest discovered backends at startup (populate optimizer database) + // Ingest discovered backends into optimizer database (for semantic search) + // Note: Backends are already discovered and registered with vMCP regardless of optimizer + // This step indexes them in the optimizer database for semantic search + // Timing is handled inside IngestInitialBackends initialBackends := backendRegistry.List(ctx) if err := optimizerInteg.IngestInitialBackends(ctx, initialBackends); err != nil { - logger.Warnf("Failed to ingest initial backends: %v", err) + logger.Warnf("Failed to ingest initial backends into optimizer: %v", err) // Don't fail server startup - optimizer can still work with incremental ingestion } - } else { - logger.Info("Optimizer configuration present but disabled (enabled=false), skipping initialization") + // Note: IngestInitialBackends logs "Initial backend ingestion completed" with timing } + // When optimizer is disabled, backends are still discovered and registered with vMCP, + // but no optimizer ingestion occurs, so no log entry is needed } + // When optimizer is not configured, backends are still discovered and registered with vMCP, + // but no optimizer ingestion occurs, so no log entry is needed // Create Server instance srv := &Server{ From 9aaac55c543307525317b62813f6ee352f132778 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 20 Jan 2026 10:42:09 +0000 Subject: [PATCH 06/16] fix: Bump operator-crds chart version to 0.0.97 after rebase --- deploy/charts/operator-crds/Chart.yaml | 2 +- deploy/charts/operator-crds/README.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deploy/charts/operator-crds/Chart.yaml b/deploy/charts/operator-crds/Chart.yaml index 47d15cab11..5da4ebde7a 100644 --- a/deploy/charts/operator-crds/Chart.yaml +++ b/deploy/charts/operator-crds/Chart.yaml @@ -2,5 +2,5 @@ apiVersion: v2 name: toolhive-operator-crds description: A Helm chart for installing the ToolHive Operator CRDs into Kubernetes. type: application -version: 0.0.96 +version: 0.0.97 appVersion: "0.0.1" diff --git a/deploy/charts/operator-crds/README.md b/deploy/charts/operator-crds/README.md index 8e44047b92..fc1e4ab5f4 100644 --- a/deploy/charts/operator-crds/README.md +++ b/deploy/charts/operator-crds/README.md @@ -1,6 +1,6 @@ # ToolHive Operator CRDs Helm Chart -![Version: 0.0.96](https://img.shields.io/badge/Version-0.0.96-informational?style=flat-square) +![Version: 0.0.97](https://img.shields.io/badge/Version-0.0.97-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) A Helm chart for installing the ToolHive Operator CRDs into Kubernetes. From 8ad52eb5166f11ddce2695ac9962eb48a7551cbb Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 20 Jan 2026 11:13:36 +0000 Subject: [PATCH 07/16] fix: Skip completed pods in checkPodsReady to prevent flaky e2e test failures The checkPodsReady function was checking all pods with matching labels, including old pods that had completed (Phase: Succeeded) from previous deployments. This caused the auth discovery e2e test to fail when old pods were still present during deployment updates. Fix: Skip pods that are not in Running phase and ensure at least one running pod exists after filtering. --- test/e2e/thv-operator/virtualmcp/helpers.go | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/test/e2e/thv-operator/virtualmcp/helpers.go b/test/e2e/thv-operator/virtualmcp/helpers.go index 449fbd177a..9959ee0b4c 100644 --- a/test/e2e/thv-operator/virtualmcp/helpers.go +++ b/test/e2e/thv-operator/virtualmcp/helpers.go @@ -86,8 +86,9 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe } for _, pod := range podList.Items { + // Skip pods that are not running (e.g., Succeeded, Failed from old deployments) if pod.Status.Phase != corev1.PodRunning { - return fmt.Errorf("pod %s is in phase %s", pod.Name, pod.Status.Phase) + continue } containerReady := false @@ -111,6 +112,17 @@ func checkPodsReady(ctx context.Context, c client.Client, namespace string, labe return fmt.Errorf("pod %s not ready", pod.Name) } } + + // After filtering, ensure we found at least one running pod + runningPods := 0 + for _, pod := range podList.Items { + if pod.Status.Phase == corev1.PodRunning { + runningPods++ + } + } + if runningPods == 0 { + return fmt.Errorf("no running pods found with labels %v", labels) + } return nil } From 628cb6fd9207336c8a1c581996f6c8c0121002da Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 20 Jan 2026 11:28:40 +0000 Subject: [PATCH 08/16] fix: Add pod readiness checks before health endpoint verification The test was failing with 'connection reset by peer' errors when trying to connect to the health endpoint. This can happen if pods crash or restart between the BeforeAll setup and the actual test execution. Fix: Add explicit pod readiness verification right before the health check and also check pod readiness inside the Eventually loop to catch pods that crash during health check retries. This makes the test more robust by ensuring pods are stable before attempting HTTP connections. --- .../virtualmcp/virtualmcp_auth_discovery_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go index b498d3dce9..ccb1de81f4 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go @@ -1159,8 +1159,19 @@ with socketserver.TCPServer(("", PORT), OIDCHandler) as httpd: } It("should list and call tools from all backends with discovered auth", func() { + By("Verifying vMCP pods are still running and ready before health check") + vmcpLabels := map[string]string{ + "app.kubernetes.io/name": "virtualmcpserver", + "app.kubernetes.io/instance": vmcpServerName, + } + WaitForPodsReady(ctx, k8sClient, testNamespace, vmcpLabels, 30*time.Second, 2*time.Second) + By("Verifying HTTP connectivity to VirtualMCPServer health endpoint") Eventually(func() error { + // Re-check pod readiness before each health check attempt + if err := checkPodsReady(ctx, k8sClient, testNamespace, vmcpLabels); err != nil { + return fmt.Errorf("pods not ready: %w", err) + } url := fmt.Sprintf("http://localhost:%d/health", vmcpNodePort) resp, err := http.Get(url) if err != nil { From b7655905e5f8c691944fa2c204c8b10ab487a30d Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 20 Jan 2026 12:27:36 +0000 Subject: [PATCH 09/16] fix: Add HTTP client timeout to health check in flaky e2e test The health check was using http.Get() without a timeout, which could cause hangs. Add an explicit HTTP client with 10s timeout and improve error messages to help diagnose connection reset issues. --- .../virtualmcp/virtualmcp_auth_discovery_test.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go index ccb1de81f4..b6285677b3 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_auth_discovery_test.go @@ -1166,6 +1166,11 @@ with socketserver.TCPServer(("", PORT), OIDCHandler) as httpd: } WaitForPodsReady(ctx, k8sClient, testNamespace, vmcpLabels, 30*time.Second, 2*time.Second) + // Create HTTP client with timeout for health checks + healthCheckClient := &http.Client{ + Timeout: 10 * time.Second, + } + By("Verifying HTTP connectivity to VirtualMCPServer health endpoint") Eventually(func() error { // Re-check pod readiness before each health check attempt @@ -1173,9 +1178,9 @@ with socketserver.TCPServer(("", PORT), OIDCHandler) as httpd: return fmt.Errorf("pods not ready: %w", err) } url := fmt.Sprintf("http://localhost:%d/health", vmcpNodePort) - resp, err := http.Get(url) + resp, err := healthCheckClient.Get(url) if err != nil { - return err + return fmt.Errorf("health check failed: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { From deb014078cfc0e384338f814d58dfccf0095a8f3 Mon Sep 17 00:00:00 2001 From: Yolanda Robla Mota Date: Tue, 20 Jan 2026 15:10:37 +0100 Subject: [PATCH 10/16] Add dynamic/static mode support to VirtualMCPServer operator (#3235) * remove docs * fixes from review * simplify code and fixes from review * fixes from review * fix ci --------- Co-authored-by: taskbot --- .../controllers/mcpremoteproxy_controller.go | 3 + .../virtualmcpserver_controller.go | 107 +- .../virtualmcpserver_deployment.go | 13 +- .../virtualmcpserver_vmcpconfig.go | 140 +- .../virtualmcpserver_vmcpconfig_test.go | 397 +++++- cmd/thv-operator/pkg/controllerutil/rbac.go | 4 + cmd/vmcp/app/commands.go | 31 +- ...olhive.stacklok.dev_virtualmcpservers.yaml | 46 + ...olhive.stacklok.dev_virtualmcpservers.yaml | 46 + docs/operator/crd-api.md | 24 +- pkg/vmcp/aggregator/discoverer.go | 76 + pkg/vmcp/aggregator/discoverer_test.go | 137 ++ pkg/vmcp/config/config.go | 52 + pkg/vmcp/config/zz_generated.deepcopy.go | 29 + pkg/vmcp/discovery/middleware_test.go | 58 +- pkg/vmcp/workloads/k8s.go | 3 + .../virtualmcp_discovered_mode_test.go | 379 +++++ .../virtualmcp/virtualmcp_lifecycle_test.go | 1245 ----------------- test/integration/vmcp/helpers/helpers_test.go | 76 + 19 files changed, 1557 insertions(+), 1309 deletions(-) delete mode 100644 test/e2e/thv-operator/virtualmcp/virtualmcp_lifecycle_test.go create mode 100644 test/integration/vmcp/helpers/helpers_test.go diff --git a/cmd/thv-operator/controllers/mcpremoteproxy_controller.go b/cmd/thv-operator/controllers/mcpremoteproxy_controller.go index e7d209d467..ef425303f0 100644 --- a/cmd/thv-operator/controllers/mcpremoteproxy_controller.go +++ b/cmd/thv-operator/controllers/mcpremoteproxy_controller.go @@ -465,6 +465,9 @@ func (r *MCPRemoteProxyReconciler) validateGroupRef(ctx context.Context, proxy * } // ensureRBACResources ensures that the RBAC resources are in place for the remote proxy +// TODO: This uses EnsureRBACResource which only creates RBAC but never updates them. +// Consider adopting the MCPRegistry pattern (pkg/registryapi/rbac.go) which uses +// CreateOrUpdate + RetryOnConflict to automatically update RBAC rules during operator upgrades. func (r *MCPRemoteProxyReconciler) ensureRBACResources(ctx context.Context, proxy *mcpv1alpha1.MCPRemoteProxy) error { proxyRunnerNameForRBAC := proxyRunnerServiceAccountNameForRemoteProxy(proxy.Name) diff --git a/cmd/thv-operator/controllers/virtualmcpserver_controller.go b/cmd/thv-operator/controllers/virtualmcpserver_controller.go index 9e1a194c2a..8adb1797f0 100644 --- a/cmd/thv-operator/controllers/virtualmcpserver_controller.go +++ b/cmd/thv-operator/controllers/virtualmcpserver_controller.go @@ -496,14 +496,31 @@ func (r *VirtualMCPServerReconciler) ensureAllResources( return nil } -// ensureRBACResources ensures that the RBAC resources are in place for the VirtualMCPServer +// ensureRBACResources ensures RBAC resources for VirtualMCPServer in dynamic mode. +// In static mode, RBAC creation is skipped. When switching dynamic→static, existing RBAC +// resources are NOT deleted - they persist until VirtualMCPServer deletion via owner references. +// This follows standard Kubernetes garbage collection patterns. +// +// TODO: This uses EnsureRBACResource which only creates RBAC but never updates them. +// Consider adopting the MCPRegistry pattern (pkg/registryapi/rbac.go) which uses +// CreateOrUpdate + RetryOnConflict to automatically update RBAC rules during operator upgrades. func (r *VirtualMCPServerReconciler) ensureRBACResources( ctx context.Context, vmcp *mcpv1alpha1.VirtualMCPServer, ) error { + // Determine the outgoing auth source mode + source := outgoingAuthSource(vmcp) + + // Static mode (inline): Skip RBAC creation/deletion + // Existing resources from dynamic mode persist until VirtualMCPServer deletion + if source == OutgoingAuthSourceInline { + return nil + } + + // Dynamic mode (discovered): Ensure RBAC resources exist serviceAccountName := vmcpServiceAccountName(vmcp.Name) - // Ensure Role with minimal permissions + // Ensure Role with permissions to discover backends and update status if err := ctrlutil.EnsureRBACResource(ctx, r.Client, r.Scheme, vmcp, "Role", func() client.Object { return &rbacv1.Role{ ObjectMeta: metav1.ObjectMeta{ @@ -848,13 +865,9 @@ func (r *VirtualMCPServerReconciler) containerNeedsUpdate( } // Check if service account has changed - expectedServiceAccountName := vmcpServiceAccountName(vmcp.Name) + expectedServiceAccountName := r.serviceAccountNameForVmcp(vmcp) currentServiceAccountName := deployment.Spec.Template.Spec.ServiceAccountName - if currentServiceAccountName != "" && currentServiceAccountName != expectedServiceAccountName { - return true - } - - return false + return currentServiceAccountName != expectedServiceAccountName } // deploymentMetadataNeedsUpdate checks if deployment-level metadata has changed @@ -1249,6 +1262,31 @@ func vmcpServiceAccountName(vmcpName string) string { return fmt.Sprintf("%s-vmcp", vmcpName) } +// outgoingAuthSource returns the outgoing auth source mode with default fallback. +// Returns OutgoingAuthSourceDiscovered if not specified. +func outgoingAuthSource(vmcp *mcpv1alpha1.VirtualMCPServer) string { + if vmcp.Spec.OutgoingAuth != nil && vmcp.Spec.OutgoingAuth.Source != "" { + return vmcp.Spec.OutgoingAuth.Source + } + return OutgoingAuthSourceDiscovered +} + +// serviceAccountNameForVmcp returns the service account name for a VirtualMCPServer +// based on its outgoing auth source mode. +// - Dynamic mode (discovered): Returns the dedicated service account name +// - Static mode (inline): Returns empty string (uses default service account) +func (*VirtualMCPServerReconciler) serviceAccountNameForVmcp(vmcp *mcpv1alpha1.VirtualMCPServer) string { + source := outgoingAuthSource(vmcp) + + // Static mode: Use default service account (no RBAC resources) + if source == OutgoingAuthSourceInline { + return "" + } + + // Dynamic mode: Use dedicated service account with K8s API permissions + return vmcpServiceAccountName(vmcp.Name) +} + // vmcpServiceName generates the service name for a VirtualMCPServer // Uses "vmcp-" prefix to distinguish from MCPServer's "mcp-{name}-proxy" pattern. // This allows VirtualMCPServer and MCPServer to coexist with the same base name. @@ -1472,10 +1510,7 @@ func (r *VirtualMCPServerReconciler) buildOutgoingAuthConfig( typedWorkloads []workloads.TypedWorkload, ) (*vmcpconfig.OutgoingAuthConfig, error) { // Determine source - default to "discovered" if not specified - source := OutgoingAuthSourceDiscovered - if vmcp.Spec.OutgoingAuth != nil && vmcp.Spec.OutgoingAuth.Source != "" { - source = vmcp.Spec.OutgoingAuth.Source - } + source := outgoingAuthSource(vmcp) outgoing := &vmcpconfig.OutgoingAuthConfig{ Source: source, @@ -1491,10 +1526,11 @@ func (r *VirtualMCPServerReconciler) buildOutgoingAuthConfig( outgoing.Default = defaultStrategy } - // Discover ExternalAuthConfig from MCPServers if source is "discovered" - if source == OutgoingAuthSourceDiscovered { - r.discoverExternalAuthConfigs(ctx, vmcp, typedWorkloads, outgoing) - } + // Discover ExternalAuthConfig from MCPServers to populate backend auth configs. + // This function is called from ensureVmcpConfigConfigMap only for inline/static mode, + // where we need full backend details in the ConfigMap. For discovered/dynamic mode, + // this function is not called, keeping the ConfigMap minimal. + r.discoverExternalAuthConfigs(ctx, vmcp, typedWorkloads, outgoing) // Apply inline overrides (works for all source modes) if vmcp.Spec.OutgoingAuth != nil && vmcp.Spec.OutgoingAuth.Backends != nil { @@ -1510,9 +1546,42 @@ func (r *VirtualMCPServerReconciler) buildOutgoingAuthConfig( return outgoing, nil } +// convertBackendsToStaticBackends converts Backend objects to StaticBackendConfig for ConfigMap embedding. +// Preserves metadata and uses transport types from workload Specs. +// Logs warnings when backends are skipped due to missing URL or transport information. +func convertBackendsToStaticBackends( + ctx context.Context, + backends []vmcptypes.Backend, + transportMap map[string]string, +) []vmcpconfig.StaticBackendConfig { + logger := log.FromContext(ctx) + static := make([]vmcpconfig.StaticBackendConfig, 0, len(backends)) + for _, backend := range backends { + if backend.BaseURL == "" { + logger.V(1).Info("Skipping backend without URL in static mode", + "backend", backend.Name) + continue + } + + transport := transportMap[backend.Name] + if transport == "" { + logger.V(1).Info("Skipping backend without transport information in static mode", + "backend", backend.Name) + continue + } + + static = append(static, vmcpconfig.StaticBackendConfig{ + Name: backend.Name, + URL: backend.BaseURL, + Transport: transport, + Metadata: backend.Metadata, + }) + } + return static +} + // discoverBackends discovers all MCPServers in the referenced MCPGroup and returns // a list of DiscoveredBackend objects with their current status. -// This reuses the existing workload discovery code from pkg/vmcp/workloads. // //nolint:gocyclo func (r *VirtualMCPServerReconciler) discoverBackends( @@ -1521,13 +1590,9 @@ func (r *VirtualMCPServerReconciler) discoverBackends( ) ([]mcpv1alpha1.DiscoveredBackend, error) { ctxLogger := log.FromContext(ctx) - // Create groups manager using the controller's client and VirtualMCPServer's namespace groupsManager := groups.NewCRDManager(r.Client, vmcp.Namespace) - - // Create K8S workload discoverer for the VirtualMCPServer's namespace workloadDiscoverer := workloads.NewK8SDiscovererWithClient(r.Client, vmcp.Namespace) - // Get all workloads in the group typedWorkloads, err := workloadDiscoverer.ListWorkloadsInGroup(ctx, vmcp.Spec.Config.Group) if err != nil { return nil, fmt.Errorf("failed to list workloads in group: %w", err) diff --git a/cmd/thv-operator/controllers/virtualmcpserver_deployment.go b/cmd/thv-operator/controllers/virtualmcpserver_deployment.go index fdbd26370b..ff523050c7 100644 --- a/cmd/thv-operator/controllers/virtualmcpserver_deployment.go +++ b/cmd/thv-operator/controllers/virtualmcpserver_deployment.go @@ -49,7 +49,8 @@ const ( vmcpReadinessFailures = int32(3) // consecutive failures before removing from service ) -// RBAC rules for VirtualMCPServer service account +// RBAC rules for VirtualMCPServer service account in dynamic mode +// These rules allow vMCP to discover backends and configurations at runtime var vmcpRBACRules = []rbacv1.PolicyRule{ { APIGroups: []string{""}, @@ -58,9 +59,14 @@ var vmcpRBACRules = []rbacv1.PolicyRule{ }, { APIGroups: []string{"toolhive.stacklok.dev"}, - Resources: []string{"mcpgroups", "mcpservers", "mcpremoteproxies", "mcpexternalauthconfigs"}, + Resources: []string{"mcpgroups", "mcpservers", "mcpremoteproxies", "mcpexternalauthconfigs", "mcptoolconfigs"}, Verbs: []string{"get", "list", "watch"}, }, + { + APIGroups: []string{"toolhive.stacklok.dev"}, + Resources: []string{"virtualmcpservers/status"}, + Verbs: []string{"update", "patch"}, + }, } // deploymentForVirtualMCPServer returns a VirtualMCPServer Deployment object @@ -80,6 +86,7 @@ func (r *VirtualMCPServerReconciler) deploymentForVirtualMCPServer( deploymentLabels, deploymentAnnotations := r.buildDeploymentMetadataForVmcp(ls, vmcp) deploymentTemplateLabels, deploymentTemplateAnnotations := r.buildPodTemplateMetadata(ls, vmcp, vmcpConfigChecksum) podSecurityContext, containerSecurityContext := r.buildSecurityContextsForVmcp(ctx, vmcp) + serviceAccountName := r.serviceAccountNameForVmcp(vmcp) dep := &appsv1.Deployment{ ObjectMeta: metav1.ObjectMeta{ @@ -99,7 +106,7 @@ func (r *VirtualMCPServerReconciler) deploymentForVirtualMCPServer( Annotations: deploymentTemplateAnnotations, }, Spec: corev1.PodSpec{ - ServiceAccountName: vmcpServiceAccountName(vmcp.Name), + ServiceAccountName: serviceAccountName, Containers: []corev1.Container{{ Image: getVmcpImage(), ImagePullPolicy: corev1.PullIfNotPresent, diff --git a/cmd/thv-operator/controllers/virtualmcpserver_vmcpconfig.go b/cmd/thv-operator/controllers/virtualmcpserver_vmcpconfig.go index 485e6db1a9..9be45407a9 100644 --- a/cmd/thv-operator/controllers/virtualmcpserver_vmcpconfig.go +++ b/cmd/thv-operator/controllers/virtualmcpserver_vmcpconfig.go @@ -13,7 +13,11 @@ import ( "github.com/stacklok/toolhive/cmd/thv-operator/pkg/kubernetes/configmaps" "github.com/stacklok/toolhive/cmd/thv-operator/pkg/oidc" "github.com/stacklok/toolhive/cmd/thv-operator/pkg/runconfig/configmap/checksum" - "github.com/stacklok/toolhive/cmd/thv-operator/pkg/vmcpconfig" + operatorvmcpconfig "github.com/stacklok/toolhive/cmd/thv-operator/pkg/vmcpconfig" + "github.com/stacklok/toolhive/pkg/groups" + vmcptypes "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" "github.com/stacklok/toolhive/pkg/vmcp/workloads" ) @@ -25,14 +29,9 @@ func (r *VirtualMCPServerReconciler) ensureVmcpConfigConfigMap( vmcp *mcpv1alpha1.VirtualMCPServer, typedWorkloads []workloads.TypedWorkload, ) error { - ctxLogger := log.FromContext(ctx) - - // Create OIDC resolver to handle all OIDC types (kubernetes, configMap, inline) + // Create OIDC resolver and converter for CRD-to-config transformation oidcResolver := oidc.NewResolver(r.Client) - - // Convert CRD to vmcp config using converter with OIDC resolver and Kubernetes client - // The client is needed to fetch referenced VirtualMCPCompositeToolDefinition resources - converter, err := vmcpconfig.NewConverter(oidcResolver, r.Client) + converter, err := operatorvmcpconfig.NewConverter(oidcResolver, r.Client) if err != nil { return fmt.Errorf("failed to create vmcp converter: %w", err) } @@ -41,21 +40,45 @@ func (r *VirtualMCPServerReconciler) ensureVmcpConfigConfigMap( return fmt.Errorf("failed to create vmcp Config from VirtualMCPServer: %w", err) } - // For dynamic mode (source: "discovered"), preserve the Source field so the vMCP pod - // can start BackendWatcher for runtime backend discovery. - // For inline mode, discover backends at reconcile time and include in ConfigMap. - if config.OutgoingAuth != nil && config.OutgoingAuth.Source != "discovered" { + // Static mode (inline): Embed full backend details in ConfigMap. + // Dynamic mode (discovered): vMCP discovers backends at runtime via K8s API. + if config.OutgoingAuth != nil && config.OutgoingAuth.Source == "inline" { + // Build auth config with backend details discoveredAuthConfig, err := r.buildOutgoingAuthConfig(ctx, vmcp, typedWorkloads) if err != nil { - ctxLogger.V(1).Info("Failed to build discovered auth config, using spec-only config", - "error", err) - } else if discoveredAuthConfig != nil { + return fmt.Errorf("failed to build auth config for static mode: %w", err) + } + if discoveredAuthConfig != nil { config.OutgoingAuth = discoveredAuthConfig } + + // Discover backends with metadata + backends, err := r.discoverBackendsWithMetadata(ctx, vmcp) + if err != nil { + return fmt.Errorf("failed to discover backends for static mode: %w", err) + } + + // Get transport types from workload specs + transportMap, err := r.buildTransportMap(ctx, vmcp.Namespace, typedWorkloads) + if err != nil { + return fmt.Errorf("failed to build transport map for static mode: %w", err) + } + + config.Backends = convertBackendsToStaticBackends(ctx, backends, transportMap) + + // Validate at least one backend exists + if len(config.Backends) == 0 { + return fmt.Errorf( + "static mode requires at least one backend with valid transport (%v), "+ + "but none were discovered in group %s", + vmcpconfig.StaticModeAllowedTransports, + config.Group, + ) + } } // Validate the vmcp Config before creating the ConfigMap - validator := vmcpconfig.NewValidator() + validator := operatorvmcpconfig.NewValidator() if err := validator.Validate(ctx, config); err != nil { return fmt.Errorf("invalid vmcp Config: %w", err) } @@ -104,3 +127,88 @@ func labelsForVmcpConfig(vmcpName string) map[string]string { "toolhive.stacklok.io/managed-by": "toolhive-operator", } } + +// discoverBackendsWithMetadata discovers backends and returns full Backend objects with metadata. +// Used in static mode for ConfigMap generation to preserve backend metadata. +func (r *VirtualMCPServerReconciler) discoverBackendsWithMetadata( + ctx context.Context, + vmcp *mcpv1alpha1.VirtualMCPServer, +) ([]vmcptypes.Backend, error) { + groupsManager := groups.NewCRDManager(r.Client, vmcp.Namespace) + workloadDiscoverer := workloads.NewK8SDiscovererWithClient(r.Client, vmcp.Namespace) + + // Build auth config if OutgoingAuth is configured + var authConfig *vmcpconfig.OutgoingAuthConfig + if vmcp.Spec.OutgoingAuth != nil { + typedWorkloads, err := workloadDiscoverer.ListWorkloadsInGroup(ctx, vmcp.Spec.Config.Group) + if err != nil { + return nil, fmt.Errorf("failed to list workloads in group: %w", err) + } + + authConfig, err = r.buildOutgoingAuthConfig(ctx, vmcp, typedWorkloads) + if err != nil { + ctxLogger := log.FromContext(ctx) + ctxLogger.V(1).Info("Failed to build outgoing auth config, continuing without authentication", + "error", err, + "virtualmcpserver", vmcp.Name, + "namespace", vmcp.Namespace) + authConfig = nil // Continue without auth config on error + } + } + + backendDiscoverer := aggregator.NewUnifiedBackendDiscoverer(workloadDiscoverer, groupsManager, authConfig) + backends, err := backendDiscoverer.Discover(ctx, vmcp.Spec.Config.Group) + if err != nil { + return nil, fmt.Errorf("failed to discover backends: %w", err) + } + + return backends, nil +} + +// buildTransportMap builds a map of backend names to transport types from workload Specs. +// Used in static mode to populate transport field in ConfigMap. +func (r *VirtualMCPServerReconciler) buildTransportMap( + ctx context.Context, + namespace string, + typedWorkloads []workloads.TypedWorkload, +) (map[string]string, error) { + transportMap := make(map[string]string, len(typedWorkloads)) + + mcpServerMap, err := r.listMCPServersAsMap(ctx, namespace) + if err != nil { + return nil, fmt.Errorf("failed to list MCPServers: %w", err) + } + + mcpRemoteProxyMap, err := r.listMCPRemoteProxiesAsMap(ctx, namespace) + if err != nil { + return nil, fmt.Errorf("failed to list MCPRemoteProxies: %w", err) + } + + for _, workload := range typedWorkloads { + var transport string + + switch workload.Type { + case workloads.WorkloadTypeMCPServer: + if mcpServer, found := mcpServerMap[workload.Name]; found { + // Read effective transport (ProxyMode takes precedence over Transport) + // For stdio servers, ProxyMode indicates how they're proxied (sse or streamable-http) + if mcpServer.Spec.ProxyMode != "" { + transport = string(mcpServer.Spec.ProxyMode) + } else { + transport = string(mcpServer.Spec.Transport) + } + } + + case workloads.WorkloadTypeMCPRemoteProxy: + if mcpRemoteProxy, found := mcpRemoteProxyMap[workload.Name]; found { + transport = string(mcpRemoteProxy.Spec.Transport) + } + } + + if transport != "" { + transportMap[workload.Name] = transport + } + } + + return transportMap, nil +} diff --git a/cmd/thv-operator/controllers/virtualmcpserver_vmcpconfig_test.go b/cmd/thv-operator/controllers/virtualmcpserver_vmcpconfig_test.go index a0cf000658..8a0b378806 100644 --- a/cmd/thv-operator/controllers/virtualmcpserver_vmcpconfig_test.go +++ b/cmd/thv-operator/controllers/virtualmcpserver_vmcpconfig_test.go @@ -665,7 +665,7 @@ func TestYAMLMarshalingDeterminism(t *testing.T) { assert.NotEmpty(t, firstResult) assert.Greater(t, len(firstResult), 100, "YAML output should contain substantial content") - t.Logf("✅ All %d marshaling iterations produced identical output (%d bytes)", + t.Logf("All %d marshaling iterations produced identical output (%d bytes)", iterations, len(results[0])) } @@ -969,3 +969,398 @@ func TestVirtualMCPServerReconciler_CompositeToolRefs_NotFound(t *testing.T) { require.Error(t, err, "should fail when referenced tool doesn't exist") assert.Contains(t, err.Error(), "not found", "error should mention not found") } + +// TestConfigMapContent_DynamicMode tests that in dynamic mode (discovered), +// the ConfigMap contains minimal content without backends +func TestConfigMapContent_DynamicMode(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testScheme := createRunConfigTestScheme() + + // Create MCPGroup for workload discovery + mcpGroup := &mcpv1alpha1.MCPGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-group", + Namespace: "default", + }, + Spec: mcpv1alpha1.MCPGroupSpec{}, + Status: mcpv1alpha1.MCPGroupStatus{ + Phase: mcpv1alpha1.MCPGroupPhaseReady, + }, + } + + // Create VirtualMCPServer in dynamic mode (source: discovered) + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-vmcp", + Namespace: "default", + }, + Spec: mcpv1alpha1.VirtualMCPServerSpec{ + Config: vmcpconfig.Config{Group: "test-group"}, + IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &mcpv1alpha1.OutgoingAuthConfig{ + Source: "discovered", // Dynamic mode + }, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(testScheme). + WithObjects(vmcpServer, mcpGroup). + Build() + + reconciler := &VirtualMCPServerReconciler{ + Client: fakeClient, + Scheme: testScheme, + } + + // Discover workloads + workloadDiscoverer := workloads.NewK8SDiscovererWithClient(fakeClient, vmcpServer.Namespace) + workloadNames, err := workloadDiscoverer.ListWorkloadsInGroup(ctx, vmcpServer.Spec.Config.Group) + require.NoError(t, err) + + // Create ConfigMap + err = reconciler.ensureVmcpConfigConfigMap(ctx, vmcpServer, workloadNames) + require.NoError(t, err) + + // Verify ConfigMap was created + configMap := &corev1.ConfigMap{} + err = fakeClient.Get(ctx, types.NamespacedName{ + Name: vmcpConfigMapName("test-vmcp"), + Namespace: "default", + }, configMap) + require.NoError(t, err) + + // Parse the YAML config + var config vmcpconfig.Config + err = yaml.Unmarshal([]byte(configMap.Data["config.yaml"]), &config) + require.NoError(t, err) + + // In dynamic mode, ConfigMap should have minimal content: + // - OutgoingAuth with source: discovered + // - No auth backends in OutgoingAuth (vMCP discovers at runtime) + // - No static backends in Backends (vMCP discovers at runtime) + require.NotNil(t, config.OutgoingAuth) + assert.Equal(t, "discovered", config.OutgoingAuth.Source, "source should be discovered") + assert.Empty(t, config.OutgoingAuth.Backends, "auth backends should be empty in dynamic mode") + assert.Empty(t, config.Backends, "static backends should be empty in dynamic mode") + + t.Log("Dynamic mode ConfigMap contains minimal content without backends") +} + +// TestConfigMapContent_StaticMode_InlineOverrides tests that in static mode (inline), +// explicitly specified backends in the spec are preserved in the ConfigMap. +// This tests inline overrides, not discovery. See TestConfigMapContent_StaticModeWithDiscovery +// for testing actual backend discovery from MCPServers in the group. +func TestConfigMapContent_StaticMode_InlineOverrides(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testScheme := createRunConfigTestScheme() + + // Create MCPGroup for workload discovery + mcpGroup := &mcpv1alpha1.MCPGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-group", + Namespace: "default", + }, + Spec: mcpv1alpha1.MCPGroupSpec{}, + Status: mcpv1alpha1.MCPGroupStatus{ + Phase: mcpv1alpha1.MCPGroupPhaseReady, + }, + } + + // Create MCPServer in the group so static mode has something to discover + // This is needed because static mode validates that at least one backend exists + mcpServer := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-backend", + Namespace: "default", + }, + Spec: mcpv1alpha1.MCPServerSpec{ + GroupRef: "test-group", + Transport: "sse", // Required for backend discovery + }, + Status: mcpv1alpha1.MCPServerStatus{ + Phase: mcpv1alpha1.MCPServerPhaseRunning, + URL: "http://test-backend.default.svc.cluster.local:8080", + }, + } + + // Create VirtualMCPServer in static mode (source: inline) + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-vmcp", + Namespace: "default", + }, + Spec: mcpv1alpha1.VirtualMCPServerSpec{ + Config: vmcpconfig.Config{Group: "test-group"}, + IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &mcpv1alpha1.OutgoingAuthConfig{ + Source: "inline", // Static mode + Backends: map[string]mcpv1alpha1.BackendAuthConfig{ + "test-backend": { + Type: mcpv1alpha1.BackendAuthTypeDiscovered, + }, + }, + }, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(testScheme). + WithObjects(vmcpServer, mcpGroup, mcpServer). + WithStatusSubresource(mcpServer). + Build() + + reconciler := &VirtualMCPServerReconciler{ + Client: fakeClient, + Scheme: testScheme, + } + + // Discover workloads + workloadDiscoverer := workloads.NewK8SDiscovererWithClient(fakeClient, vmcpServer.Namespace) + workloadNames, err := workloadDiscoverer.ListWorkloadsInGroup(ctx, vmcpServer.Spec.Config.Group) + require.NoError(t, err) + + // Create ConfigMap + err = reconciler.ensureVmcpConfigConfigMap(ctx, vmcpServer, workloadNames) + require.NoError(t, err) + + // Verify ConfigMap was created + configMap := &corev1.ConfigMap{} + err = fakeClient.Get(ctx, types.NamespacedName{ + Name: vmcpConfigMapName("test-vmcp"), + Namespace: "default", + }, configMap) + require.NoError(t, err) + + // Parse the YAML config + var config vmcpconfig.Config + err = yaml.Unmarshal([]byte(configMap.Data["config.yaml"]), &config) + require.NoError(t, err) + + // In static mode with inline backends, ConfigMap should preserve them: + // - OutgoingAuth with source: inline + // - Backends from spec.outgoingAuth.backends are included + require.NotNil(t, config.OutgoingAuth) + assert.Equal(t, "inline", config.OutgoingAuth.Source, "source should be inline") + require.NotEmpty(t, config.OutgoingAuth.Backends, "backends should be present in static mode") + + // Verify the inline backend from spec is present + _, exists := config.OutgoingAuth.Backends["test-backend"] + assert.True(t, exists, "inline backend from spec should be present in ConfigMap") + + t.Log("Static mode ConfigMap preserves inline backend overrides from spec") +} + +// TestConfigMapContent_StaticModeWithDiscovery tests that in static mode (inline), +// the ConfigMap contains discovered backend auth configs from MCPServer ExternalAuthConfigRefs +func TestConfigMapContent_StaticModeWithDiscovery(t *testing.T) { + t.Parallel() + + ctx := context.Background() + testScheme := createRunConfigTestScheme() + + // Create MCPGroup for workload discovery + mcpGroup := &mcpv1alpha1.MCPGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-group", + Namespace: "default", + }, + Spec: mcpv1alpha1.MCPGroupSpec{}, + Status: mcpv1alpha1.MCPGroupStatus{ + Phase: mcpv1alpha1.MCPGroupPhaseReady, + }, + } + + // Create MCPExternalAuthConfig that will be referenced by MCPServer + externalAuthConfig := &mcpv1alpha1.MCPExternalAuthConfig{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-auth-config", + Namespace: "default", + }, + Spec: mcpv1alpha1.MCPExternalAuthConfigSpec{ + Type: mcpv1alpha1.ExternalAuthTypeUnauthenticated, + }, + } + + // Create MCPServer with ExternalAuthConfigRef and Status + mcpServer := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "discovered-backend", + Namespace: "default", + }, + Spec: mcpv1alpha1.MCPServerSpec{ + GroupRef: "test-group", + Transport: "sse", // Required for static mode backend discovery + ExternalAuthConfigRef: &mcpv1alpha1.ExternalAuthConfigRef{ + Name: "test-auth-config", + }, + }, + Status: mcpv1alpha1.MCPServerStatus{ + Phase: mcpv1alpha1.MCPServerPhaseRunning, + URL: "http://discovered-backend.default.svc.cluster.local:8080", + }, + } + + // Create VirtualMCPServer in static mode (source: inline) WITHOUT inline backends + vmcpServer := &mcpv1alpha1.VirtualMCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-vmcp", + Namespace: "default", + }, + Spec: mcpv1alpha1.VirtualMCPServerSpec{ + Config: vmcpconfig.Config{Group: "test-group"}, + IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{ + Type: "anonymous", + }, + OutgoingAuth: &mcpv1alpha1.OutgoingAuthConfig{ + Source: "inline", // Static mode - should discover backends + }, + }, + } + + fakeClient := fake.NewClientBuilder(). + WithScheme(testScheme). + WithObjects(vmcpServer, mcpGroup, mcpServer, externalAuthConfig). + WithStatusSubresource(mcpServer). + Build() + + reconciler := &VirtualMCPServerReconciler{ + Client: fakeClient, + Scheme: testScheme, + } + + // Discover workloads + workloadDiscoverer := workloads.NewK8SDiscovererWithClient(fakeClient, vmcpServer.Namespace) + workloadNames, err := workloadDiscoverer.ListWorkloadsInGroup(ctx, vmcpServer.Spec.Config.Group) + require.NoError(t, err) + require.NotEmpty(t, workloadNames, "should have discovered the MCPServer") + + // Create ConfigMap + err = reconciler.ensureVmcpConfigConfigMap(ctx, vmcpServer, workloadNames) + require.NoError(t, err) + + // Verify ConfigMap was created + configMap := &corev1.ConfigMap{} + err = fakeClient.Get(ctx, types.NamespacedName{ + Name: vmcpConfigMapName("test-vmcp"), + Namespace: "default", + }, configMap) + require.NoError(t, err) + + // Parse the YAML config + var config vmcpconfig.Config + err = yaml.Unmarshal([]byte(configMap.Data["config.yaml"]), &config) + require.NoError(t, err) + + // In static mode with discovery, ConfigMap should have: + // - OutgoingAuth with source: inline and auth configs + // - Backends populated with URLs and transport types for zero-K8s-access mode + require.NotNil(t, config.OutgoingAuth) + assert.Equal(t, "inline", config.OutgoingAuth.Source, "source should be inline") + require.NotEmpty(t, config.OutgoingAuth.Backends, "backends should be discovered in static mode") + + // Verify the discovered backend auth config is present + discoveredBackend, exists := config.OutgoingAuth.Backends["discovered-backend"] + require.True(t, exists, "discovered backend should be present in ConfigMap") + require.NotNil(t, discoveredBackend, "discovered backend should have auth strategy") + assert.Equal(t, "unauthenticated", discoveredBackend.Type, "backend should have correct auth type") + + // Verify static backend configurations (URLs + transport) are populated + require.NotEmpty(t, config.Backends, "static backends with URLs should be populated in static mode") + + // Find the discovered backend in the static backend list + var foundBackend *vmcpconfig.StaticBackendConfig + for i := range config.Backends { + if config.Backends[i].Name == "discovered-backend" { + foundBackend = &config.Backends[i] + break + } + } + require.NotNil(t, foundBackend, "discovered backend should be in static backends list") + assert.NotEmpty(t, foundBackend.URL, "backend should have URL populated") + assert.NotEmpty(t, foundBackend.Transport, "backend should have transport type populated") + + // Verify metadata is preserved (group, tool_type, workload_type, namespace) + require.NotNil(t, foundBackend.Metadata, "backend should have metadata") + assert.Equal(t, "test-group", foundBackend.Metadata["group"], "backend should have group metadata") + assert.Equal(t, "mcp", foundBackend.Metadata["tool_type"], "backend should have tool_type metadata") + assert.Equal(t, "mcp_server", foundBackend.Metadata["workload_type"], "backend should have workload_type metadata") + assert.Equal(t, "default", foundBackend.Metadata["namespace"], "backend should have namespace metadata") + + t.Log("Static mode ConfigMap contains both auth configs, backend URLs/transports, and metadata") +} + +// TestConvertBackendsToStaticBackends_SkipsInvalidBackends tests that backends +// without URL or transport are skipped with appropriate logging +func TestConvertBackendsToStaticBackends_SkipsInvalidBackends(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + backends := []vmcp.Backend{ + { + Name: "valid-backend", + BaseURL: "http://backend1:8080", + TransportType: "sse", + Metadata: map[string]string{"key": "value"}, + }, + { + Name: "no-url-backend", + BaseURL: "", // Missing URL + TransportType: "sse", + }, + { + Name: "no-transport-backend", + BaseURL: "http://backend2:8080", + // Transport will be missing from map + }, + } + + transportMap := map[string]string{ + "valid-backend": "sse", + "no-url-backend": "streamable-http", + // "no-transport-backend" intentionally missing + } + + result := convertBackendsToStaticBackends(ctx, backends, transportMap) + + // Should only include the valid backend + assert.Len(t, result, 1, "should only include backends with URL and transport") + assert.Equal(t, "valid-backend", result[0].Name) + assert.Equal(t, "http://backend1:8080", result[0].URL) + assert.Equal(t, "sse", result[0].Transport) + assert.Equal(t, "value", result[0].Metadata["key"]) +} + +// TestStaticModeTransportConstants verifies that the transport constants match the CRD enum. +// This test ensures consistency between runtime validation and CRD schema validation. +func TestStaticModeTransportConstants(t *testing.T) { + t.Parallel() + + // Define the expected transports that should be in the CRD enum. + // If this test fails, it means the CRD enum in StaticBackendConfig.Transport + // is out of sync with vmcpconfig.StaticModeAllowedTransports. + expectedTransports := []string{vmcpconfig.TransportSSE, vmcpconfig.TransportStreamableHTTP} + + // Verify the slice matches exactly + assert.ElementsMatch(t, expectedTransports, vmcpconfig.StaticModeAllowedTransports, + "StaticModeAllowedTransports must match the transport constants") + + // Verify individual constants have expected values + assert.Equal(t, "sse", vmcpconfig.TransportSSE, "TransportSSE constant value") + assert.Equal(t, "streamable-http", vmcpconfig.TransportStreamableHTTP, "TransportStreamableHTTP constant value") + + // NOTE: When updating allowed transports: + // 1. Update the constants in pkg/vmcp/config/config.go + // 2. Update the CRD enum in StaticBackendConfig.Transport: +kubebuilder:validation:Enum=... + // 3. Run: task operator-generate && task operator-manifests + // 4. This test will verify the constants match the expected values +} diff --git a/cmd/thv-operator/pkg/controllerutil/rbac.go b/cmd/thv-operator/pkg/controllerutil/rbac.go index a7eadd2938..76c0865dc6 100644 --- a/cmd/thv-operator/pkg/controllerutil/rbac.go +++ b/cmd/thv-operator/pkg/controllerutil/rbac.go @@ -13,6 +13,10 @@ import ( ) // EnsureRBACResource is a generic helper function to ensure a Kubernetes RBAC resource exists +// LIMITATION: This only creates resources if they don't exist - it does NOT update them. +// If RBAC rules change in an operator upgrade, existing resources won't be updated. +// For a better pattern that supports updates, see pkg/registryapi/rbac.go which uses +// CreateOrUpdate + RetryOnConflict. func EnsureRBACResource( ctx context.Context, c client.Client, diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 2d5c8f28e3..5bbb21e635 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -228,17 +228,28 @@ func discoverBackends(ctx context.Context, cfg *config.Config) ([]vmcp.Backend, return nil, nil, fmt.Errorf("failed to create backend client: %w", err) } - // Initialize managers for backend discovery - logger.Info("Initializing group manager") - groupsManager, err := groups.NewManager() - if err != nil { - return nil, nil, fmt.Errorf("failed to create groups manager: %w", err) - } + // Create backend discoverer based on configuration mode + var discoverer aggregator.BackendDiscoverer + if len(cfg.Backends) > 0 { + // Static mode: Use pre-configured backends from config (no K8s API access needed) + logger.Infof("Static mode: using %d pre-configured backends", len(cfg.Backends)) + discoverer = aggregator.NewUnifiedBackendDiscovererWithStaticBackends( + cfg.Backends, + cfg.OutgoingAuth, + cfg.Group, + ) + } else { + // Dynamic mode: Discover backends at runtime from K8s API + logger.Info("Dynamic mode: initializing group manager for backend discovery") + groupsManager, err := groups.NewManager() + if err != nil { + return nil, nil, fmt.Errorf("failed to create groups manager: %w", err) + } - // Create backend discoverer based on runtime environment - discoverer, err := aggregator.NewBackendDiscoverer(ctx, groupsManager, cfg.OutgoingAuth) - if err != nil { - return nil, nil, fmt.Errorf("failed to create backend discoverer: %w", err) + discoverer, err = aggregator.NewBackendDiscoverer(ctx, groupsManager, cfg.OutgoingAuth) + if err != nil { + return nil, nil, fmt.Errorf("failed to create backend discoverer: %w", err) + } } logger.Infof("Discovering backends in group: %s", cfg.Group) diff --git a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml index a20b4b1625..9e3f958641 100644 --- a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -215,6 +215,51 @@ spec: data included in audit logs (in bytes). type: integer type: object + backends: + description: |- + Backends defines pre-configured backend servers for static mode. + When OutgoingAuth.Source is "inline", this field contains the full list of backend + servers with their URLs and transport types, eliminating the need for K8s API access. + When OutgoingAuth.Source is "discovered", this field is empty and backends are + discovered at runtime via Kubernetes API. + items: + description: |- + StaticBackendConfig defines a pre-configured backend server for static mode. + This allows vMCP to operate without Kubernetes API access by embedding all backend + information directly in the configuration. + properties: + metadata: + additionalProperties: + type: string + description: |- + Metadata is a custom key-value map for storing additional backend information + such as labels, tags, or other arbitrary data (e.g., "env": "prod", "region": "us-east-1"). + This is NOT Kubernetes ObjectMeta - it's a simple string map for user-defined metadata. + Reserved keys: "group" is automatically set by vMCP and any user-provided value will be overridden. + type: object + name: + description: |- + Name is the backend identifier. + Must match the backend name from the MCPGroup for auth config resolution. + type: string + transport: + description: |- + Transport is the MCP transport protocol: "sse" or "streamable-http" + Only network transports supported by vMCP client are allowed. + enum: + - sse + - streamable-http + type: string + url: + description: URL is the backend's MCP server base URL. + pattern: ^https?:// + type: string + required: + - name + - transport + - url + type: object + type: array compositeToolRefs: description: |- CompositeToolRefs references VirtualMCPCompositeToolDefinition resources @@ -517,6 +562,7 @@ spec: type: boolean issuer: description: Issuer is the OIDC issuer URL. + pattern: ^https?:// type: string protectedResourceAllowPrivateIp: description: |- diff --git a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml index 466a0906ce..d9631d935c 100644 --- a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -218,6 +218,51 @@ spec: data included in audit logs (in bytes). type: integer type: object + backends: + description: |- + Backends defines pre-configured backend servers for static mode. + When OutgoingAuth.Source is "inline", this field contains the full list of backend + servers with their URLs and transport types, eliminating the need for K8s API access. + When OutgoingAuth.Source is "discovered", this field is empty and backends are + discovered at runtime via Kubernetes API. + items: + description: |- + StaticBackendConfig defines a pre-configured backend server for static mode. + This allows vMCP to operate without Kubernetes API access by embedding all backend + information directly in the configuration. + properties: + metadata: + additionalProperties: + type: string + description: |- + Metadata is a custom key-value map for storing additional backend information + such as labels, tags, or other arbitrary data (e.g., "env": "prod", "region": "us-east-1"). + This is NOT Kubernetes ObjectMeta - it's a simple string map for user-defined metadata. + Reserved keys: "group" is automatically set by vMCP and any user-provided value will be overridden. + type: object + name: + description: |- + Name is the backend identifier. + Must match the backend name from the MCPGroup for auth config resolution. + type: string + transport: + description: |- + Transport is the MCP transport protocol: "sse" or "streamable-http" + Only network transports supported by vMCP client are allowed. + enum: + - sse + - streamable-http + type: string + url: + description: URL is the backend's MCP server base URL. + pattern: ^https?:// + type: string + required: + - name + - transport + - url + type: object + type: array compositeToolRefs: description: |- CompositeToolRefs references VirtualMCPCompositeToolDefinition resources @@ -520,6 +565,7 @@ spec: type: boolean issuer: description: Issuer is the OIDC issuer URL. + pattern: ^https?:// type: string protectedResourceAllowPrivateIp: description: |- diff --git a/docs/operator/crd-api.md b/docs/operator/crd-api.md index 4250170268..22a76b681d 100644 --- a/docs/operator/crd-api.md +++ b/docs/operator/crd-api.md @@ -235,6 +235,7 @@ _Appears in:_ | --- | --- | --- | --- | | `name` _string_ | Name is the virtual MCP server name. | | | | `groupRef` _string_ | Group references an existing MCPGroup that defines backend workloads.
In Kubernetes, the referenced MCPGroup must exist in the same namespace. | | Required: \{\}
| +| `backends` _[vmcp.config.StaticBackendConfig](#vmcpconfigstaticbackendconfig) array_ | Backends defines pre-configured backend servers for static mode.
When OutgoingAuth.Source is "inline", this field contains the full list of backend
servers with their URLs and transport types, eliminating the need for K8s API access.
When OutgoingAuth.Source is "discovered", this field is empty and backends are
discovered at runtime via Kubernetes API. | | | | `incomingAuth` _[vmcp.config.IncomingAuthConfig](#vmcpconfigincomingauthconfig)_ | IncomingAuth configures how clients authenticate to the virtual MCP server.
When using the Kubernetes operator, this is populated by the converter from
VirtualMCPServerSpec.IncomingAuth and any values set here will be superseded. | | | | `outgoingAuth` _[vmcp.config.OutgoingAuthConfig](#vmcpconfigoutgoingauthconfig)_ | OutgoingAuth configures how the virtual MCP server authenticates to backends.
When using the Kubernetes operator, this is populated by the converter from
VirtualMCPServerSpec.OutgoingAuth and any values set here will be superseded. | | | | `aggregation` _[vmcp.config.AggregationConfig](#vmcpconfigaggregationconfig)_ | Aggregation defines tool aggregation and conflict resolution strategies.
Supports ToolConfigRef for Kubernetes-native MCPToolConfig resource references. | | | @@ -343,7 +344,7 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `issuer` _string_ | Issuer is the OIDC issuer URL. | | | +| `issuer` _string_ | Issuer is the OIDC issuer URL. | | Pattern: `^https?://`
| | `clientId` _string_ | ClientID is the OAuth client ID. | | | | `clientSecretEnv` _string_ | ClientSecretEnv is the name of the environment variable containing the client secret.
This is the secure way to reference secrets - the actual secret value is never stored
in configuration files, only the environment variable name.
The secret value will be resolved from this environment variable at runtime. | | | | `audience` _string_ | Audience is the required token audience. | | | @@ -467,6 +468,27 @@ _Appears in:_ | `default` _[pkg.json.Any](#pkgjsonany)_ | Default is the fallback value if template expansion fails.
Type coercion is applied to match the declared Type. | | Schemaless: \{\}
| +#### vmcp.config.StaticBackendConfig + + + +StaticBackendConfig defines a pre-configured backend server for static mode. +This allows vMCP to operate without Kubernetes API access by embedding all backend +information directly in the configuration. + + + +_Appears in:_ +- [vmcp.config.Config](#vmcpconfigconfig) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `name` _string_ | Name is the backend identifier.
Must match the backend name from the MCPGroup for auth config resolution. | | Required: \{\}
| +| `url` _string_ | URL is the backend's MCP server base URL. | | Pattern: `^https?://`
Required: \{\}
| +| `transport` _string_ | Transport is the MCP transport protocol: "sse" or "streamable-http"
Only network transports supported by vMCP client are allowed. | | Enum: [sse streamable-http]
Required: \{\}
| +| `metadata` _object (keys:string, values:string)_ | Refer to Kubernetes API documentation for fields of `metadata`. | | | + + #### vmcp.config.StepErrorHandling diff --git a/pkg/vmcp/aggregator/discoverer.go b/pkg/vmcp/aggregator/discoverer.go index 012416f05c..13bc8f9a63 100644 --- a/pkg/vmcp/aggregator/discoverer.go +++ b/pkg/vmcp/aggregator/discoverer.go @@ -27,6 +27,8 @@ type backendDiscoverer struct { workloadsManager workloads.Discoverer groupsManager groups.Manager authConfig *config.OutgoingAuthConfig + staticBackends []config.StaticBackendConfig // Pre-configured backends for static mode + groupRef string // Group reference for static mode metadata } // NewUnifiedBackendDiscoverer creates a unified backend discoverer that works with both @@ -43,6 +45,23 @@ func NewUnifiedBackendDiscoverer( workloadsManager: workloadsManager, groupsManager: groupsManager, authConfig: authConfig, + staticBackends: nil, // Dynamic mode - discover backends at runtime + } +} + +// NewUnifiedBackendDiscovererWithStaticBackends creates a backend discoverer for static mode +// with pre-configured backends, eliminating the need for K8s API access. +func NewUnifiedBackendDiscovererWithStaticBackends( + staticBackends []config.StaticBackendConfig, + authConfig *config.OutgoingAuthConfig, + groupRef string, +) BackendDiscoverer { + return &backendDiscoverer{ + workloadsManager: nil, // Not needed in static mode + groupsManager: nil, // Not needed in static mode + authConfig: authConfig, + staticBackends: staticBackends, + groupRef: groupRef, } } @@ -95,9 +114,29 @@ func NewBackendDiscovererWithManager( // Discover finds all backend workloads in the specified group. // Returns all accessible backends with their health status marked based on workload status. // The groupRef is the group name (e.g., "engineering-team"). +// +// In static mode (when staticBackends are configured), this returns pre-configured backends +// without any K8s API access. In dynamic mode, it discovers backends at runtime. func (d *backendDiscoverer) Discover(ctx context.Context, groupRef string) ([]vmcp.Backend, error) { logger.Infof("Discovering backends in group %s", groupRef) + // Static mode: Use pre-configured backends if available + if len(d.staticBackends) > 0 { + logger.Infof("Using %d pre-configured static backends (no K8s API access)", len(d.staticBackends)) + return d.discoverFromStaticConfig() + } + + // If staticBackends was explicitly set (even if empty), but groupsManager is nil, + // this discoverer was created for static mode with an empty backend list. + // Return empty list instead of falling through to dynamic mode which would panic. + if d.staticBackends != nil && d.groupsManager == nil { + logger.Infof("Static mode with empty backend list, returning no backends") + return []vmcp.Backend{}, nil + } + + // Dynamic mode: Discover backends from K8s API at runtime + logger.Infof("Dynamic mode: discovering backends from K8s API") + // Verify that the group exists exists, err := d.groupsManager.Exists(ctx, groupRef) if err != nil { @@ -202,3 +241,40 @@ func (d *backendDiscoverer) applyAuthConfigToBackend(backend *vmcp.Backend, back } } } + +// discoverFromStaticConfig converts pre-configured static backends into vmcp.Backend objects +// for use in static mode where no K8s API access is available. +func (d *backendDiscoverer) discoverFromStaticConfig() ([]vmcp.Backend, error) { + backends := make([]vmcp.Backend, 0, len(d.staticBackends)) + + for _, staticBackend := range d.staticBackends { + backend := vmcp.Backend{ + ID: staticBackend.Name, + Name: staticBackend.Name, + BaseURL: staticBackend.URL, + TransportType: staticBackend.Transport, + HealthStatus: vmcp.BackendHealthy, // Assume healthy, actual health check happens later + Metadata: staticBackend.Metadata, + } + + // Apply auth configuration from OutgoingAuthConfig + d.applyAuthConfigToBackend(&backend, staticBackend.Name) + + // Set group metadata (reserved key, always overridden) + if backend.Metadata == nil { + backend.Metadata = make(map[string]string) + } + // Warn if user provided a conflicting group value + if existingGroup, exists := backend.Metadata["group"]; exists && existingGroup != d.groupRef { + logger.Warnf("Backend %s has user-provided group metadata '%s' which will be overridden with '%s'", + staticBackend.Name, existingGroup, d.groupRef) + } + backend.Metadata["group"] = d.groupRef + + backends = append(backends, backend) + logger.Infof("Loaded static backend: %s (url=%s, transport=%s)", + staticBackend.Name, staticBackend.URL, staticBackend.Transport) + } + + return backends, nil +} diff --git a/pkg/vmcp/aggregator/discoverer_test.go b/pkg/vmcp/aggregator/discoverer_test.go index 9f75120f9e..8bab6ec003 100644 --- a/pkg/vmcp/aggregator/discoverer_test.go +++ b/pkg/vmcp/aggregator/discoverer_test.go @@ -1168,3 +1168,140 @@ func TestBackendDiscoverer_applyAuthConfigToBackend(t *testing.T) { assert.Equal(t, "default-fallback-token", backend.AuthConfig.HeaderInjection.HeaderValue) }) } + +// TestStaticBackendDiscoverer_EmptyBackendList verifies that when a static discoverer +// is created with an empty backend list, it gracefully returns an empty list instead of +// panicking due to nil groupsManager (regression test for nil pointer dereference). +func TestStaticBackendDiscoverer_EmptyBackendList(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Create a static discoverer with empty backend list (not nil, but zero length) + // This simulates the edge case where staticBackends was set but is empty + discoverer := NewUnifiedBackendDiscovererWithStaticBackends( + []config.StaticBackendConfig{}, // Empty slice, not nil + nil, // No auth config + "test-group", + ) + + // This should return empty list without panicking + // Previously would panic when falling through to dynamic mode with nil groupsManager + backends, err := discoverer.Discover(ctx, "test-group") + + require.NoError(t, err) + assert.Empty(t, backends) +} + +// TestStaticBackendDiscoverer_MetadataGroupOverride verifies that the "group" metadata key +// is always overridden with the groupRef value, even if user provides a different value. +func TestStaticBackendDiscoverer_MetadataGroupOverride(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + staticBackends []config.StaticBackendConfig + groupRef string + expectedGroupVals []string + }{ + { + name: "user-provided group metadata is overridden", + staticBackends: []config.StaticBackendConfig{ + { + Name: "backend1", + URL: "http://backend1:8080", + Transport: "sse", + Metadata: map[string]string{ + "group": "wrong-group", // User provided conflicting value + "env": "prod", + }, + }, + }, + groupRef: "correct-group", + expectedGroupVals: []string{"correct-group"}, + }, + { + name: "group metadata added when not present", + staticBackends: []config.StaticBackendConfig{ + { + Name: "backend2", + URL: "http://backend2:8080", + Transport: "streamable-http", + Metadata: map[string]string{ + "env": "dev", + }, + }, + }, + groupRef: "test-group", + expectedGroupVals: []string{"test-group"}, + }, + { + name: "group metadata added when metadata is nil", + staticBackends: []config.StaticBackendConfig{ + { + Name: "backend3", + URL: "http://backend3:8080", + Transport: "sse", + Metadata: nil, // No metadata at all + }, + }, + groupRef: "my-group", + expectedGroupVals: []string{"my-group"}, + }, + { + name: "multiple backends all get correct group", + staticBackends: []config.StaticBackendConfig{ + { + Name: "backend1", + URL: "http://backend1:8080", + Transport: "sse", + Metadata: map[string]string{"group": "wrong1"}, + }, + { + Name: "backend2", + URL: "http://backend2:8080", + Transport: "streamable-http", + Metadata: map[string]string{"env": "prod"}, + }, + }, + groupRef: "shared-group", + expectedGroupVals: []string{"shared-group", "shared-group"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + ctx := context.Background() + + discoverer := NewUnifiedBackendDiscovererWithStaticBackends( + tt.staticBackends, + nil, // No auth config needed for this test + tt.groupRef, + ) + + backends, err := discoverer.Discover(ctx, tt.groupRef) + require.NoError(t, err) + + // Verify we got the expected number of backends + assert.Len(t, backends, len(tt.expectedGroupVals)) + + // Verify each backend has the correct group metadata + for i, backend := range backends { + assert.NotNil(t, backend.Metadata, "Backend %d should have metadata", i) + assert.Equal(t, tt.expectedGroupVals[i], backend.Metadata["group"], + "Backend %d should have correct group metadata", i) + + // Verify other metadata is preserved + if tt.staticBackends[i].Metadata != nil { + for k, v := range tt.staticBackends[i].Metadata { + if k != "group" { + assert.Equal(t, v, backend.Metadata[k], + "Backend %d should preserve non-group metadata key %s", i, k) + } + } + } + } + }) + } +} diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go index d1564e3c12..2f05902b4d 100644 --- a/pkg/vmcp/config/config.go +++ b/pkg/vmcp/config/config.go @@ -17,6 +17,19 @@ import ( authtypes "github.com/stacklok/toolhive/pkg/vmcp/auth/types" ) +// Transport type constants for static backend configuration. +// These define the allowed network transport protocols for vMCP backends in static mode. +const ( + // TransportSSE is the Server-Sent Events transport protocol. + TransportSSE = "sse" + // TransportStreamableHTTP is the streamable HTTP transport protocol. + TransportStreamableHTTP = "streamable-http" +) + +// StaticModeAllowedTransports lists all transport types allowed for static backend configuration. +// This must be kept in sync with the CRD enum validation in StaticBackendConfig.Transport. +var StaticModeAllowedTransports = []string{TransportSSE, TransportStreamableHTTP} + // Duration is a wrapper around time.Duration that marshals/unmarshals as a duration string. // This ensures duration values are serialized as "30s", "1m", etc. instead of nanosecond integers. // +kubebuilder:validation:Type=string @@ -80,6 +93,14 @@ type Config struct { // +kubebuilder:validation:Required Group string `json:"groupRef" yaml:"groupRef"` + // Backends defines pre-configured backend servers for static mode. + // When OutgoingAuth.Source is "inline", this field contains the full list of backend + // servers with their URLs and transport types, eliminating the need for K8s API access. + // When OutgoingAuth.Source is "discovered", this field is empty and backends are + // discovered at runtime via Kubernetes API. + // +optional + Backends []StaticBackendConfig `json:"backends,omitempty" yaml:"backends,omitempty"` + // IncomingAuth configures how clients authenticate to the virtual MCP server. // When using the Kubernetes operator, this is populated by the converter from // VirtualMCPServerSpec.IncomingAuth and any values set here will be superseded. @@ -161,6 +182,7 @@ type IncomingAuthConfig struct { // +gendoc type OIDCConfig struct { // Issuer is the OIDC issuer URL. + // +kubebuilder:validation:Pattern=`^https?://` Issuer string `json:"issuer" yaml:"issuer"` // ClientID is the OAuth client ID. @@ -203,6 +225,36 @@ type AuthzConfig struct { Policies []string `json:"policies,omitempty" yaml:"policies,omitempty"` } +// StaticBackendConfig defines a pre-configured backend server for static mode. +// This allows vMCP to operate without Kubernetes API access by embedding all backend +// information directly in the configuration. +// +gendoc +// +kubebuilder:object:generate=true +type StaticBackendConfig struct { + // Name is the backend identifier. + // Must match the backend name from the MCPGroup for auth config resolution. + // +kubebuilder:validation:Required + Name string `json:"name" yaml:"name"` + + // URL is the backend's MCP server base URL. + // +kubebuilder:validation:Required + // +kubebuilder:validation:Pattern=`^https?://` + URL string `json:"url" yaml:"url"` + + // Transport is the MCP transport protocol: "sse" or "streamable-http" + // Only network transports supported by vMCP client are allowed. + // +kubebuilder:validation:Enum=sse;streamable-http + // +kubebuilder:validation:Required + Transport string `json:"transport" yaml:"transport"` + + // Metadata is a custom key-value map for storing additional backend information + // such as labels, tags, or other arbitrary data (e.g., "env": "prod", "region": "us-east-1"). + // This is NOT Kubernetes ObjectMeta - it's a simple string map for user-defined metadata. + // Reserved keys: "group" is automatically set by vMCP and any user-provided value will be overridden. + // +optional + Metadata map[string]string `json:"metadata,omitempty" yaml:"metadata,omitempty"` +} + // OutgoingAuthConfig configures backend authentication. // // Note: When using the Kubernetes operator (VirtualMCPServer CRD), the diff --git a/pkg/vmcp/config/zz_generated.deepcopy.go b/pkg/vmcp/config/zz_generated.deepcopy.go index 6550ecddc7..9066fd45c5 100644 --- a/pkg/vmcp/config/zz_generated.deepcopy.go +++ b/pkg/vmcp/config/zz_generated.deepcopy.go @@ -138,6 +138,13 @@ func (in *CompositeToolRef) DeepCopy() *CompositeToolRef { // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Config) DeepCopyInto(out *Config) { *out = *in + if in.Backends != nil { + in, out := &in.Backends, &out.Backends + *out = make([]StaticBackendConfig, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } if in.IncomingAuth != nil { in, out := &in.IncomingAuth, &out.IncomingAuth *out = new(IncomingAuthConfig) @@ -430,6 +437,28 @@ func (in *OutputProperty) DeepCopy() *OutputProperty { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *StaticBackendConfig) DeepCopyInto(out *StaticBackendConfig) { + *out = *in + if in.Metadata != nil { + in, out := &in.Metadata, &out.Metadata + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new StaticBackendConfig. +func (in *StaticBackendConfig) DeepCopy() *StaticBackendConfig { + if in == nil { + return nil + } + out := new(StaticBackendConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *StepErrorHandling) DeepCopyInto(out *StepErrorHandling) { *out = *in diff --git a/pkg/vmcp/discovery/middleware_test.go b/pkg/vmcp/discovery/middleware_test.go index 4d82eb0dca..8594b89c29 100644 --- a/pkg/vmcp/discovery/middleware_test.go +++ b/pkg/vmcp/discovery/middleware_test.go @@ -28,6 +28,50 @@ func createTestSessionManager(t *testing.T) *transportsession.Manager { return sessionMgr } +// unorderedBackendsMatcher is a gomock matcher that compares backend slices without caring about order. +// This is needed because ImmutableRegistry.List() iterates over a map which doesn't guarantee order. +type unorderedBackendsMatcher struct { + expected []vmcp.Backend +} + +func (m unorderedBackendsMatcher) Matches(x any) bool { + actual, ok := x.([]vmcp.Backend) + if !ok { + return false + } + if len(actual) != len(m.expected) { + return false + } + + // Create maps for comparison + expectedMap := make(map[string]vmcp.Backend) + for _, b := range m.expected { + expectedMap[b.ID] = b + } + + actualMap := make(map[string]vmcp.Backend) + for _, b := range actual { + actualMap[b.ID] = b + } + + // Check all expected backends are present + for id, expectedBackend := range expectedMap { + actualBackend, found := actualMap[id] + if !found { + return false + } + if expectedBackend.ID != actualBackend.ID || expectedBackend.Name != actualBackend.Name { + return false + } + } + + return true +} + +func (unorderedBackendsMatcher) String() string { + return "matches backends regardless of order" +} + func TestMiddleware_InitializeRequest(t *testing.T) { t.Parallel() @@ -67,7 +111,7 @@ func TestMiddleware_InitializeRequest(t *testing.T) { // Expect discovery to be called for initialize request (no session ID) mockMgr.EXPECT(). - Discover(gomock.Any(), backends). + Discover(gomock.Any(), unorderedBackendsMatcher{backends}). Return(expectedCaps, nil) // Create a test handler that verifies capabilities are in context @@ -303,17 +347,7 @@ func TestMiddleware_CapabilitiesInContext(t *testing.T) { // Use Do to capture and verify backends separately, since order may vary mockMgr.EXPECT(). - Discover(gomock.Any(), gomock.Any()). - Do(func(_ context.Context, actualBackends []vmcp.Backend) { - // Verify that we got the expected backends regardless of order - assert.Len(t, actualBackends, 2) - backendIDs := make(map[string]bool) - for _, b := range actualBackends { - backendIDs[b.ID] = true - } - assert.True(t, backendIDs["backend1"], "backend1 should be present") - assert.True(t, backendIDs["backend2"], "backend2 should be present") - }). + Discover(gomock.Any(), unorderedBackendsMatcher{backends}). Return(expectedCaps, nil) // Create handler that inspects context in detail diff --git a/pkg/vmcp/workloads/k8s.go b/pkg/vmcp/workloads/k8s.go index fdcc751bac..24b081da81 100644 --- a/pkg/vmcp/workloads/k8s.go +++ b/pkg/vmcp/workloads/k8s.go @@ -199,6 +199,9 @@ func (d *k8sDiscoverer) mcpServerToBackend(ctx context.Context, mcpServer *mcpv1 // Generate URL from status or reconstruct from spec url := mcpServer.Status.URL if url == "" { + // Use ProxyPort (not McpPort) because it's the externally accessible port + // that the egress proxy listens on. This is what vMCP connects to. + // The McpPort is only for internal container-to-container communication. port := int(mcpServer.Spec.ProxyPort) if port == 0 { port = int(mcpServer.Spec.Port) // Fallback to deprecated Port field diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_discovered_mode_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_discovered_mode_test.go index b8e3e89f19..3ab460bce7 100644 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_discovered_mode_test.go +++ b/test/e2e/thv-operator/virtualmcp/virtualmcp_discovered_mode_test.go @@ -2,7 +2,9 @@ package virtualmcp import ( "context" + "encoding/json" "fmt" + "io" "net/http" "strings" "time" @@ -19,6 +21,13 @@ import ( "github.com/stacklok/toolhive/test/e2e/images" ) +// ReadinessResponse represents the /readyz endpoint response +type ReadinessResponse struct { + Status string `json:"status"` + Mode string `json:"mode"` + Reason string `json:"reason,omitempty"` +} + var _ = Describe("VirtualMCPServer Discovered Mode", Ordered, func() { var ( testNamespace = "default" @@ -402,4 +411,374 @@ var _ = Describe("VirtualMCPServer Discovered Mode", Ordered, func() { Expect(backendNames).To(ContainElements(backend1Name, backend2Name)) }) }) + + Context("when dynamically adding a new backend", func() { + var ( + backend3Name = "backend-dynamic-fetch" + initialToolCount int + ) + + AfterAll(func() { + // Clean up the dynamic backend + backend3 := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: backend3Name, + Namespace: testNamespace, + }, + } + _ = k8sClient.Delete(ctx, backend3) + }) + + It("should record initial tool count", func() { + By("Creating MCP client to get initial tool count") + serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort) + mcpClient, err := client.NewStreamableHttpClient(serverURL) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + Eventually(func() error { + err = mcpClient.Start(testCtx) + if err != nil { + return err + } + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "toolhive-e2e-initial-count", + Version: "1.0.0", + } + + _, err = mcpClient.Initialize(testCtx, initRequest) + return err + }, 30*time.Second, 5*time.Second).Should(Succeed()) + + var tools *mcp.ListToolsResult + Eventually(func() error { + var err error + tools, err = mcpClient.ListTools(testCtx, mcp.ListToolsRequest{}) + return err + }, 30*time.Second, 2*time.Second).Should(Succeed()) + + initialToolCount = len(tools.Tools) + GinkgoWriter.Printf("Initial tool count: %d\n", initialToolCount) + }) + + It("should detect new backend and update tool list", func() { + By("Adding third backend MCPServer") + backend3 := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: backend3Name, + Namespace: testNamespace, + }, + Spec: mcpv1alpha1.MCPServerSpec{ + GroupRef: mcpGroupName, + Image: images.GofetchServerImage, + Transport: "streamable-http", + ProxyPort: 8080, + McpPort: 8080, + }, + } + Expect(k8sClient.Create(ctx, backend3)).To(Succeed()) + + By("Waiting for new backend to be ready") + Eventually(func() error { + server := &mcpv1alpha1.MCPServer{} + err := k8sClient.Get(ctx, types.NamespacedName{ + Name: backend3Name, + Namespace: testNamespace, + }, server) + if err != nil { + return err + } + if server.Status.Phase != mcpv1alpha1.MCPServerPhaseRunning { + return fmt.Errorf("backend not ready, phase: %s", server.Status.Phase) + } + return nil + }, timeout, pollingInterval).Should(Succeed()) + + By("Verifying group now has three backends") + Eventually(func() int { + backends, err := GetMCPGroupBackends(ctx, k8sClient, mcpGroupName, testNamespace) + if err != nil { + return 0 + } + return len(backends) + }, 30*time.Second, 2*time.Second).Should(Equal(3)) + + By("Verifying tool count increased with new session") + serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort) + + Eventually(func() error { + mcpClient, err := client.NewStreamableHttpClient(serverURL) + if err != nil { + return err + } + defer mcpClient.Close() + + testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + err = mcpClient.Start(testCtx) + if err != nil { + return err + } + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "toolhive-e2e-after-add", + Version: "1.0.0", + } + + _, err = mcpClient.Initialize(testCtx, initRequest) + if err != nil { + return err + } + + tools, err := mcpClient.ListTools(testCtx, mcp.ListToolsRequest{}) + if err != nil { + return err + } + + if len(tools.Tools) <= initialToolCount { + return fmt.Errorf("expected more tools, got %d (was %d)", len(tools.Tools), initialToolCount) + } + return nil + }, 1*time.Minute, 10*time.Second).Should(Succeed()) + }) + }) + + Context("when dynamically removing a backend", func() { + It("should detect backend removal and update tool list", func() { + By("Getting current tool count") + serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort) + mcpClient, err := client.NewStreamableHttpClient(serverURL) + Expect(err).ToNot(HaveOccurred()) + defer mcpClient.Close() + + testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + Eventually(func() error { + err = mcpClient.Start(testCtx) + if err != nil { + return err + } + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "toolhive-e2e-before-remove", + Version: "1.0.0", + } + + _, err = mcpClient.Initialize(testCtx, initRequest) + return err + }, 30*time.Second, 5*time.Second).Should(Succeed()) + + var toolsBeforeRemoval *mcp.ListToolsResult + Eventually(func() error { + var err error + toolsBeforeRemoval, err = mcpClient.ListTools(testCtx, mcp.ListToolsRequest{}) + return err + }, 30*time.Second, 2*time.Second).Should(Succeed()) + + toolCountBefore := len(toolsBeforeRemoval.Tools) + GinkgoWriter.Printf("Before removal: %d tools\n", toolCountBefore) + + By("Removing backend2 (osv)") + backend2 := &mcpv1alpha1.MCPServer{ + ObjectMeta: metav1.ObjectMeta{ + Name: backend2Name, + Namespace: testNamespace, + }, + } + Expect(k8sClient.Delete(ctx, backend2)).To(Succeed()) + + By("Waiting for backend deletion") + Eventually(func() bool { + server := &mcpv1alpha1.MCPServer{} + err := k8sClient.Get(ctx, types.NamespacedName{ + Name: backend2Name, + Namespace: testNamespace, + }, server) + return err != nil + }, timeout, pollingInterval).Should(BeTrue()) + + By("Verifying group now has fewer backends") + Eventually(func() int { + backends, err := GetMCPGroupBackends(ctx, k8sClient, mcpGroupName, testNamespace) + if err != nil { + return -1 + } + return len(backends) + }, 30*time.Second, 2*time.Second).Should(BeNumerically("<", 3)) + + By("Verifying tool count decreased with new session") + Eventually(func() error { + mcpClient2, err := client.NewStreamableHttpClient(serverURL) + if err != nil { + return err + } + defer mcpClient2.Close() + + testCtx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel2() + + err = mcpClient2.Start(testCtx2) + if err != nil { + return err + } + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "toolhive-e2e-after-remove", + Version: "1.0.0", + } + + _, err = mcpClient2.Initialize(testCtx2, initRequest) + if err != nil { + return err + } + + tools, err := mcpClient2.ListTools(testCtx2, mcp.ListToolsRequest{}) + if err != nil { + return err + } + + if len(tools.Tools) >= toolCountBefore { + return fmt.Errorf("expected fewer tools, got %d (was %d)", len(tools.Tools), toolCountBefore) + } + return nil + }, 1*time.Minute, 10*time.Second).Should(Succeed()) + }) + }) + + Context("when testing health and readiness endpoints", func() { + It("should expose /health endpoint that always returns 200", func() { + vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) + + By("Checking /health endpoint") + resp, err := http.Get(vmcpURL + "/health") + Expect(err).NotTo(HaveOccurred()) + defer resp.Body.Close() + + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + var health map[string]string + err = json.NewDecoder(resp.Body).Decode(&health) + Expect(err).NotTo(HaveOccurred()) + Expect(health["status"]).To(Equal("ok")) + }) + + It("should expose /readyz endpoint", func() { + vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) + + By("Checking /readyz endpoint is accessible") + resp, err := http.Get(vmcpURL + "/readyz") + Expect(err).NotTo(HaveOccurred()) + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + Fail(fmt.Sprintf("unexpected status code: %d, body: %s", resp.StatusCode, string(body))) + } + + By("Parsing readiness response") + var readiness ReadinessResponse + err = json.NewDecoder(resp.Body).Decode(&readiness) + Expect(err).NotTo(HaveOccurred()) + + By("Verifying readiness status") + Expect(readiness.Status).To(Equal("ready"), "Status should be ready") + }) + + It("should distinguish between /health and /readyz", func() { + vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) + + By("Getting /health response") + healthResp, err := http.Get(vmcpURL + "/health") + Expect(err).NotTo(HaveOccurred()) + defer healthResp.Body.Close() + + By("Getting /readyz response") + readyResp, err := http.Get(vmcpURL + "/readyz") + Expect(err).NotTo(HaveOccurred()) + defer readyResp.Body.Close() + + // Both should return 200 when ready + Expect(healthResp.StatusCode).To(Equal(http.StatusOK)) + Expect(readyResp.StatusCode).To(Equal(http.StatusOK)) + + // Parse both responses + var health map[string]string + err = json.NewDecoder(healthResp.Body).Decode(&health) + Expect(err).NotTo(HaveOccurred()) + + var readiness ReadinessResponse + err = json.NewDecoder(readyResp.Body).Decode(&readiness) + Expect(err).NotTo(HaveOccurred()) + + // Health is simple status + Expect(health).To(HaveKey("status")) + Expect(health).NotTo(HaveKey("mode")) + + // Readiness includes status + Expect(readiness.Status).To(Equal("ready")) + }) + }) + + Context("when testing status endpoint", func() { + It("should expose /status endpoint with group reference", func() { + vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) + + By("Checking /status endpoint") + resp, err := http.Get(vmcpURL + "/status") + Expect(err).NotTo(HaveOccurred()) + defer resp.Body.Close() + + Expect(resp.StatusCode).To(Equal(http.StatusOK)) + + var status map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&status) + Expect(err).NotTo(HaveOccurred()) + + By("Verifying group_ref is present") + Expect(status).To(HaveKey("group_ref")) + groupRef, ok := status["group_ref"].(string) + Expect(ok).To(BeTrue()) + Expect(groupRef).To(ContainSubstring(mcpGroupName)) + }) + + It("should list discovered backends in status", func() { + vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) + + By("Getting /status response") + resp, err := http.Get(vmcpURL + "/status") + Expect(err).NotTo(HaveOccurred()) + defer resp.Body.Close() + + var status map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&status) + Expect(err).NotTo(HaveOccurred()) + + By("Verifying backends are listed") + Expect(status).To(HaveKey("backends")) + backends, ok := status["backends"].([]interface{}) + Expect(ok).To(BeTrue()) + Expect(backends).NotTo(BeEmpty(), "Should have at least one backend") + + // Verify backend structure + backend, ok := backends[0].(map[string]interface{}) + Expect(ok).To(BeTrue(), "backend should be a map") + Expect(backend).To(HaveKey("name")) + Expect(backend).To(HaveKey("health")) + Expect(backend).To(HaveKey("transport")) + }) + }) }) diff --git a/test/e2e/thv-operator/virtualmcp/virtualmcp_lifecycle_test.go b/test/e2e/thv-operator/virtualmcp/virtualmcp_lifecycle_test.go deleted file mode 100644 index 47a6300e6d..0000000000 --- a/test/e2e/thv-operator/virtualmcp/virtualmcp_lifecycle_test.go +++ /dev/null @@ -1,1245 +0,0 @@ -package virtualmcp - -import ( - "context" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/mark3labs/mcp-go/client" - "github.com/mark3labs/mcp-go/mcp" - . "github.com/onsi/ginkgo/v2" - . "github.com/onsi/gomega" - corev1 "k8s.io/api/core/v1" - "k8s.io/apimachinery/pkg/api/errors" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/types" - ctrlclient "sigs.k8s.io/controller-runtime/pkg/client" - - mcpv1alpha1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1alpha1" - vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config" - "github.com/stacklok/toolhive/test/e2e/images" -) - -// NOTE: These tests verify DynamicRegistry functionality with full operator integration. -// The vMCP server now uses DynamicRegistry in Kubernetes mode and supports dynamic -// backend discovery via BackendWatcher. New sessions will see updated backends when -// they are added/removed from the MCPGroup. Existing sessions retain their original -// capability snapshot. -// -// Implementation status: DynamicRegistry is fully integrated with BackendReconciler that -// watches MCPServer/MCPRemoteProxy resources and updates the registry in real-time. - -// getBackendName safely extracts the backend name from a status response interface. -func getBackendName(b interface{}) string { - if backend, ok := b.(map[string]interface{}); ok { - if name, ok := backend["name"].(string); ok { - return name - } - } - return "" -} - -var _ = Describe("VirtualMCPServer Lifecycle - Dynamic Backend Discovery", Ordered, func() { - var ( - testNamespace = "default" - mcpGroupName = "test-lifecycle-group" - vmcpServerName = "test-vmcp-lifecycle" - backend1Name = "backend-lifecycle-fetch" - backend2Name = "backend-lifecycle-osv" - backend3Name = "backend-lifecycle-dynamic" // Backend added dynamically - timeout = 3 * time.Minute - pollingInterval = 1 * time.Second - vmcpNodePort int32 - ) - - BeforeAll(func() { - By("Creating MCPGroup") - CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, testNamespace, - "Test MCP Group for VirtualMCP lifecycle E2E tests", timeout, pollingInterval) - - By("Creating initial backend MCPServer - fetch (streamable-http)") - backend1 := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: backend1Name, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: mcpGroupName, - Image: images.GofetchServerImage, - Transport: "streamable-http", - ProxyPort: 8080, - McpPort: 8080, - }, - } - Expect(k8sClient.Create(ctx, backend1)).To(Succeed()) - - By("Waiting for initial backend MCPServer to be ready") - Eventually(func() error { - server := &mcpv1alpha1.MCPServer{} - err := k8sClient.Get(ctx, types.NamespacedName{ - Name: backend1Name, - Namespace: testNamespace, - }, server) - if err != nil { - return fmt.Errorf("failed to get server: %w", err) - } - - if server.Status.Phase == mcpv1alpha1.MCPServerPhaseRunning { - return nil - } - return fmt.Errorf("backend not ready yet, phase: %s", server.Status.Phase) - }, timeout, pollingInterval).Should(Succeed(), "Initial backend should be ready") - - By("Creating VirtualMCPServer in discovered mode") - vmcpServer := &mcpv1alpha1.VirtualMCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: vmcpServerName, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.VirtualMCPServerSpec{ - Config: vmcpconfig.Config{ - Group: mcpGroupName, - Aggregation: &vmcpconfig.AggregationConfig{ - ConflictResolution: "prefix", - }, - }, - IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{ - Type: "anonymous", - }, - ServiceType: "NodePort", - }, - } - Expect(k8sClient.Create(ctx, vmcpServer)).To(Succeed()) - - By("Waiting for VirtualMCPServer to be ready") - WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) - - By("Waiting for VirtualMCPServer to discover backends") - WaitForCondition(ctx, k8sClient, vmcpServerName, testNamespace, "BackendsDiscovered", "True", timeout, pollingInterval) - - By("Getting NodePort for VirtualMCPServer") - vmcpNodePort = GetVMCPNodePort(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) - - By(fmt.Sprintf("VirtualMCPServer accessible at http://localhost:%d", vmcpNodePort)) - - By("Waiting for VirtualMCPServer to be accessible") - Eventually(func() error { - httpClient := &http.Client{Timeout: 5 * time.Second} - url := fmt.Sprintf("http://localhost:%d/health", vmcpNodePort) - resp, err := httpClient.Get(url) - if err != nil { - return fmt.Errorf("health check failed: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - return nil - }, 30*time.Second, 2*time.Second).Should(Succeed(), "VirtualMCPServer health endpoint should be accessible") - }) - - AfterAll(func() { - By("Cleaning up VirtualMCPServer") - vmcpServer := &mcpv1alpha1.VirtualMCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: vmcpServerName, - Namespace: testNamespace, - }, - } - if err := k8sClient.Delete(ctx, vmcpServer); err != nil { - GinkgoWriter.Printf("Warning: failed to delete VirtualMCPServer: %v\n", err) - } - - By("Cleaning up all backend MCPServers") - for _, backendName := range []string{backend1Name, backend2Name, backend3Name} { - backend := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: backendName, - Namespace: testNamespace, - }, - } - if err := k8sClient.Delete(ctx, backend); err != nil { - GinkgoWriter.Printf("Warning: failed to delete backend %s: %v\n", backendName, err) - } - } - - By("Cleaning up MCPGroup") - mcpGroup := &mcpv1alpha1.MCPGroup{ - ObjectMeta: metav1.ObjectMeta{ - Name: mcpGroupName, - Namespace: testNamespace, - }, - } - if err := k8sClient.Delete(ctx, mcpGroup); err != nil { - GinkgoWriter.Printf("Warning: failed to delete MCPGroup: %v\n", err) - } - }) - - var initialToolCount int - - Context("when testing initial backend discovery", func() { - It("should discover tools from initial backend", func() { - By("Creating MCP client for VirtualMCPServer") - serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort) - mcpClient, err := client.NewStreamableHttpClient(serverURL) - Expect(err).ToNot(HaveOccurred()) - defer mcpClient.Close() - - By("Starting transport and initializing connection") - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - defer cancel() - - Eventually(func() error { - initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer initCancel() - - err = mcpClient.Start(initCtx) - if err != nil { - return fmt.Errorf("failed to start transport: %w", err) - } - - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "toolhive-e2e-lifecycle-test", - Version: "1.0.0", - } - - _, err = mcpClient.Initialize(initCtx, initRequest) - if err != nil { - return fmt.Errorf("failed to initialize: %w", err) - } - - return nil - }, 2*time.Minute, 5*time.Second).Should(Succeed(), "MCP client should initialize successfully") - - By("Listing tools from VirtualMCPServer") - var initialTools *mcp.ListToolsResult - Eventually(func() error { - listRequest := mcp.ListToolsRequest{} - var err error - initialTools, err = mcpClient.ListTools(ctx, listRequest) - if err != nil { - return fmt.Errorf("failed to list tools: %w", err) - } - if len(initialTools.Tools) == 0 { - return fmt.Errorf("no tools returned") - } - return nil - }, 30*time.Second, 2*time.Second).Should(Succeed(), "Should be able to list tools") - - initialToolCount = len(initialTools.Tools) - By(fmt.Sprintf("Initial state: VirtualMCPServer has %d tools", initialToolCount)) - for _, tool := range initialTools.Tools { - GinkgoWriter.Printf(" Initial tool: %s - %s\n", tool.Name, tool.Description) - } - - // Verify we have at least one tool from the initial backend - Expect(initialTools.Tools).ToNot(BeEmpty(), "VirtualMCPServer should have tools from initial backend") - }) - - It("should have exactly one backend in the group", func() { - backends, err := GetMCPGroupBackends(ctx, k8sClient, mcpGroupName, testNamespace) - Expect(err).ToNot(HaveOccurred()) - Expect(backends).To(HaveLen(1), "Should have exactly one backend initially") - Expect(backends[0].Name).To(Equal(backend1Name)) - }) - }) - - Context("when dynamically adding a new backend", func() { - It("should detect the new backend and update tool list", func() { - By("Adding second backend MCPServer - osv (streamable-http)") - backend2 := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: backend2Name, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: mcpGroupName, - Image: images.OSVMCPServerImage, - Transport: "streamable-http", - ProxyPort: 8080, - McpPort: 8080, - }, - } - Expect(k8sClient.Create(ctx, backend2)).To(Succeed()) - - By("Waiting for new backend MCPServer to be ready") - Eventually(func() error { - server := &mcpv1alpha1.MCPServer{} - err := k8sClient.Get(ctx, types.NamespacedName{ - Name: backend2Name, - Namespace: testNamespace, - }, server) - if err != nil { - return fmt.Errorf("failed to get server: %w", err) - } - - if server.Status.Phase == mcpv1alpha1.MCPServerPhaseRunning { - return nil - } - return fmt.Errorf("backend not ready yet, phase: %s", server.Status.Phase) - }, timeout, pollingInterval).Should(Succeed(), "New backend should be ready") - - By("Verifying the group now has two backends") - Eventually(func() int { - backends, err := GetMCPGroupBackends(ctx, k8sClient, mcpGroupName, testNamespace) - if err != nil { - return 0 - } - return len(backends) - }, 30*time.Second, 2*time.Second).Should(Equal(2), "Should have two backends after adding") - - By("Waiting for VirtualMCPServer to reconcile and discover tools from both backends") - // Use Eventually to wait for the VirtualMCPServer to: - // 1. Detect the new backend in the group via operator reconciliation - // 2. Update the DynamicRegistry (which increments version) - // 3. Invalidate cached capabilities - // 4. Rediscover capabilities from both backends - serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort) - - var updatedTools *mcp.ListToolsResult - Eventually(func() error { - // Create a fresh client for each attempt to ensure we're not hitting stale cache - mcpClient, err := client.NewStreamableHttpClient(serverURL) - if err != nil { - return fmt.Errorf("failed to create client: %w", err) - } - defer mcpClient.Close() - - testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Start and initialize - err = mcpClient.Start(testCtx) - if err != nil { - return fmt.Errorf("failed to start transport: %w", err) - } - - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "toolhive-e2e-lifecycle-test-add", - Version: "1.0.0", - } - - _, err = mcpClient.Initialize(testCtx, initRequest) - if err != nil { - return fmt.Errorf("failed to initialize: %w", err) - } - - // List tools - listRequest := mcp.ListToolsRequest{} - updatedTools, err = mcpClient.ListTools(testCtx, listRequest) - if err != nil { - return fmt.Errorf("failed to list tools: %w", err) - } - - currentToolCount := len(updatedTools.Tools) - - // Log current state for debugging - if currentToolCount > 0 { - GinkgoWriter.Printf("Attempt: %d tools found (initial: %d)\n", currentToolCount, initialToolCount) - for _, tool := range updatedTools.Tools { - GinkgoWriter.Printf(" - %s\n", tool.Name) - } - } - - // Should have more tools now (from both backends) - // Check if tool count increased from initial state - if currentToolCount <= initialToolCount { - return fmt.Errorf("expected more tools after adding backend, got %d (initial: %d)", currentToolCount, initialToolCount) - } - return nil - }, 2*time.Minute, 5*time.Second).Should(Succeed(), "Should see more tools after adding second backend") - - By(fmt.Sprintf("After adding backend: VirtualMCPServer now has %d tools", len(updatedTools.Tools))) - for _, tool := range updatedTools.Tools { - GinkgoWriter.Printf(" Updated tool: %s - %s\n", tool.Name, tool.Description) - } - }) - }) - - Context("when dynamically removing a backend", func() { - It("should detect backend removal and update tool list", func() { - By("Getting current tool count") - serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort) - mcpClient, err := client.NewStreamableHttpClient(serverURL) - Expect(err).ToNot(HaveOccurred()) - defer mcpClient.Close() - - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) - defer cancel() - - Eventually(func() error { - initCtx, initCancel := context.WithTimeout(context.Background(), 10*time.Second) - defer initCancel() - - err = mcpClient.Start(initCtx) - if err != nil { - return fmt.Errorf("failed to start transport: %w", err) - } - - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "toolhive-e2e-lifecycle-test-before-remove", - Version: "1.0.0", - } - - _, err = mcpClient.Initialize(initCtx, initRequest) - if err != nil { - return fmt.Errorf("failed to initialize: %w", err) - } - - return nil - }, 2*time.Minute, 5*time.Second).Should(Succeed()) - - var toolsBeforeRemoval *mcp.ListToolsResult - Eventually(func() error { - listRequest := mcp.ListToolsRequest{} - var err error - toolsBeforeRemoval, err = mcpClient.ListTools(ctx, listRequest) - if err != nil { - return fmt.Errorf("failed to list tools: %w", err) - } - return nil - }, 30*time.Second, 2*time.Second).Should(Succeed()) - - toolCountBefore := len(toolsBeforeRemoval.Tools) - By(fmt.Sprintf("Before removal: %d tools", toolCountBefore)) - - By("Removing the second backend (osv)") - backend2 := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: backend2Name, - Namespace: testNamespace, - }, - } - Expect(k8sClient.Delete(ctx, backend2)).To(Succeed()) - - By("Waiting for backend deletion to complete") - Eventually(func() bool { - server := &mcpv1alpha1.MCPServer{} - err := k8sClient.Get(ctx, types.NamespacedName{ - Name: backend2Name, - Namespace: testNamespace, - }, server) - return err != nil // Should fail to get when deleted - }, timeout, pollingInterval).Should(BeTrue(), "Backend should be deleted") - - By("Verifying the group now has one backend") - Eventually(func() int { - backends, err := GetMCPGroupBackends(ctx, k8sClient, mcpGroupName, testNamespace) - if err != nil { - return -1 - } - return len(backends) - }, 30*time.Second, 2*time.Second).Should(Equal(1), "Should have one backend after removal") - - By("Waiting for VirtualMCPServer to detect backend removal and update tool list") - var toolsAfterRemoval *mcp.ListToolsResult - Eventually(func() error { - // Create a fresh client for each attempt - mcpClient2, err := client.NewStreamableHttpClient(serverURL) - if err != nil { - return fmt.Errorf("failed to create client: %w", err) - } - defer mcpClient2.Close() - - testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Start and initialize - err = mcpClient2.Start(testCtx) - if err != nil { - return fmt.Errorf("failed to start transport: %w", err) - } - - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "toolhive-e2e-lifecycle-test-after-remove", - Version: "1.0.0", - } - - _, err = mcpClient2.Initialize(testCtx, initRequest) - if err != nil { - return fmt.Errorf("failed to initialize: %w", err) - } - - // List tools - listRequest := mcp.ListToolsRequest{} - toolsAfterRemoval, err = mcpClient2.ListTools(testCtx, listRequest) - if err != nil { - return fmt.Errorf("failed to list tools: %w", err) - } - - toolCountAfter := len(toolsAfterRemoval.Tools) - - // Verify tool count decreased (tools from removed backend are gone) - if toolCountAfter >= toolCountBefore { - return fmt.Errorf("expected fewer tools after removal, got %d (was %d)", toolCountAfter, toolCountBefore) - } - - return nil - }, 2*time.Minute, 5*time.Second).Should(Succeed(), "Should have fewer tools after backend removal") - - By(fmt.Sprintf("After removal: %d tools (was %d)", len(toolsAfterRemoval.Tools), toolCountBefore)) - - By("Verifying tools from removed backend are no longer present") - for _, tool := range toolsAfterRemoval.Tools { - GinkgoWriter.Printf(" Remaining tool: %s - %s\n", tool.Name, tool.Description) - // Tools from osv backend should not be present - Expect(strings.Contains(strings.ToLower(tool.Name), "osv")).To(BeFalse(), - "Tools from removed osv backend should not be present") - } - }) - }) - - Context("when testing cache invalidation", func() { - It("should invalidate cache when backends change", func() { - By("Adding a third backend to trigger cache invalidation") - backend3 := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: backend3Name, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: mcpGroupName, - Image: images.GofetchServerImage, // Use fetch image for simplicity - Transport: "streamable-http", - ProxyPort: 8080, - McpPort: 8080, - }, - } - Expect(k8sClient.Create(ctx, backend3)).To(Succeed()) - - By("Waiting for new backend to be ready") - Eventually(func() error { - server := &mcpv1alpha1.MCPServer{} - err := k8sClient.Get(ctx, types.NamespacedName{ - Name: backend3Name, - Namespace: testNamespace, - }, server) - if err != nil { - return fmt.Errorf("failed to get server: %w", err) - } - - if server.Status.Phase == mcpv1alpha1.MCPServerPhaseRunning { - return nil - } - return fmt.Errorf("backend not ready yet, phase: %s", server.Status.Phase) - }, timeout, pollingInterval).Should(Succeed()) - - By("Verifying tool list is updated (cache was invalidated)") - serverURL := fmt.Sprintf("http://localhost:%d/mcp", vmcpNodePort) - - var tools *mcp.ListToolsResult - Eventually(func() error { - // Create a fresh client for each attempt - mcpClient, err := client.NewStreamableHttpClient(serverURL) - if err != nil { - return fmt.Errorf("failed to create client: %w", err) - } - defer mcpClient.Close() - - testCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - // Start and initialize - err = mcpClient.Start(testCtx) - if err != nil { - return fmt.Errorf("failed to start transport: %w", err) - } - - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "toolhive-e2e-lifecycle-test-cache", - Version: "1.0.0", - } - - _, err = mcpClient.Initialize(testCtx, initRequest) - if err != nil { - return fmt.Errorf("failed to initialize: %w", err) - } - - // List tools - listRequest := mcp.ListToolsRequest{} - tools, err = mcpClient.ListTools(testCtx, listRequest) - if err != nil { - return fmt.Errorf("failed to list tools: %w", err) - } - - // Should now have tools from 2 backends (backend1 and backend3) - if len(tools.Tools) < 1 { - return fmt.Errorf("expected tools from active backends, got %d", len(tools.Tools)) - } - - return nil - }, 2*time.Minute, 5*time.Second).Should(Succeed(), "Cache should be invalidated and show updated tools") - - By(fmt.Sprintf("After cache invalidation: VirtualMCPServer has %d tools from active backends", len(tools.Tools))) - - By("Verifying backends in the group") - backends, err := GetMCPGroupBackends(ctx, k8sClient, mcpGroupName, testNamespace) - Expect(err).ToNot(HaveOccurred()) - Expect(backends).To(HaveLen(2), "Should have two backends after adding third backend") - - backendNames := make([]string, len(backends)) - for i, backend := range backends { - backendNames[i] = backend.Name - } - Expect(backendNames).To(ContainElements(backend1Name, backend3Name)) - Expect(backendNames).ToNot(ContainElement(backend2Name), "Removed backend should not be present") - }) - }) -}) - -// ReadinessResponse represents the /readyz endpoint response -type ReadinessResponse struct { - Status string `json:"status"` - Mode string `json:"mode"` - Reason string `json:"reason,omitempty"` -} - -// VirtualMCPServer K8s Manager Infrastructure Tests -// These tests verify the K8s manager integration that was implemented as part of THV-2884. -// This includes BackendWatcher with BackendReconciler for dynamic backend discovery, -// manager creation, readiness probes, cache sync, and endpoint behavior. -var _ = Describe("VirtualMCPServer K8s Manager Infrastructure", Ordered, func() { - var ( - testNamespace = "default" - mcpGroupName = "test-k8s-manager-infra-group" - vmcpServerName = "test-vmcp-k8s-manager-infra" - backendName = "backend-k8s-manager-infra-fetch" - timeout = 3 * time.Minute - pollingInterval = 2 * time.Second - vmcpNodePort int32 - ) - - BeforeAll(func() { - By("Creating MCPGroup for K8s manager infrastructure tests") - CreateMCPGroupAndWait(ctx, k8sClient, mcpGroupName, testNamespace, - "Test MCP Group for K8s manager infrastructure E2E tests", timeout, pollingInterval) - - By("Creating backend MCPServer") - backend := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: backendName, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: mcpGroupName, - Image: images.GofetchServerImage, - Transport: "streamable-http", - ProxyPort: 8080, - McpPort: 8080, - }, - } - Expect(k8sClient.Create(ctx, backend)).To(Succeed()) - - By("Waiting for backend MCPServer to be ready") - Eventually(func() error { - server := &mcpv1alpha1.MCPServer{} - err := k8sClient.Get(ctx, types.NamespacedName{ - Name: backendName, - Namespace: testNamespace, - }, server) - if err != nil { - return fmt.Errorf("failed to get server: %w", err) - } - - if server.Status.Phase == mcpv1alpha1.MCPServerPhaseRunning { - return nil - } - return fmt.Errorf("backend not ready yet, phase: %s", server.Status.Phase) - }, timeout, pollingInterval).Should(Succeed(), "Backend should be ready") - - By("Creating VirtualMCPServer with discovered auth source (dynamic mode)") - vmcpServer := &mcpv1alpha1.VirtualMCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: vmcpServerName, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.VirtualMCPServerSpec{ - Config: vmcpconfig.Config{ - Group: mcpGroupName, - Aggregation: &vmcpconfig.AggregationConfig{ - ConflictResolution: "prefix", - }, - }, - IncomingAuth: &mcpv1alpha1.IncomingAuthConfig{ - Type: "anonymous", - }, - OutgoingAuth: &mcpv1alpha1.OutgoingAuthConfig{ - Source: "discovered", // This triggers K8s manager creation - }, - ServiceType: "NodePort", - }, - } - Expect(k8sClient.Create(ctx, vmcpServer)).To(Succeed()) - - By("Waiting for VirtualMCPServer to be ready") - WaitForVirtualMCPServerReady(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) - - By("Getting NodePort for VirtualMCPServer") - vmcpNodePort = GetVMCPNodePort(ctx, k8sClient, vmcpServerName, testNamespace, timeout, pollingInterval) - - By(fmt.Sprintf("VirtualMCPServer is ready on NodePort: %d", vmcpNodePort)) - - By("Waiting for VirtualMCPServer to be accessible") - Eventually(func() error { - httpClient := &http.Client{Timeout: 5 * time.Second} - url := fmt.Sprintf("http://localhost:%d/health", vmcpNodePort) - resp, err := httpClient.Get(url) - if err != nil { - return fmt.Errorf("health check failed: %w", err) - } - defer resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - return nil - }, 30*time.Second, 2*time.Second).Should(Succeed(), "VirtualMCPServer health endpoint should be accessible") - }) - - AfterAll(func() { - By("Cleaning up VirtualMCPServer") - vmcpServer := &mcpv1alpha1.VirtualMCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: vmcpServerName, - Namespace: testNamespace, - }, - } - _ = k8sClient.Delete(ctx, vmcpServer) - - By("Cleaning up backend MCPServer") - backend := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: backendName, - Namespace: testNamespace, - }, - } - _ = k8sClient.Delete(ctx, backend) - - By("Cleaning up MCPGroup") - group := &mcpv1alpha1.MCPGroup{ - ObjectMeta: metav1.ObjectMeta{ - Name: mcpGroupName, - Namespace: testNamespace, - }, - } - _ = k8sClient.Delete(ctx, group) - }) - - Context("Readiness Probe Integration", func() { - It("should expose /readyz endpoint", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - - By("Checking /readyz endpoint is accessible") - Eventually(func() error { - resp, err := http.Get(vmcpURL + "/readyz") - if err != nil { - return fmt.Errorf("failed to connect to /readyz: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) - } - - return nil - }, 2*time.Minute, 5*time.Second).Should(Succeed(), "/readyz should return 200 OK") - }) - - It("should return dynamic mode status", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - - By("Getting /readyz response") - resp, err := http.Get(vmcpURL + "/readyz") - Expect(err).NotTo(HaveOccurred()) - defer resp.Body.Close() - - Expect(resp.StatusCode).To(Equal(http.StatusOK)) - - By("Parsing readiness response") - var readiness ReadinessResponse - err = json.NewDecoder(resp.Body).Decode(&readiness) - Expect(err).NotTo(HaveOccurred()) - - By("Verifying dynamic mode is enabled") - Expect(readiness.Status).To(Equal("ready"), "Status should be ready") - Expect(readiness.Mode).To(Equal("dynamic"), "Mode should be dynamic since outgoingAuth.source is 'discovered'") - }) - - It("should indicate cache sync in dynamic mode", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - - By("Verifying cache is synced") - resp, err := http.Get(vmcpURL + "/readyz") - Expect(err).NotTo(HaveOccurred()) - defer resp.Body.Close() - - var readiness ReadinessResponse - err = json.NewDecoder(resp.Body).Decode(&readiness) - Expect(err).NotTo(HaveOccurred()) - - // In dynamic mode with synced cache, status should be "ready" - Expect(readiness.Status).To(Equal("ready")) - Expect(readiness.Mode).To(Equal("dynamic")) - // Reason should be empty when ready - Expect(readiness.Reason).To(BeEmpty()) - }) - }) - - Context("K8s Manager Lifecycle", func() { - It("should start with K8s manager running", func() { - By("Verifying pod is running") - Eventually(func() error { - pods := &corev1.PodList{} - err := k8sClient.List(ctx, pods, - ctrlclient.InNamespace(testNamespace), - ctrlclient.MatchingLabels{"app.kubernetes.io/instance": vmcpServerName}) - if err != nil { - return fmt.Errorf("failed to list pods: %w", err) - } - - if len(pods.Items) == 0 { - return fmt.Errorf("no pods found") - } - - pod := pods.Items[0] - if pod.Status.Phase != corev1.PodRunning { - return fmt.Errorf("pod not running yet, phase: %s", pod.Status.Phase) - } - - // Check pod is ready - for _, condition := range pod.Status.Conditions { - if condition.Type == corev1.PodReady { - if condition.Status != corev1.ConditionTrue { - return fmt.Errorf("pod not ready: %s", condition.Message) - } - return nil - } - } - - return fmt.Errorf("pod ready condition not found") - }, timeout, pollingInterval).Should(Succeed(), "Pod should be running and ready") - }) - - It("should have healthy container status", func() { - By("Getting pod name") - pods := &corev1.PodList{} - err := k8sClient.List(ctx, pods, - ctrlclient.InNamespace(testNamespace), - ctrlclient.MatchingLabels{"app.kubernetes.io/instance": vmcpServerName}) - Expect(err).NotTo(HaveOccurred()) - Expect(pods.Items).NotTo(BeEmpty(), "Should have at least one pod") - - podName := pods.Items[0].Name - - By("Checking container status") - Eventually(func() error { - pod := &corev1.Pod{} - err := k8sClient.Get(ctx, types.NamespacedName{ - Name: podName, - Namespace: testNamespace, - }, pod) - if err != nil { - return err - } - - // Check all containers are ready - for _, status := range pod.Status.ContainerStatuses { - if !status.Ready { - return fmt.Errorf("container %s not ready", status.Name) - } - } - - return nil - }, timeout, pollingInterval).Should(Succeed(), "All containers should be ready") - }) - }) - - Context("Dynamic Backend Discovery Lifecycle", func() { - var ( - dynamicBackend1Name = "dynamic-backend-1" - dynamicBackend2Name = "dynamic-backend-2" - ) - - AfterEach(func() { - // Cleanup any dynamic backends created in tests - _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: dynamicBackend1Name, - Namespace: testNamespace, - }, - }) - _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: dynamicBackend2Name, - Namespace: testNamespace, - }, - }) - }) - - It("should discover new backends added to the group", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - - By("Getting initial backend count") - resp, err := http.Get(vmcpURL + "/status") - Expect(err).NotTo(HaveOccurred()) - defer resp.Body.Close() - - var initialStatus map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&initialStatus) - Expect(err).NotTo(HaveOccurred()) - - initialBackends, ok := initialStatus["backends"].([]interface{}) - Expect(ok).To(BeTrue(), "backends field should be an array") - initialCount := len(initialBackends) - - By("Creating a new backend MCPServer in the same group") - newBackend := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: dynamicBackend1Name, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: mcpGroupName, - Image: images.GofetchServerImage, - Transport: "streamable-http", - ProxyPort: 8080, - McpPort: 8080, - }, - } - Expect(k8sClient.Create(ctx, newBackend)).To(Succeed()) - - By("Waiting for new backend to be running") - Eventually(func() error { - server := &mcpv1alpha1.MCPServer{} - err := k8sClient.Get(ctx, types.NamespacedName{ - Name: dynamicBackend1Name, - Namespace: testNamespace, - }, server) - if err != nil { - return err - } - if server.Status.Phase != mcpv1alpha1.MCPServerPhaseRunning { - return fmt.Errorf("backend not running yet, phase: %s", server.Status.Phase) - } - return nil - }, timeout, pollingInterval).Should(Succeed()) - - By("Verifying new backend appears in vMCP status") - Eventually(func() bool { - resp, err := http.Get(vmcpURL + "/status") - if err != nil { - return false - } - defer resp.Body.Close() - - var status map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { - return false - } - - backends, ok := status["backends"].([]interface{}) - if !ok { - return false - } - if len(backends) != initialCount+1 { - return false - } - - // Check that the new backend is in the list - for _, b := range backends { - if strings.Contains(getBackendName(b), dynamicBackend1Name) { - return true - } - } - return false - }, timeout, pollingInterval).Should(BeTrue(), "New backend should be discovered") - }) - - It("should remove backends deleted from the group", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - - By("Creating a backend to be deleted") - tempBackend := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: dynamicBackend2Name, - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: mcpGroupName, - Image: images.GofetchServerImage, - Transport: "streamable-http", - ProxyPort: 8080, - McpPort: 8080, - }, - } - Expect(k8sClient.Create(ctx, tempBackend)).To(Succeed()) - - By("Waiting for backend to be running and discovered") - Eventually(func() bool { - resp, err := http.Get(vmcpURL + "/status") - if err != nil { - return false - } - defer resp.Body.Close() - - var status map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { - return false - } - - backends, ok := status["backends"].([]interface{}) - if !ok { - return false - } - for _, b := range backends { - if strings.Contains(getBackendName(b), dynamicBackend2Name) { - return true - } - } - return false - }, timeout, pollingInterval).Should(BeTrue(), "Backend should be discovered") - - By("Deleting the backend") - Expect(k8sClient.Delete(ctx, tempBackend)).To(Succeed()) - - By("Waiting for backend to be fully deleted from K8s") - Eventually(func() bool { - err := k8sClient.Get(ctx, ctrlclient.ObjectKey{ - Name: dynamicBackend2Name, - Namespace: testNamespace, - }, &mcpv1alpha1.MCPServer{}) - return errors.IsNotFound(err) - }, timeout, pollingInterval).Should(BeTrue(), "Backend should be deleted from K8s") - - By("Waiting for backend pod to be deleted") - Eventually(func() int { - podList := &corev1.PodList{} - err := k8sClient.List(ctx, podList, ctrlclient.InNamespace(testNamespace), - ctrlclient.MatchingLabels{"app.kubernetes.io/name": dynamicBackend2Name}) - if err != nil { - return -1 - } - return len(podList.Items) - }, timeout, pollingInterval).Should(Equal(0), "Backend pods should be deleted") - - By("Verifying backend is removed from vMCP status") - Eventually(func() bool { - resp, err := http.Get(vmcpURL + "/status") - if err != nil { - return false - } - defer resp.Body.Close() - - var status map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { - return false - } - - backends, ok := status["backends"].([]interface{}) - if !ok { - return false - } - - for _, b := range backends { - if strings.Contains(getBackendName(b), dynamicBackend2Name) { - return false // Backend still present - } - } - return true // Backend not found (removed) - }, timeout*2, pollingInterval).Should(BeTrue(), "Deleted backend should be removed from status") - }) - - It("should not discover backends from different groups", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - differentGroup := "different-group" - - By("Creating a group with a different name") - CreateMCPGroupAndWait(ctx, k8sClient, differentGroup, testNamespace, - "Different group for isolation testing", timeout, pollingInterval) - defer func() { - _ = k8sClient.Delete(ctx, &mcpv1alpha1.MCPGroup{ - ObjectMeta: metav1.ObjectMeta{ - Name: differentGroup, - Namespace: testNamespace, - }, - }) - }() - - By("Creating a backend in the different group") - otherGroupBackend := &mcpv1alpha1.MCPServer{ - ObjectMeta: metav1.ObjectMeta{ - Name: "other-group-backend", - Namespace: testNamespace, - }, - Spec: mcpv1alpha1.MCPServerSpec{ - GroupRef: differentGroup, // Different group - Image: images.GofetchServerImage, - Transport: "streamable-http", - ProxyPort: 8080, - McpPort: 8080, - }, - } - Expect(k8sClient.Create(ctx, otherGroupBackend)).To(Succeed()) - defer func() { - _ = k8sClient.Delete(ctx, otherGroupBackend) - }() - - By("Waiting for backend to be running") - Eventually(func() error { - server := &mcpv1alpha1.MCPServer{} - err := k8sClient.Get(ctx, types.NamespacedName{ - Name: "other-group-backend", - Namespace: testNamespace, - }, server) - if err != nil { - return err - } - if server.Status.Phase != mcpv1alpha1.MCPServerPhaseRunning { - return fmt.Errorf("backend not running yet") - } - return nil - }, timeout, pollingInterval).Should(Succeed()) - - By("Verifying backend from different group is NOT discovered") - Consistently(func() bool { - resp, err := http.Get(vmcpURL + "/status") - if err != nil { - return true // Continue checking - } - defer resp.Body.Close() - - var status map[string]interface{} - if err := json.NewDecoder(resp.Body).Decode(&status); err != nil { - return true - } - - backends, ok := status["backends"].([]interface{}) - if !ok { - return true // Continue checking if structure is unexpected - } - for _, b := range backends { - if strings.Contains(getBackendName(b), "other-group-backend") { - return false // Backend found - test should fail - } - } - return true // Backend not found - correct - }, 10*time.Second, pollingInterval).Should(BeTrue(), "Backend from different group should not be discovered") - }) - }) - - Context("Health Endpoints", func() { - It("should expose /health endpoint that always returns 200", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - - By("Checking /health endpoint") - resp, err := http.Get(vmcpURL + "/health") - Expect(err).NotTo(HaveOccurred()) - defer resp.Body.Close() - - Expect(resp.StatusCode).To(Equal(http.StatusOK)) - - var health map[string]string - err = json.NewDecoder(resp.Body).Decode(&health) - Expect(err).NotTo(HaveOccurred()) - Expect(health["status"]).To(Equal("ok")) - }) - - It("should distinguish between /health and /readyz", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - - By("Getting /health response") - healthResp, err := http.Get(vmcpURL + "/health") - Expect(err).NotTo(HaveOccurred()) - defer healthResp.Body.Close() - - By("Getting /readyz response") - readyResp, err := http.Get(vmcpURL + "/readyz") - Expect(err).NotTo(HaveOccurred()) - defer readyResp.Body.Close() - - // Both should return 200 when ready - Expect(healthResp.StatusCode).To(Equal(http.StatusOK)) - Expect(readyResp.StatusCode).To(Equal(http.StatusOK)) - - // Parse both responses - var health map[string]string - err = json.NewDecoder(healthResp.Body).Decode(&health) - Expect(err).NotTo(HaveOccurred()) - - var readiness ReadinessResponse - err = json.NewDecoder(readyResp.Body).Decode(&readiness) - Expect(err).NotTo(HaveOccurred()) - - // Health is simple status - Expect(health).To(HaveKey("status")) - Expect(health).NotTo(HaveKey("mode")) - - // Readiness includes mode information - Expect(readiness.Status).To(Equal("ready")) - Expect(readiness.Mode).To(Equal("dynamic")) - }) - }) - - Context("Status Endpoint", func() { - It("should expose /status endpoint with group reference", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - - By("Checking /status endpoint") - resp, err := http.Get(vmcpURL + "/status") - Expect(err).NotTo(HaveOccurred()) - defer resp.Body.Close() - - Expect(resp.StatusCode).To(Equal(http.StatusOK)) - - var status map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&status) - Expect(err).NotTo(HaveOccurred()) - - By("Verifying group_ref is present") - Expect(status).To(HaveKey("group_ref")) - groupRef, ok := status["group_ref"].(string) - Expect(ok).To(BeTrue()) - Expect(groupRef).To(ContainSubstring(mcpGroupName)) - }) - - It("should list discovered backends", func() { - vmcpURL := fmt.Sprintf("http://localhost:%d", vmcpNodePort) - - By("Getting /status response") - resp, err := http.Get(vmcpURL + "/status") - Expect(err).NotTo(HaveOccurred()) - defer resp.Body.Close() - - var status map[string]interface{} - err = json.NewDecoder(resp.Body).Decode(&status) - Expect(err).NotTo(HaveOccurred()) - - By("Verifying backends are listed") - Expect(status).To(HaveKey("backends")) - backends, ok := status["backends"].([]interface{}) - Expect(ok).To(BeTrue()) - Expect(backends).NotTo(BeEmpty(), "Should have at least one backend") - - // Verify backend structure - backend, ok := backends[0].(map[string]interface{}) - Expect(ok).To(BeTrue(), "backend should be a map") - Expect(backend).To(HaveKey("name")) - Expect(backend).To(HaveKey("health")) - Expect(backend).To(HaveKey("transport")) - }) - }) -}) diff --git a/test/integration/vmcp/helpers/helpers_test.go b/test/integration/vmcp/helpers/helpers_test.go new file mode 100644 index 0000000000..0438b9deb2 --- /dev/null +++ b/test/integration/vmcp/helpers/helpers_test.go @@ -0,0 +1,76 @@ +package helpers + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" +) + +// TestGetToolNames tests the GetToolNames helper function. +func TestGetToolNames(t *testing.T) { + t.Parallel() + tests := []struct { + name string + result *mcp.ListToolsResult + expected []string + }{ + { + name: "empty tools", + result: &mcp.ListToolsResult{ + Tools: []mcp.Tool{}, + }, + expected: []string{}, + }, + { + name: "single tool", + result: &mcp.ListToolsResult{ + Tools: []mcp.Tool{ + {Name: "tool1"}, + }, + }, + expected: []string{"tool1"}, + }, + { + name: "multiple tools", + result: &mcp.ListToolsResult{ + Tools: []mcp.Tool{ + {Name: "tool1"}, + {Name: "tool2"}, + {Name: "tool3"}, + }, + }, + expected: []string{"tool1", "tool2", "tool3"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + names := GetToolNames(tt.result) + assert.Equal(t, tt.expected, names) + }) + } +} + +// TestAssertTextContains tests the AssertTextContains helper. +func TestAssertTextContains(t *testing.T) { + t.Parallel() + t.Run("all substrings present", func(t *testing.T) { + t.Parallel() + text := "hello world, this is a test" + // Should not fail + AssertTextContains(t, text, "hello", "world", "test") + }) +} + +// TestAssertTextNotContains tests the AssertTextNotContains helper. +func TestAssertTextNotContains(t *testing.T) { + t.Parallel() + t.Run("no forbidden substrings", func(t *testing.T) { + t.Parallel() + text := "hello world" + // Should not fail + AssertTextNotContains(t, text, "password", "secret") + }) +} From f68cb038ed04d00b6c4687cb9641567fc59b1952 Mon Sep 17 00:00:00 2001 From: Don Browne Date: Tue, 20 Jan 2026 17:08:12 +0000 Subject: [PATCH 11/16] Run the API E2E test server as a standalone process (#3356) Run the server as a standalone process Previously it was instantiating the server in-process. Use the full binary to ensure that the tests are realistic. --- test/e2e/api_helpers.go | 144 ++++++++++++++++++++++------------------ 1 file changed, 79 insertions(+), 65 deletions(-) diff --git a/test/e2e/api_helpers.go b/test/e2e/api_helpers.go index 5ce35c68f1..b5a65edb8b 100644 --- a/test/e2e/api_helpers.go +++ b/test/e2e/api_helpers.go @@ -4,15 +4,16 @@ package e2e import ( "context" "fmt" - "net" "net/http" + "os/exec" + "strconv" + "strings" "time" . "github.com/onsi/ginkgo/v2" //nolint:staticcheck // Standard practice for Ginkgo . "github.com/onsi/gomega" //nolint:staticcheck // Standard practice for Gomega - "github.com/stacklok/toolhive/pkg/api" - "github.com/stacklok/toolhive/pkg/container" + "github.com/stacklok/toolhive/pkg/networking" ) // ServerConfig holds configuration for the API server in tests @@ -33,83 +34,89 @@ func NewServerConfig() *ServerConfig { } } -// Server represents a running API server instance for testing +// Server represents a running API server instance for testing. +// It runs `thv serve` as a subprocess. type Server struct { config *ServerConfig baseURL string + cmd *exec.Cmd ctx context.Context cancel context.CancelFunc - serverErr chan error - done chan struct{} httpClient *http.Client + port int + stderr *strings.Builder + stdout *strings.Builder } -// NewServer creates and starts a new API server instance +// NewServer creates and starts a new API server instance by running `thv serve` as a subprocess. func NewServer(config *ServerConfig) (*Server, error) { - ctx, cancel := context.WithCancel(context.Background()) + testConfig := NewTestConfig() - // Create a temporary listener to get a free port - listener, err := net.Listen("tcp", config.Address) + // Find a free port + port, err := networking.FindOrUsePort(0) if err != nil { + return nil, fmt.Errorf("failed to find free port: %w", err) + } + + // Create temporary config directory (similar to CLI tests) + tempXdgConfigHome := GinkgoT().TempDir() + tempHome := GinkgoT().TempDir() + + ctx, cancel := context.WithCancel(context.Background()) + + // Create string builders to capture output + var stdout, stderr strings.Builder + + // Create the command: thv serve --host 127.0.0.1 --port + //nolint:gosec // Intentional for e2e testing + cmd := exec.CommandContext( + ctx, + testConfig.THVBinary, + "serve", + "--host", + "127.0.0.1", + "--port", + strconv.Itoa(port), + ) + // Set environment variables including temporary config paths + cmd.Env = append([]string{ + "TOOLHIVE_DEV=true", + fmt.Sprintf("XDG_CONFIG_HOME=%s", tempXdgConfigHome), + fmt.Sprintf("HOME=%s", tempHome), + }, cmd.Env...) + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + // Start the server process + if err := cmd.Start(); err != nil { cancel() - return nil, fmt.Errorf("failed to create listener: %w", err) + return nil, fmt.Errorf("failed to start thv serve: %w", err) } - actualAddr := listener.Addr().String() - // Close the listener immediately as the server will create its own - _ = listener.Close() server := &Server{ - config: config, - baseURL: fmt.Sprintf("http://%s", actualAddr), - ctx: ctx, - cancel: cancel, - serverErr: make(chan error, 1), - done: make(chan struct{}), + config: config, + baseURL: fmt.Sprintf("http://127.0.0.1:%d", port), + cmd: cmd, + ctx: ctx, + cancel: cancel, httpClient: &http.Client{ Timeout: config.RequestTimeout, }, + port: port, + stdout: &stdout, + stderr: &stderr, } - // Start the server in a goroutine - go func() { - defer close(server.done) - // Create container runtime for the API server - containerRuntime, err := container.NewFactory().Create(ctx) - if err != nil { - server.serverErr <- fmt.Errorf("failed to create container runtime: %w", err) - return - } - - builder := api.NewServerBuilder(). - WithAddress(actualAddr). - WithUnixSocket(false). - WithDebugMode(config.DebugMode). - WithDocs(false). - WithOIDCConfig(nil). - WithContainerRuntime(containerRuntime) - - apiServer, err := api.NewServer(ctx, builder) - if err != nil { - server.serverErr <- fmt.Errorf("failed to create API server: %w", err) - return - } - - if err := apiServer.Start(ctx); err != nil { - server.serverErr <- err - return - } - }() - // Wait for server to be ready if err := server.WaitForReady(); err != nil { - server.Stop() + _ = server.Stop() return nil, err } return server, nil } -// WaitForReady waits for the API server to be ready to accept requests +// WaitForReady waits for the API server to be ready to accept requests. func (s *Server) WaitForReady() error { ctx, cancel := context.WithTimeout(context.Background(), s.config.StartTimeout) defer cancel() @@ -120,9 +127,9 @@ func (s *Server) WaitForReady() error { for { select { case <-ctx.Done(): - return fmt.Errorf("timeout waiting for API server to be ready") - case err := <-s.serverErr: - return fmt.Errorf("API server failed to start: %w", err) + // Include server logs in the error message for debugging + return fmt.Errorf("timeout waiting for API server to be ready on port %d.\nStdout: %s\nStderr: %s", + s.port, s.stdout.String(), s.stderr.String()) case <-ticker.C: // Try to connect to the health endpoint req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.baseURL+"/health", nil) @@ -136,7 +143,7 @@ func (s *Server) WaitForReady() error { } _ = resp.Body.Close() - // Server is ready if we get the expected response. + // Server is ready if we get the expected response if resp.StatusCode == http.StatusNoContent { return nil } @@ -144,14 +151,21 @@ func (s *Server) WaitForReady() error { } } -// Stop stops the API server -func (s *Server) Stop() { - s.cancel() - // Wait for server to shut down gracefully - <-s.done +// Stop stops the API server subprocess. +func (s *Server) Stop() error { + if s.cancel != nil { + s.cancel() + } + + if s.cmd != nil && s.cmd.Process != nil { + // Wait for the process to exit + _ = s.cmd.Wait() + } + + return nil } -// Get performs a GET request to the specified path +// Get performs a GET request to the specified path. func (s *Server) Get(path string) (*http.Response, error) { req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.baseURL+path, nil) if err != nil { @@ -160,7 +174,7 @@ func (s *Server) Get(path string) (*http.Response, error) { return s.httpClient.Do(req) } -// GetWithHeaders performs a GET request with custom headers +// GetWithHeaders performs a GET request with custom headers. func (s *Server) GetWithHeaders(path string, headers map[string]string) (*http.Response, error) { req, err := http.NewRequestWithContext(s.ctx, http.MethodGet, s.baseURL+path, nil) if err != nil { @@ -174,7 +188,7 @@ func (s *Server) GetWithHeaders(path string, headers map[string]string) (*http.R return s.httpClient.Do(req) } -// BaseURL returns the base URL of the API server +// BaseURL returns the base URL of the API server. func (s *Server) BaseURL() string { return s.baseURL } @@ -187,7 +201,7 @@ func StartServer(config *ServerConfig) *Server { // Register cleanup DeferCleanup(func() { - server.Stop() + _ = server.Stop() }) return server From a48547e7ff65d2884872990f15491b2c83eb657f Mon Sep 17 00:00:00 2001 From: Nigel Brown Date: Mon, 19 Jan 2026 12:49:22 +0000 Subject: [PATCH 12/16] fix: Resolve tool names in optim.find_tool to match routing table (#3337) * fix: Resolve tool names in optim.find_tool to match routing table --- pkg/vmcp/discovery/middleware_test.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pkg/vmcp/discovery/middleware_test.go b/pkg/vmcp/discovery/middleware_test.go index 8594b89c29..0bf4b03f82 100644 --- a/pkg/vmcp/discovery/middleware_test.go +++ b/pkg/vmcp/discovery/middleware_test.go @@ -347,7 +347,17 @@ func TestMiddleware_CapabilitiesInContext(t *testing.T) { // Use Do to capture and verify backends separately, since order may vary mockMgr.EXPECT(). - Discover(gomock.Any(), unorderedBackendsMatcher{backends}). + Discover(gomock.Any(), gomock.Any()). + Do(func(_ context.Context, actualBackends []vmcp.Backend) { + // Verify that we got the expected backends regardless of order + assert.Len(t, actualBackends, 2) + backendIDs := make(map[string]bool) + for _, b := range actualBackends { + backendIDs[b.ID] = true + } + assert.True(t, backendIDs["backend1"], "backend1 should be present") + assert.True(t, backendIDs["backend2"], "backend2 should be present") + }). Return(expectedCaps, nil) // Create handler that inspects context in detail From 094fba50d1f0d55b29c1d95cfdfd849e91235f63 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 20 Jan 2026 18:56:18 +0000 Subject: [PATCH 13/16] feat: Add DeepCopy and Kubernetes service resolution for optimizer config - Use DeepCopy() for automatic passthrough of config fields (Optimizer, Metadata, etc.) - Add resolveEmbeddingService() to resolve Kubernetes Service names to URLs - Ensures optimizer config is properly converted from CRD to runtime config - Resolves embeddingService references in Kubernetes deployments --- cmd/thv-operator/pkg/vmcpconfig/converter.go | 65 ++++++++++++++++++-- 1 file changed, 60 insertions(+), 5 deletions(-) diff --git a/cmd/thv-operator/pkg/vmcpconfig/converter.go b/cmd/thv-operator/pkg/vmcpconfig/converter.go index 000d306970..4b357e0e49 100644 --- a/cmd/thv-operator/pkg/vmcpconfig/converter.go +++ b/cmd/thv-operator/pkg/vmcpconfig/converter.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/go-logr/logr" + corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" @@ -56,15 +57,23 @@ func NewConverter(oidcResolver oidc.Resolver, k8sClient client.Client) (*Convert }, nil } -// Convert converts VirtualMCPServer CRD spec to vmcp Config +// Convert converts VirtualMCPServer CRD spec to vmcp Config. +// +// The conversion starts with a DeepCopy of the embedded config.Config from the CRD spec. +// This ensures that simple fields (like Optimizer, Metadata, etc.) are automatically +// passed through without explicit mapping. Only fields that require special handling +// (auth, aggregation, composite tools, telemetry) are explicitly converted below. func (c *Converter) Convert( ctx context.Context, vmcp *mcpv1alpha1.VirtualMCPServer, ) (*vmcpconfig.Config, error) { - config := &vmcpconfig.Config{ - Name: vmcp.Name, - Group: vmcp.Spec.Config.Group, - } + // Start with a deep copy of the embedded config for automatic field passthrough. + // This ensures new fields added to config.Config are automatically included + // without requiring explicit mapping in this converter. + config := vmcp.Spec.Config.DeepCopy() + + // Override name with the CR name (authoritative source) + config.Name = vmcp.Name // Convert IncomingAuth - required field, no defaults if vmcp.Spec.IncomingAuth != nil { @@ -132,6 +141,24 @@ func (c *Converter) Convert( config.Audit.Component = vmcp.Name } + // Convert optimizer config - resolve embeddingService to embeddingURL if needed + if vmcp.Spec.Config.Optimizer != nil { + optimizerConfig := vmcp.Spec.Config.Optimizer.DeepCopy() + + // If embeddingService is set, resolve it to embeddingURL + if optimizerConfig.EmbeddingService != "" && optimizerConfig.EmbeddingURL == "" { + embeddingURL, err := c.resolveEmbeddingService(ctx, vmcp.Namespace, optimizerConfig.EmbeddingService) + if err != nil { + return nil, fmt.Errorf("failed to resolve embedding service %s: %w", optimizerConfig.EmbeddingService, err) + } + optimizerConfig.EmbeddingURL = embeddingURL + // Clear embeddingService since we've resolved it to URL + optimizerConfig.EmbeddingService = "" + } + + config.Optimizer = optimizerConfig + } + // Apply operational defaults (fills missing values) config.EnsureOperationalDefaults() @@ -597,3 +624,31 @@ func validateCompositeToolNames(tools []vmcpconfig.CompositeToolConfig) error { } return nil } + +// resolveEmbeddingService resolves a Kubernetes service name to its URL by querying the service. +// Returns the service URL in format: http://..svc.cluster.local: +func (c *Converter) resolveEmbeddingService(ctx context.Context, namespace, serviceName string) (string, error) { + // Get the service + svc := &corev1.Service{} + key := types.NamespacedName{ + Name: serviceName, + Namespace: namespace, + } + if err := c.k8sClient.Get(ctx, key, svc); err != nil { + return "", fmt.Errorf("failed to get service %s/%s: %w", namespace, serviceName, err) + } + + // Find the first port (typically there's only one for embedding services) + if len(svc.Spec.Ports) == 0 { + return "", fmt.Errorf("service %s/%s has no ports", namespace, serviceName) + } + + port := svc.Spec.Ports[0].Port + if port == 0 { + return "", fmt.Errorf("service %s/%s has invalid port", namespace, serviceName) + } + + // Construct URL using full DNS name + url := fmt.Sprintf("http://%s.%s.svc.cluster.local:%d", serviceName, namespace, port) + return url, nil +} From 4ac86db7448dcb55cd77e674196b4bc9eea04ccc Mon Sep 17 00:00:00 2001 From: nigel brown Date: Tue, 20 Jan 2026 18:57:33 +0000 Subject: [PATCH 14/16] fix: Add remaining Kubernetes optimizer integration fixes from PR #3359 - Add CLI fallback for embeddingService when not resolved by operator - Normalize localhost to 127.0.0.1 in embeddings to avoid IPv6 issues - Add HTTP timeout (30s) to prevent hanging connections - Remove WithContinuousListening() to use timeout-based approach --- cmd/vmcp/app/commands.go | 14 ++++++++++++++ pkg/optimizer/embeddings/manager.go | 8 ++++++-- pkg/optimizer/embeddings/ollama.go | 13 ++++++++++++- pkg/vmcp/client/client.go | 6 ++++-- 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/cmd/vmcp/app/commands.go b/cmd/vmcp/app/commands.go index 5bbb21e635..aa7bc40948 100644 --- a/cmd/vmcp/app/commands.go +++ b/cmd/vmcp/app/commands.go @@ -434,6 +434,20 @@ func runServe(cmd *cobra.Command, _ []string) error { if cfg.Optimizer.HybridSearchRatio != nil { hybridRatio = *cfg.Optimizer.HybridSearchRatio } + + // embeddingURL should already be resolved from embeddingService by the operator + // If embeddingService is still set (CLI mode), log a warning + if cfg.Optimizer.EmbeddingService != "" { + logger.Warnf("embeddingService is set but not resolved to embeddingURL. This should be handled by the operator. Falling back to default port 11434") + // Simple fallback for CLI/testing scenarios + namespace := os.Getenv("POD_NAMESPACE") + if namespace != "" { + cfg.Optimizer.EmbeddingURL = fmt.Sprintf("http://%s.%s.svc.cluster.local:11434", cfg.Optimizer.EmbeddingService, namespace) + } else { + cfg.Optimizer.EmbeddingURL = fmt.Sprintf("http://%s:11434", cfg.Optimizer.EmbeddingService) + } + } + serverCfg.OptimizerConfig = &vmcpserver.OptimizerConfig{ Enabled: cfg.Optimizer.Enabled, PersistPath: cfg.Optimizer.PersistPath, diff --git a/pkg/optimizer/embeddings/manager.go b/pkg/optimizer/embeddings/manager.go index 70ac838492..5264112c53 100644 --- a/pkg/optimizer/embeddings/manager.go +++ b/pkg/optimizer/embeddings/manager.go @@ -2,6 +2,7 @@ package embeddings import ( "fmt" + "strings" "sync" "github.com/stacklok/toolhive/pkg/logger" @@ -24,7 +25,7 @@ type Config struct { BackendType string // BaseURL is the base URL for the embedding service - // - Ollama: http://localhost:11434 + // - Ollama: http://127.0.0.1:11434 (or http://localhost:11434, will be normalized to 127.0.0.1) // - vLLM: http://localhost:8000 BaseURL string @@ -84,7 +85,10 @@ func NewManager(config *Config) (*Manager, error) { // Use Ollama native API (requires ollama serve) baseURL := config.BaseURL if baseURL == "" { - baseURL = "http://localhost:11434" + baseURL = "http://127.0.0.1:11434" + } else { + // Normalize localhost to 127.0.0.1 to avoid IPv6 resolution issues + baseURL = strings.ReplaceAll(baseURL, "localhost", "127.0.0.1") } model := config.Model if model == "" { diff --git a/pkg/optimizer/embeddings/ollama.go b/pkg/optimizer/embeddings/ollama.go index a05af2af11..9d6887375a 100644 --- a/pkg/optimizer/embeddings/ollama.go +++ b/pkg/optimizer/embeddings/ollama.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "strings" "github.com/stacklok/toolhive/pkg/logger" ) @@ -29,12 +30,22 @@ type ollamaEmbedResponse struct { Embedding []float64 `json:"embedding"` } +// normalizeLocalhostURL converts localhost to 127.0.0.1 to avoid IPv6 resolution issues +func normalizeLocalhostURL(url string) string { + // Replace localhost with 127.0.0.1 to ensure IPv4 connection + // This prevents connection refused errors when Ollama only listens on IPv4 + return strings.ReplaceAll(url, "localhost", "127.0.0.1") +} + // NewOllamaBackend creates a new Ollama backend // Requires Ollama to be running locally: ollama serve // Default model: all-minilm (all-MiniLM-L6-v2, 384 dimensions) func NewOllamaBackend(baseURL, model string) (*OllamaBackend, error) { if baseURL == "" { - baseURL = "http://localhost:11434" + baseURL = "http://127.0.0.1:11434" + } else { + // Normalize localhost to 127.0.0.1 to avoid IPv6 resolution issues + baseURL = normalizeLocalhostURL(baseURL) } if model == "" { model = "all-minilm" // Default embedding model (all-MiniLM-L6-v2) diff --git a/pkg/vmcp/client/client.go b/pkg/vmcp/client/client.go index e99533a83a..dcff9e7ee2 100644 --- a/pkg/vmcp/client/client.go +++ b/pkg/vmcp/client/client.go @@ -12,6 +12,7 @@ import ( "io" "net" "net/http" + "time" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" @@ -198,8 +199,10 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm }) // Create HTTP client with configured transport chain + // Set timeouts to prevent long-lived connections that require continuous listening httpClient := &http.Client{ Transport: sizeLimitedTransport, + Timeout: 30 * time.Second, // Prevent hanging on connections } var c *client.Client @@ -208,8 +211,7 @@ func (h *httpBackendClient) defaultClientFactory(ctx context.Context, target *vm case "streamable-http", "streamable": c, err = client.NewStreamableHttpClient( target.BaseURL, - transport.WithHTTPTimeout(0), - transport.WithContinuousListening(), + transport.WithHTTPTimeout(30*time.Second), // Set timeout instead of 0 transport.WithHTTPBasicClient(httpClient), ) if err != nil { From ae5d8f3fe8cf1e21ccc8f46269ea4a48f92400a0 Mon Sep 17 00:00:00 2001 From: nigel brown Date: Wed, 21 Jan 2026 12:31:10 +0000 Subject: [PATCH 15/16] Add OpenTelemetry tracing to capability aggregation Add tracing spans to all aggregator methods to enable visibility of capability aggregation in Jaeger. This includes spans for: - AggregateCapabilities (parent span) - QueryAllCapabilities (parallel backend queries) - QueryCapabilities (per-backend queries) - ResolveConflicts (conflict resolution) - MergeCapabilities (final merge) All spans include relevant attributes like backend counts, tool/resource/prompt counts, and error recording. This fixes the issue where capability aggregation logs appeared but no spans were visible in Jaeger. --- pkg/vmcp/aggregator/default_aggregator.go | 93 ++++++++++++++++++++++- 1 file changed, 91 insertions(+), 2 deletions(-) diff --git a/pkg/vmcp/aggregator/default_aggregator.go b/pkg/vmcp/aggregator/default_aggregator.go index 19abf71b5d..dc2a49a4b3 100644 --- a/pkg/vmcp/aggregator/default_aggregator.go +++ b/pkg/vmcp/aggregator/default_aggregator.go @@ -5,6 +5,10 @@ import ( "fmt" "sync" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "github.com/stacklok/toolhive/pkg/logger" @@ -18,6 +22,7 @@ type defaultAggregator struct { backendClient vmcp.BackendClient conflictResolver ConflictResolver toolConfigMap map[string]*config.WorkloadToolConfig // Maps backend ID to tool config + tracer trace.Tracer } // NewDefaultAggregator creates a new default aggregator implementation. @@ -40,12 +45,20 @@ func NewDefaultAggregator( backendClient: backendClient, conflictResolver: conflictResolver, toolConfigMap: toolConfigMap, + tracer: otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/aggregator"), } } // QueryCapabilities queries a single backend for its MCP capabilities. // Returns the raw capabilities (tools, resources, prompts) from the backend. func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp.Backend) (*BackendCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.QueryCapabilities", + trace.WithAttributes( + attribute.String("backend.id", backend.ID), + ), + ) + defer span.End() + logger.Debugf("Querying capabilities from backend %s", backend.ID) // Create a BackendTarget from the Backend @@ -55,6 +68,8 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. // Query capabilities using the backend client capabilities, err := a.backendClient.ListCapabilities(ctx, target) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("%w: %s: %w", ErrBackendQueryFailed, backend.ID, err) } @@ -71,6 +86,12 @@ func (a *defaultAggregator) QueryCapabilities(ctx context.Context, backend vmcp. SupportsSampling: capabilities.SupportsSampling, } + span.SetAttributes( + attribute.Int("tools.count", len(result.Tools)), + attribute.Int("resources.count", len(result.Resources)), + attribute.Int("prompts.count", len(result.Prompts)), + ) + logger.Debugf("Backend %s: %d tools (after filtering/overrides), %d resources, %d prompts", backend.ID, len(result.Tools), len(result.Resources), len(result.Prompts)) @@ -83,6 +104,13 @@ func (a *defaultAggregator) QueryAllCapabilities( ctx context.Context, backends []vmcp.Backend, ) (map[string]*BackendCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.QueryAllCapabilities", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + ), + ) + defer span.End() + logger.Infof("Querying capabilities from %d backends", len(backends)) // Use errgroup for parallel queries with context cancellation @@ -115,13 +143,22 @@ func (a *defaultAggregator) QueryAllCapabilities( // Wait for all queries to complete if err := g.Wait(); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("capability queries failed: %w", err) } if len(capabilities) == 0 { - return nil, fmt.Errorf("no backends returned capabilities") + err := fmt.Errorf("no backends returned capabilities") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return nil, err } + span.SetAttributes( + attribute.Int("successful.backends", len(capabilities)), + ) + logger.Infof("Successfully queried %d/%d backends", len(capabilities), len(backends)) return capabilities, nil } @@ -132,6 +169,13 @@ func (a *defaultAggregator) ResolveConflicts( ctx context.Context, capabilities map[string]*BackendCapabilities, ) (*ResolvedCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.ResolveConflicts", + trace.WithAttributes( + attribute.Int("backends.count", len(capabilities)), + ), + ) + defer span.End() + logger.Debugf("Resolving conflicts across %d backends", len(capabilities)) // Group tools by backend for conflict resolution @@ -147,6 +191,8 @@ func (a *defaultAggregator) ResolveConflicts( if a.conflictResolver != nil { resolvedTools, err = a.conflictResolver.ResolveToolConflicts(ctx, toolsByBackend) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("conflict resolution failed: %w", err) } } else { @@ -188,6 +234,12 @@ func (a *defaultAggregator) ResolveConflicts( resolved.SupportsSampling = resolved.SupportsSampling || caps.SupportsSampling } + span.SetAttributes( + attribute.Int("resolved.tools", len(resolved.Tools)), + attribute.Int("resolved.resources", len(resolved.Resources)), + attribute.Int("resolved.prompts", len(resolved.Prompts)), + ) + logger.Debugf("Resolved %d unique tools, %d resources, %d prompts", len(resolved.Tools), len(resolved.Resources), len(resolved.Prompts)) @@ -196,11 +248,20 @@ func (a *defaultAggregator) ResolveConflicts( // MergeCapabilities creates the final unified capability view and routing table. // Uses the backend registry to populate full BackendTarget information for routing. -func (*defaultAggregator) MergeCapabilities( +func (a *defaultAggregator) MergeCapabilities( ctx context.Context, resolved *ResolvedCapabilities, registry vmcp.BackendRegistry, ) (*AggregatedCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.MergeCapabilities", + trace.WithAttributes( + attribute.Int("resolved.tools", len(resolved.Tools)), + attribute.Int("resolved.resources", len(resolved.Resources)), + attribute.Int("resolved.prompts", len(resolved.Prompts)), + ), + ) + defer span.End() + logger.Debugf("Merging capabilities into final view") // Create routing table @@ -301,6 +362,13 @@ func (*defaultAggregator) MergeCapabilities( }, } + span.SetAttributes( + attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount), + attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount), + attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount), + attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)), + ) + logger.Infof("Merged capabilities: %d tools, %d resources, %d prompts", aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) @@ -313,6 +381,13 @@ func (*defaultAggregator) MergeCapabilities( // 3. Resolve conflicts // 4. Merge into final view with full backend information func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends []vmcp.Backend) (*AggregatedCapabilities, error) { + ctx, span := a.tracer.Start(ctx, "aggregator.AggregateCapabilities", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + ), + ) + defer span.End() + logger.Infof("Starting capability aggregation for %d backends", len(backends)) // Step 1: Create registry from discovered backends @@ -322,24 +397,38 @@ func (a *defaultAggregator) AggregateCapabilities(ctx context.Context, backends // Step 2: Query all backends capabilities, err := a.QueryAllCapabilities(ctx, backends) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to query backends: %w", err) } // Step 3: Resolve conflicts resolved, err := a.ResolveConflicts(ctx, capabilities) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to resolve conflicts: %w", err) } // Step 4: Merge into final view with full backend information aggregated, err := a.MergeCapabilities(ctx, resolved, registry) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) return nil, fmt.Errorf("failed to merge capabilities: %w", err) } // Update metadata with backend count aggregated.Metadata.BackendCount = len(backends) + span.SetAttributes( + attribute.Int("aggregated.backends", aggregated.Metadata.BackendCount), + attribute.Int("aggregated.tools", aggregated.Metadata.ToolCount), + attribute.Int("aggregated.resources", aggregated.Metadata.ResourceCount), + attribute.Int("aggregated.prompts", aggregated.Metadata.PromptCount), + attribute.String("conflict.strategy", string(aggregated.Metadata.ConflictStrategy)), + ) + logger.Infof("Capability aggregation complete: %d backends, %d tools, %d resources, %d prompts", aggregated.Metadata.BackendCount, aggregated.Metadata.ToolCount, aggregated.Metadata.ResourceCount, aggregated.Metadata.PromptCount) From 22e020c2746e542df5ae36c10bd11e1d61114a6c Mon Sep 17 00:00:00 2001 From: nigel brown Date: Wed, 21 Jan 2026 15:04:11 +0000 Subject: [PATCH 16/16] Fix unrecognized dotty names Signed-off-by: nigel brown --- pkg/vmcp/config/config.go | 4 +- .../find_tool_semantic_search_test.go | 8 +- .../find_tool_string_matching_test.go | 6 +- pkg/vmcp/optimizer/optimizer.go | 160 ++++++++++++++++-- pkg/vmcp/optimizer/optimizer_handlers_test.go | 40 ++--- .../optimizer/optimizer_integration_test.go | 6 +- pkg/vmcp/router/default_router.go | 6 +- pkg/vmcp/server/mocks/mock_watcher.go | 28 +++ pkg/vmcp/server/server.go | 55 ++++-- 9 files changed, 245 insertions(+), 68 deletions(-) diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go index 2f05902b4d..239e4a6c34 100644 --- a/pkg/vmcp/config/config.go +++ b/pkg/vmcp/config/config.go @@ -148,7 +148,7 @@ type Config struct { Audit *audit.Config `json:"audit,omitempty" yaml:"audit,omitempty"` // Optimizer configures the MCP optimizer for context optimization on large toolsets. - // When enabled, vMCP exposes optim.find_tool and optim.call_tool operations to clients + // When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients // instead of all backend tools directly. This reduces token usage by allowing // LLMs to discover relevant tools on demand rather than receiving all tool definitions. // +optional @@ -700,7 +700,7 @@ type OutputProperty struct { // +gendoc type OptimizerConfig struct { // Enabled determines whether the optimizer is active. - // When true, vMCP exposes optim.find_tool and optim.call_tool instead of all backend tools. + // When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools. // +optional Enabled bool `json:"enabled" yaml:"enabled"` diff --git a/pkg/vmcp/optimizer/find_tool_semantic_search_test.go b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go index a539937fe9..817c11eb8b 100644 --- a/pkg/vmcp/optimizer/find_tool_semantic_search_test.go +++ b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go @@ -272,7 +272,7 @@ func TestFindTool_SemanticSearch(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": tc.query, "tool_keywords": tc.keywords, @@ -472,7 +472,7 @@ func TestFindTool_SemanticVsKeyword(t *testing.T) { // Test semantic search requestSemantic := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": query, "tool_keywords": "", @@ -489,7 +489,7 @@ func TestFindTool_SemanticVsKeyword(t *testing.T) { // Test keyword search requestKeyword := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": query, "tool_keywords": "", @@ -647,7 +647,7 @@ func TestFindTool_SemanticSimilarityScores(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": query, "tool_keywords": "", diff --git a/pkg/vmcp/optimizer/find_tool_string_matching_test.go b/pkg/vmcp/optimizer/find_tool_string_matching_test.go index b994d7b95d..d144a69b51 100644 --- a/pkg/vmcp/optimizer/find_tool_string_matching_test.go +++ b/pkg/vmcp/optimizer/find_tool_string_matching_test.go @@ -286,7 +286,7 @@ func TestFindTool_StringMatching(t *testing.T) { // Create the tool call request request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": tc.query, "tool_keywords": tc.keywords, @@ -506,7 +506,7 @@ func TestFindTool_ExactStringMatch(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": tc.query, "tool_keywords": tc.keywords, @@ -651,7 +651,7 @@ func TestFindTool_CaseInsensitive(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": query, "tool_keywords": strings.ToLower(query), diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index 03e32ce5d3..225f8374dd 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -1,8 +1,8 @@ // Package optimizer provides vMCP integration for semantic tool discovery. // // This package implements the RFC-0022 optimizer integration, exposing: -// - optim.find_tool: Semantic/keyword-based tool discovery -// - optim.call_tool: Dynamic tool invocation across backends +// - optim_find_tool: Semantic/keyword-based tool discovery +// - optim_call_tool: Dynamic tool invocation across backends // // Architecture: // - Embeddings are generated during session initialization (OnRegisterSession hook) @@ -110,7 +110,7 @@ func NewIntegration( // This hook: // 1. Extracts backend tools from discovered capabilities // 2. Generates embeddings for all tools (parallel per-backend) -// 3. Registers optim.find_tool and optim.call_tool as session tools + // 3. Registers optim_find_tool and optim_call_tool as session tools func (o *OptimizerIntegration) OnRegisterSession( _ context.Context, session server.ClientSession, @@ -140,7 +140,76 @@ func (o *OptimizerIntegration) OnRegisterSession( return nil } +// RegisterGlobalTools registers optimizer tools globally (available to all sessions). +// This should be called during server initialization, before any sessions are created. +// Registering tools globally ensures they are immediately available when clients connect, +// avoiding timing issues where list_tools is called before per-session registration completes. +func (o *OptimizerIntegration) RegisterGlobalTools() error { + if o == nil { + return nil // Optimizer not enabled + } + + // Define optimizer tools with handlers + findToolHandler := o.createFindToolHandler() + callToolHandler := o.CreateCallToolHandler() + + // Register optim_find_tool globally + o.mcpServer.AddTool(mcp.Tool{ + Name: "optim_find_tool", + Description: "Semantic search across all backend tools using natural language description and optional keywords", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "tool_description": map[string]any{ + "type": "string", + "description": "Natural language description of the tool you're looking for", + }, + "tool_keywords": map[string]any{ + "type": "string", + "description": "Optional space-separated keywords for keyword-based search", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of tools to return (default: 10)", + "default": 10, + }, + }, + Required: []string{"tool_description"}, + }, + }, findToolHandler) + + // Register optim_call_tool globally + o.mcpServer.AddTool(mcp.Tool { + Name: "optim_call_tool", + Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "backend_id": map[string]any{ + "type": "string", + "description": "Backend ID from find_tool results", + }, + "tool_name": map[string]any{ + "type": "string", + "description": "Tool name to invoke", + }, + "parameters": map[string]any{ + "type": "object", + "description": "Parameters to pass to the tool", + }, + }, + Required: []string{"backend_id", "tool_name", "parameters"}, + }, + }, callToolHandler) + + logger.Info("Optimizer tools registered globally (optim_find_tool, optim_call_tool)") + return nil +} + // RegisterTools adds optimizer tools to the session. +// Even though tools are registered globally via RegisterGlobalTools(), +// with WithToolCapabilities(false), we also need to register them per-session +// to ensure they appear in list_tools responses. // This should be called after OnRegisterSession completes. func (o *OptimizerIntegration) RegisterTools(_ context.Context, session server.ClientSession) error { if o == nil { @@ -149,11 +218,11 @@ func (o *OptimizerIntegration) RegisterTools(_ context.Context, session server.C sessionID := session.SessionID() - // Define optimizer tools with handlers + // Define optimizer tools with handlers (same as global registration) optimizerTools := []server.ServerTool{ { Tool: mcp.Tool{ - Name: "optim.find_tool", + Name: "optim_find_tool", Description: "Semantic search across all backend tools using natural language description and optional keywords", InputSchema: mcp.ToolInputSchema{ Type: "object", @@ -179,7 +248,7 @@ func (o *OptimizerIntegration) RegisterTools(_ context.Context, session server.C }, { Tool: mcp.Tool{ - Name: "optim.call_tool", + Name: "optim_call_tool", Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool", InputSchema: mcp.ToolInputSchema{ Type: "object", @@ -204,16 +273,71 @@ func (o *OptimizerIntegration) RegisterTools(_ context.Context, session server.C }, } - // Add tools to session + // Add tools to session (required when WithToolCapabilities(false)) if err := o.mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil { return fmt.Errorf("failed to add optimizer tools to session: %w", err) } - logger.Debugw("Optimizer tools registered", "session_id", sessionID) + logger.Debugw("Optimizer tools registered for session", "session_id", sessionID) return nil } -// CreateFindToolHandler creates the handler for optim.find_tool +// GetOptimizerToolDefinitions returns the tool definitions for optimizer tools +// without handlers. This is useful for adding tools to capabilities before session registration. +func (o *OptimizerIntegration) GetOptimizerToolDefinitions() []mcp.Tool { + if o == nil { + return nil + } + return []mcp.Tool{ + { + Name: "optim_find_tool", + Description: "Semantic search across all backend tools using natural language description and optional keywords", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "tool_description": map[string]any{ + "type": "string", + "description": "Natural language description of the tool you're looking for", + }, + "tool_keywords": map[string]any{ + "type": "string", + "description": "Optional space-separated keywords for keyword-based search", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of tools to return (default: 10)", + "default": 10, + }, + }, + Required: []string{"tool_description"}, + }, + }, + { + Name: "optim_call_tool", + Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "backend_id": map[string]any{ + "type": "string", + "description": "Backend ID from find_tool results", + }, + "tool_name": map[string]any{ + "type": "string", + "description": "Tool name to invoke", + }, + "parameters": map[string]any{ + "type": "object", + "description": "Parameters to pass to the tool", + }, + }, + Required: []string{"backend_id", "tool_name", "parameters"}, + }, + }, + } +} + +// CreateFindToolHandler creates the handler for optim_find_tool // Exported for testing purposes func (o *OptimizerIntegration) CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return o.createFindToolHandler() @@ -335,10 +459,10 @@ func convertSearchResultsToResponse( return responseTools, totalReturnedTokens } -// createFindToolHandler creates the handler for optim.find_tool +// createFindToolHandler creates the handler for optim_find_tool func (o *OptimizerIntegration) createFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - logger.Debugw("optim.find_tool called", "request", request) + logger.Debugw("optim_find_tool called", "request", request) // Extract parameters from request arguments args, ok := request.Params.Arguments.(map[string]any) @@ -423,7 +547,7 @@ func (o *OptimizerIntegration) createFindToolHandler() func(context.Context, mcp return mcp.NewToolResultError(fmt.Sprintf("failed to marshal response: %v", err3)), nil } - logger.Infow("optim.find_tool completed", + logger.Infow("optim_find_tool completed", "query", toolDescription, "results_count", len(responseTools), "tokens_saved", tokensSaved, @@ -456,7 +580,7 @@ func (*OptimizerIntegration) recordTokenMetrics( returnedCounter, err := meter.Int64Counter( "toolhive_vmcp_optimizer_returned_tokens", - metric.WithDescription("Total tokens for tools returned by optim.find_tool"), + metric.WithDescription("Total tokens for tools returned by optim_find_tool"), ) if err != nil { logger.Debugw("Failed to create returned_tokens counter", "error", err) @@ -465,7 +589,7 @@ func (*OptimizerIntegration) recordTokenMetrics( savedCounter, err := meter.Int64Counter( "toolhive_vmcp_optimizer_tokens_saved", - metric.WithDescription("Number of tokens saved by filtering tools with optim.find_tool"), + metric.WithDescription("Number of tokens saved by filtering tools with optim_find_tool"), ) if err != nil { logger.Debugw("Failed to create tokens_saved counter", "error", err) @@ -499,16 +623,16 @@ func (*OptimizerIntegration) recordTokenMetrics( "savings_percentage", savingsPercentage) } -// CreateCallToolHandler creates the handler for optim.call_tool +// CreateCallToolHandler creates the handler for optim_call_tool // Exported for testing purposes func (o *OptimizerIntegration) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return o.createCallToolHandler() } -// createCallToolHandler creates the handler for optim.call_tool +// createCallToolHandler creates the handler for optim_call_tool func (o *OptimizerIntegration) createCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - logger.Debugw("optim.call_tool called", "request", request) + logger.Debugw("optim_call_tool called", "request", request) // Extract parameters from request arguments args, ok := request.Params.Arguments.(map[string]any) @@ -587,7 +711,7 @@ func (o *OptimizerIntegration) createCallToolHandler() func(context.Context, mcp return mcp.NewToolResultError(fmt.Sprintf("failed to marshal result: %v", err)), nil } - logger.Infow("optim.call_tool completed successfully", + logger.Infow("optim_call_tool completed successfully", "backend_id", backendID, "tool_name", toolName) diff --git a/pkg/vmcp/optimizer/optimizer_handlers_test.go b/pkg/vmcp/optimizer/optimizer_handlers_test.go index 3889a47e37..aa9146c058 100644 --- a/pkg/vmcp/optimizer/optimizer_handlers_test.go +++ b/pkg/vmcp/optimizer/optimizer_handlers_test.go @@ -110,7 +110,7 @@ func TestCreateFindToolHandler_InvalidArguments(t *testing.T) { // Test with invalid arguments type request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: "not a map", }, } @@ -122,7 +122,7 @@ func TestCreateFindToolHandler_InvalidArguments(t *testing.T) { // Test with missing tool_description request = mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "limit": 10, }, @@ -136,7 +136,7 @@ func TestCreateFindToolHandler_InvalidArguments(t *testing.T) { // Test with empty tool_description request = mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": "", }, @@ -150,7 +150,7 @@ func TestCreateFindToolHandler_InvalidArguments(t *testing.T) { // Test with non-string tool_description request = mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": 123, }, @@ -217,7 +217,7 @@ func TestCreateFindToolHandler_WithKeywords(t *testing.T) { // Test with keywords request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": "search tool", "tool_keywords": "test search", @@ -289,7 +289,7 @@ func TestCreateFindToolHandler_Limit(t *testing.T) { // Test with custom limit request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": "test", "limit": 5, @@ -304,7 +304,7 @@ func TestCreateFindToolHandler_Limit(t *testing.T) { // Test with float64 limit (from JSON) request = mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": "test", "limit": float64(3), @@ -332,7 +332,7 @@ func TestCreateFindToolHandler_BackendToolOpsNil(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": "test", }, @@ -389,7 +389,7 @@ func TestCreateCallToolHandler_InvalidArguments(t *testing.T) { // Test with invalid arguments type request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: "not a map", }, } @@ -401,7 +401,7 @@ func TestCreateCallToolHandler_InvalidArguments(t *testing.T) { // Test with missing backend_id request = mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "tool_name": "test_tool", "parameters": map[string]any{}, @@ -416,7 +416,7 @@ func TestCreateCallToolHandler_InvalidArguments(t *testing.T) { // Test with empty backend_id request = mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "backend_id": "", "tool_name": "test_tool", @@ -432,7 +432,7 @@ func TestCreateCallToolHandler_InvalidArguments(t *testing.T) { // Test with missing tool_name request = mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "backend_id": "backend-1", "parameters": map[string]any{}, @@ -447,7 +447,7 @@ func TestCreateCallToolHandler_InvalidArguments(t *testing.T) { // Test with missing parameters request = mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "backend_id": "backend-1", "tool_name": "test_tool", @@ -462,7 +462,7 @@ func TestCreateCallToolHandler_InvalidArguments(t *testing.T) { // Test with invalid parameters type request = mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "backend_id": "backend-1", "tool_name": "test_tool", @@ -521,7 +521,7 @@ func TestCreateCallToolHandler_NoRoutingTable(t *testing.T) { // Test without routing table in context request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "backend_id": "backend-1", "tool_name": "test_tool", @@ -590,7 +590,7 @@ func TestCreateCallToolHandler_ToolNotFound(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "backend_id": "backend-1", "tool_name": "nonexistent_tool", @@ -664,7 +664,7 @@ func TestCreateCallToolHandler_BackendMismatch(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "backend_id": "backend-1", // Requesting backend-1 "tool_name": "test_tool", // But tool belongs to backend-2 @@ -745,7 +745,7 @@ func TestCreateCallToolHandler_Success(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "backend_id": "backend-1", "tool_name": "test_tool", @@ -834,7 +834,7 @@ func TestCreateCallToolHandler_CallToolError(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.call_tool", + Name: "optim_call_tool", Arguments: map[string]any{ "backend_id": "backend-1", "tool_name": "test_tool", @@ -891,7 +891,7 @@ func TestCreateFindToolHandler_InputSchemaUnmarshalError(t *testing.T) { request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": "test", }, diff --git a/pkg/vmcp/optimizer/optimizer_integration_test.go b/pkg/vmcp/optimizer/optimizer_integration_test.go index 4742de843d..44c1a895e4 100644 --- a/pkg/vmcp/optimizer/optimizer_integration_test.go +++ b/pkg/vmcp/optimizer/optimizer_integration_test.go @@ -303,7 +303,7 @@ func TestOptimizerIntegration_DisabledEmbeddingTime(t *testing.T) { require.NoError(t, err, "Should handle nil integration gracefully") } -// TestOptimizerIntegration_TokenMetrics tests that token metrics are calculated and returned in optim.find_tool +// TestOptimizerIntegration_TokenMetrics tests that token metrics are calculated and returned in optim_find_tool func TestOptimizerIntegration_TokenMetrics(t *testing.T) { t.Parallel() ctx := context.Background() @@ -381,10 +381,10 @@ func TestOptimizerIntegration_TokenMetrics(t *testing.T) { handler := integration.CreateFindToolHandler() require.NotNil(t, handler) - // Call optim.find_tool + // Call optim_find_tool request := mcp.CallToolRequest{ Params: mcp.CallToolParams{ - Name: "optim.find_tool", + Name: "optim_find_tool", Arguments: map[string]any{ "tool_description": "create issue", "limit": 5, diff --git a/pkg/vmcp/router/default_router.go b/pkg/vmcp/router/default_router.go index 7e32731ed9..1b4a2e1a5c 100644 --- a/pkg/vmcp/router/default_router.go +++ b/pkg/vmcp/router/default_router.go @@ -78,15 +78,15 @@ func routeCapability( // instead of using a cached routing table. // // Special handling for optimizer tools: -// - Tools with "optim." prefix (optim.find_tool, optim.call_tool) are handled by vMCP itself +// - Tools with "optim_" prefix (optim_find_tool, optim_call_tool) are handled by vMCP itself // - These tools are registered during session initialization and don't route to backends // - The SDK handles these tools directly via registered handlers func (*defaultRouter) RouteTool(ctx context.Context, toolName string) (*vmcp.BackendTarget, error) { - // Optimizer tools (optim.*) are handled by vMCP itself, not routed to backends. + // Optimizer tools (optim_*) are handled by vMCP itself, not routed to backends. // The SDK will invoke the registered handler directly. // We return ErrToolNotFound here so the handler factory doesn't try to create // a backend routing handler for these tools. - if strings.HasPrefix(toolName, "optim.") { + if strings.HasPrefix(toolName, "optim_") { logger.Debugf("Optimizer tool %s is handled by vMCP, not routed to backend", toolName) return nil, fmt.Errorf("%w: optimizer tool %s is handled by vMCP", ErrToolNotFound, toolName) } diff --git a/pkg/vmcp/server/mocks/mock_watcher.go b/pkg/vmcp/server/mocks/mock_watcher.go index 4044825b14..d88b4144f4 100644 --- a/pkg/vmcp/server/mocks/mock_watcher.go +++ b/pkg/vmcp/server/mocks/mock_watcher.go @@ -123,6 +123,20 @@ func (mr *MockOptimizerIntegrationMockRecorder) OnRegisterSession(ctx, session, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnRegisterSession", reflect.TypeOf((*MockOptimizerIntegration)(nil).OnRegisterSession), ctx, session, capabilities) } +// RegisterGlobalTools mocks base method. +func (m *MockOptimizerIntegration) RegisterGlobalTools() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RegisterGlobalTools") + ret0, _ := ret[0].(error) + return ret0 +} + +// RegisterGlobalTools indicates an expected call of RegisterGlobalTools. +func (mr *MockOptimizerIntegrationMockRecorder) RegisterGlobalTools() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterGlobalTools", reflect.TypeOf((*MockOptimizerIntegration)(nil).RegisterGlobalTools)) +} + // RegisterTools mocks base method. func (m *MockOptimizerIntegration) RegisterTools(ctx context.Context, session server.ClientSession) error { m.ctrl.T.Helper() @@ -136,3 +150,17 @@ func (mr *MockOptimizerIntegrationMockRecorder) RegisterTools(ctx, session any) mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterTools", reflect.TypeOf((*MockOptimizerIntegration)(nil).RegisterTools), ctx, session) } + +// GetOptimizerToolDefinitions mocks base method. +func (m *MockOptimizerIntegration) GetOptimizerToolDefinitions() []mcp.Tool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOptimizerToolDefinitions") + ret0, _ := ret[0].([]mcp.Tool) + return ret0 +} + +// GetOptimizerToolDefinitions indicates an expected call of GetOptimizerToolDefinitions. +func (mr *MockOptimizerIntegrationMockRecorder) GetOptimizerToolDefinitions() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOptimizerToolDefinitions", reflect.TypeOf((*MockOptimizerIntegration)(nil).GetOptimizerToolDefinitions)) +} diff --git a/pkg/vmcp/server/server.go b/pkg/vmcp/server/server.go index 8092268c63..ec1af04919 100644 --- a/pkg/vmcp/server/server.go +++ b/pkg/vmcp/server/server.go @@ -15,6 +15,7 @@ import ( "sync" "time" + "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" "github.com/stacklok/toolhive/pkg/audit" @@ -123,7 +124,7 @@ type Config struct { Watcher Watcher // OptimizerConfig is the optional optimizer configuration. - // If nil or Enabled=false, optimizer tools (optim.find_tool, optim.call_tool) are not available. + // If nil or Enabled=false, optimizer tools (optim_find_tool, optim_call_tool) are not available. OptimizerConfig *OptimizerConfig } @@ -222,7 +223,7 @@ type Server struct { healthMonitor *health.Monitor healthMonitorMu sync.RWMutex - // optimizerIntegration provides semantic tool discovery via optim.find_tool and optim.call_tool. + // optimizerIntegration provides semantic tool discovery via optim_find_tool and optim_call_tool. // Nil if optimizer is disabled. optimizerIntegration OptimizerIntegration } @@ -236,9 +237,20 @@ type OptimizerIntegration interface { // OnRegisterSession generates embeddings for session tools OnRegisterSession(ctx context.Context, session server.ClientSession, capabilities *aggregator.AggregatedCapabilities) error - // RegisterTools adds optim.find_tool and optim.call_tool to the session + // RegisterGlobalTools registers optim_find_tool and optim_call_tool globally (available to all sessions) + // This should be called during server initialization, before any sessions are created. + RegisterGlobalTools() error + + // RegisterTools adds optim_find_tool and optim_call_tool to the session + // Even though tools are registered globally via RegisterGlobalTools(), + // with WithToolCapabilities(false), we also need to register them per-session + // to ensure they appear in list_tools responses. RegisterTools(ctx context.Context, session server.ClientSession) error + // GetOptimizerToolDefinitions returns the tool definitions for optimizer tools without handlers. + // This is useful for adding tools to capabilities before session registration. + GetOptimizerToolDefinitions() []mcp.Tool + // Close cleans up optimizer resources Close() error } @@ -424,6 +436,13 @@ func New( } logger.Info("Optimizer integration initialized successfully") + // Register optimizer tools globally (available to all sessions immediately) + // This ensures tools are available when clients call list_tools, avoiding timing issues + // where list_tools is called before per-session registration completes + if err := optimizerInteg.RegisterGlobalTools(); err != nil { + return nil, fmt.Errorf("failed to register optimizer tools globally: %w", err) + } + // Ingest discovered backends into optimizer database (for semantic search) // Note: Backends are already discovered and registered with vMCP regardless of optimizer // This step indexes them in the optimizer database for semantic search @@ -475,6 +494,21 @@ func New( sessionID := session.SessionID() logger.Debugw("OnRegisterSession hook called", "session_id", sessionID) + // CRITICAL: Register optimizer tools FIRST, before any other processing + // This ensures tools are available immediately when clients call list_tools + // during or immediately after initialize, before other hooks complete + if srv.optimizerIntegration != nil { + if err := srv.optimizerIntegration.RegisterTools(ctx, session); err != nil { + logger.Errorw("failed to register optimizer tools", + "error", err, + "session_id", sessionID) + // Don't fail session initialization - continue without optimizer tools + } else { + logger.Debugw("optimizer tools registered for session (early registration)", + "session_id", sessionID) + } + } + // Get capabilities from context (discovered by middleware) caps, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) if !ok || caps == nil { @@ -532,7 +566,7 @@ func New( "prompt_count", len(caps.RoutingTable.Prompts)) // When optimizer is enabled, we should NOT inject backend tools directly. - // Instead, only optimizer tools (optim.find_tool, optim.call_tool) will be exposed. + // Instead, only optimizer tools (optim_find_tool, optim_call_tool) will be exposed. // Backend tools are still discovered and stored for optimizer ingestion, // but not exposed directly to clients. if srv.optimizerIntegration == nil { @@ -549,17 +583,8 @@ func New( "tool_count", len(caps.Tools), "resource_count", len(caps.Resources)) } else { - // Optimizer is enabled - register optimizer tools FIRST so they're available immediately - // Backend tools will be accessible via optim.find_tool and optim.call_tool - if err := srv.optimizerIntegration.RegisterTools(ctx, session); err != nil { - logger.Errorw("failed to register optimizer tools", - "error", err, - "session_id", sessionID) - // Don't fail session initialization - continue without optimizer tools - } else { - logger.Infow("optimizer tools registered", - "session_id", sessionID) - } + // Optimizer tools already registered above (early registration) + // Backend tools will be accessible via optim_find_tool and optim_call_tool // Inject resources (but not backend tools) if len(caps.Resources) > 0 {