diff --git a/.gitignore b/.gitignore index f0840c001e..55d5cbbc5b 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,12 @@ coverage* crd-helm-wrapper cmd/vmcp/__debug_bin* + +# Demo files +examples/operator/virtual-mcps/vmcp_optimizer.yaml +scripts/k8s_vmcp_optimizer_demo.sh +examples/ingress/mcp-servers-ingress.yaml +examples/vmcp-config-optimizer.yaml +/vmcp +thv-operator +thv diff --git a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml index c5964fe9e3..b2c07ceadd 100644 --- a/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/files/crds/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -642,14 +642,6 @@ spec: - fail - best_effort type: string - statusReportingInterval: - default: 30s - description: |- - StatusReportingInterval is the interval for reporting status updates to Kubernetes. - This controls how often the vMCP runtime reports backend health and phase changes. - Lower values provide faster status updates but increase API server load. - pattern: ^([0-9]+(\.[0-9]+)?(ns|us|µs|ms|s|m|h))+$ - type: string unhealthyThreshold: default: 3 description: UnhealthyThreshold is the number of consecutive @@ -685,17 +677,76 @@ spec: optimizer: description: |- Optimizer configures the MCP optimizer for context optimization on large toolsets. - When enabled, vMCP exposes only find_tool and call_tool operations to clients + When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients instead of all backend tools directly. This reduces token usage by allowing LLMs to discover relevant tools on demand rather than receiving all tool definitions. properties: - embeddingService: + embeddingBackend: description: |- - EmbeddingService is the name of a Kubernetes Service that provides the embedding service - for semantic tool discovery. The service must implement the optimizer embedding API. + EmbeddingBackend specifies the embedding provider: "ollama", "vllm", "unified", or "openai". + - "ollama": Uses local Ollama HTTP API for embeddings + - "vllm": Uses vLLM OpenAI-compatible API (recommended for production Kubernetes deployments) + - "unified": Uses generic OpenAI-compatible API (works with both vLLM and OpenAI) + - "openai": Uses OpenAI-compatible API + enum: + - ollama + - vllm + - unified + - openai + type: string + embeddingDimension: + description: |- + EmbeddingDimension is the dimension of the embedding vectors. + Common values: + - 384: all-MiniLM-L6-v2, nomic-embed-text + - 768: BAAI/bge-small-en-v1.5 + - 1536: OpenAI text-embedding-3-small + minimum: 1 + type: integer + embeddingModel: + description: |- + EmbeddingModel is the model name to use for embeddings. + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "nomic-embed-text", "all-minilm" + - vLLM: "BAAI/bge-small-en-v1.5" + - OpenAI: "text-embedding-3-small" + type: string + embeddingURL: + description: |- + EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API). + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "http://localhost:11434" + - vLLM: "http://vllm-service:8000/v1" + - OpenAI: "https://api.openai.com/v1" + type: string + enabled: + description: |- + Enabled determines whether the optimizer is active. + When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools. + type: boolean + ftsDBPath: + description: |- + FTSDBPath is the path to the SQLite FTS5 database for BM25 text search. + If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + Hybrid search (semantic + BM25) is always enabled. + type: string + hybridSearchRatio: + description: |- + HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search. + Value range: 0 (all BM25) to 100 (all semantic), representing percentage. + Default: 70 (70% semantic, 30% BM25) + Only used when FTSDBPath is set. + maximum: 100 + minimum: 0 + type: integer + persistPath: + description: |- + PersistPath is the optional filesystem path for persisting the chromem-go database. + If empty, the database will be in-memory only (ephemeral). + When set, tool metadata and embeddings are persisted to disk for faster restarts. type: string - required: - - embeddingService type: object outgoingAuth: description: |- diff --git a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml index 6378a74603..d7b2b250e3 100644 --- a/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml +++ b/deploy/charts/operator-crds/templates/toolhive.stacklok.dev_virtualmcpservers.yaml @@ -645,14 +645,6 @@ spec: - fail - best_effort type: string - statusReportingInterval: - default: 30s - description: |- - StatusReportingInterval is the interval for reporting status updates to Kubernetes. - This controls how often the vMCP runtime reports backend health and phase changes. - Lower values provide faster status updates but increase API server load. - pattern: ^([0-9]+(\.[0-9]+)?(ns|us|µs|ms|s|m|h))+$ - type: string unhealthyThreshold: default: 3 description: UnhealthyThreshold is the number of consecutive @@ -688,17 +680,74 @@ spec: optimizer: description: |- Optimizer configures the MCP optimizer for context optimization on large toolsets. - When enabled, vMCP exposes only find_tool and call_tool operations to clients + When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients instead of all backend tools directly. This reduces token usage by allowing LLMs to discover relevant tools on demand rather than receiving all tool definitions. properties: - embeddingService: + embeddingBackend: description: |- - EmbeddingService is the name of a Kubernetes Service that provides the embedding service - for semantic tool discovery. The service must implement the optimizer embedding API. + EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder". + - "ollama": Uses local Ollama HTTP API for embeddings + - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + - "placeholder": Uses deterministic hash-based embeddings (for testing/development) + enum: + - ollama + - openai-compatible + - placeholder + type: string + embeddingDimension: + description: |- + EmbeddingDimension is the dimension of the embedding vectors. + Common values: + - 384: all-MiniLM-L6-v2, nomic-embed-text + - 768: BAAI/bge-small-en-v1.5 + - 1536: OpenAI text-embedding-3-small + minimum: 1 + type: integer + embeddingModel: + description: |- + EmbeddingModel is the model name to use for embeddings. + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "nomic-embed-text", "all-minilm" + - vLLM: "BAAI/bge-small-en-v1.5" + - OpenAI: "text-embedding-3-small" + type: string + embeddingURL: + description: |- + EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API). + Required when EmbeddingBackend is "ollama" or "openai-compatible". + Examples: + - Ollama: "http://localhost:11434" + - vLLM: "http://vllm-service:8000/v1" + - OpenAI: "https://api.openai.com/v1" + type: string + enabled: + description: |- + Enabled determines whether the optimizer is active. + When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools. + type: boolean + ftsDBPath: + description: |- + FTSDBPath is the path to the SQLite FTS5 database for BM25 text search. + If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + Hybrid search (semantic + BM25) is always enabled. + type: string + hybridSearchRatio: + description: |- + HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search. + Value range: 0 (all BM25) to 100 (all semantic), representing percentage. + Default: 70 (70% semantic, 30% BM25) + Only used when FTSDBPath is set. + maximum: 100 + minimum: 0 + type: integer + persistPath: + description: |- + PersistPath is the optional filesystem path for persisting the chromem-go database. + If empty, the database will be in-memory only (ephemeral). + When set, tool metadata and embeddings are persisted to disk for faster restarts. type: string - required: - - embeddingService type: object outgoingAuth: description: |- diff --git a/go.mod b/go.mod index 5ef7e657d7..39fbfb0af5 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/stacklok/toolhive -go 1.25.6 +go 1.25.5 require ( dario.cat/mergo v1.0.2 @@ -29,6 +29,7 @@ require ( github.com/onsi/ginkgo/v2 v2.27.5 github.com/onsi/gomega v1.39.0 github.com/ory/fosite v0.49.0 + github.com/philippgille/chromem-go v0.7.0 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/prometheus/client_golang v1.23.2 github.com/sigstore/protobuf-specs v0.5.0 @@ -59,6 +60,7 @@ require ( k8s.io/api v0.35.0 k8s.io/apimachinery v0.35.0 k8s.io/utils v0.0.0-20260108192941-914a6e750570 + modernc.org/sqlite v1.44.0 sigs.k8s.io/controller-runtime v0.22.4 sigs.k8s.io/yaml v1.6.0 ) @@ -174,6 +176,7 @@ require ( github.com/muesli/termenv v0.16.0 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/oklog/ulid v1.3.1 // indirect github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect github.com/olekukonko/errors v1.1.0 // indirect @@ -188,6 +191,7 @@ require ( github.com/prometheus/common v0.67.4 // indirect github.com/prometheus/otlptranslator v1.0.0 // indirect github.com/prometheus/procfs v0.19.2 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.11.0 // indirect @@ -251,6 +255,9 @@ require ( k8s.io/apiextensions-apiserver v0.34.1 // indirect k8s.io/klog/v2 v2.130.1 // indirect k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 // indirect + modernc.org/libc v1.67.4 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 // indirect sigs.k8s.io/randfill v1.0.0 // indirect sigs.k8s.io/structured-merge-diff/v6 v6.3.0 // indirect @@ -268,7 +275,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt/v5 v5.3.1 + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/lestrrat-go/blackmagic v1.0.4 // indirect github.com/lestrrat-go/httpcc v1.0.1 // indirect @@ -286,7 +293,7 @@ require ( go.opentelemetry.io/otel/metric v1.39.0 go.opentelemetry.io/otel/trace v1.39.0 golang.org/x/crypto v0.47.0 - golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/sys v0.40.0 k8s.io/client-go v0.35.0 ) diff --git a/go.sum b/go.sum index 6b6795e3d4..8a1997bac9 100644 --- a/go.sum +++ b/go.sum @@ -333,8 +333,8 @@ github.com/gofrs/uuid v4.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRx github.com/gofrs/uuid v4.3.1+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= -github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= github.com/golang/mock v1.7.0-rc.1 h1:YojYx61/OLFsiv6Rw1Z96LpldJIy31o+UHmwAUMJ6/U= @@ -602,6 +602,8 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f h1:y5//uYreIhSUg3J github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/natefinch/atomic v1.0.1 h1:ZPYKxkqQOx3KZ+RsbnP/YsgvxWQPGxjC0oBt2AhwV0A= github.com/natefinch/atomic v1.0.1/go.mod h1:N/D/ELrljoqDyT3rZrsUmtsuzvHkeB/wWjHV22AZRbM= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/nyaruka/phonenumbers v1.1.6 h1:DcueYq7QrOArAprAYNoQfDgp0KetO4LqtnBtQC6Wyes= github.com/nyaruka/phonenumbers v1.1.6/go.mod h1:yShPJHDSH3aTKzCbXyVxNpbl2kA+F+Ne5Pun/MvFRos= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= @@ -640,6 +642,8 @@ github.com/ory/x v0.0.665 h1:61vv0ObCDSX1vOQYbxBeqDiv4YiPmMT91lYxDaaKX08= github.com/ory/x v0.0.665/go.mod h1:7SCTki3N0De3ZpqlxhxU/94ZrOCfNEnXwVtd0xVt+L8= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/philippgille/chromem-go v0.7.0 h1:4jfvfyKymjKNfGxBUhHUcj1kp7B17NL/I1P+vGh1RvY= +github.com/philippgille/chromem-go v0.7.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/pjbgf/sha1cd v0.3.2 h1:a9wb0bp1oC2TGwStyn0Umc/IGKQnEgF0vVaZ8QF8eo4= github.com/pjbgf/sha1cd v0.3.2/go.mod h1:zQWigSxVmsHEZow5qaLtPYxpcKMMQpa09ixqBxuCS6A= github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ= @@ -661,6 +665,8 @@ github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEo github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM= github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -909,8 +915,8 @@ golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= -golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/exp/event v0.0.0-20251219203646-944ab1f22d93 h1:Fee8ke0jLfLhU4ywDLs7IYmhJ8MrSP0iZE3p39EKKSc= golang.org/x/exp/event v0.0.0-20251219203646-944ab1f22d93/go.mod h1:HgAgrKXB9WF2wFZJBGBnRVkmsC8n+v2ja/8VR0H3QkY= golang.org/x/exp/jsonrpc2 v0.0.0-20260112195511-716be5621a96 h1:cN9X2vSBmT3Ruw2UlbJNLJh0iBqTmtSB0dRfh5aumiY= @@ -1086,6 +1092,34 @@ k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912 h1:Y3gxNAuB0OBLImH611+UDZ k8s.io/kube-openapi v0.0.0-20250910181357-589584f1c912/go.mod h1:kdmbQkyfwUagLfXIad1y2TdrjPFWp2Q89B3qkRwf/pQ= k8s.io/utils v0.0.0-20260108192941-914a6e750570 h1:JT4W8lsdrGENg9W+YwwdLJxklIuKWdRm+BC+xt33FOY= k8s.io/utils v0.0.0-20260108192941-914a6e750570/go.mod h1:xDxuJ0whA3d0I4mf/C4ppKHxXynQ+fxnkmQH0vTHnuk= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.4 h1:zZGmCMUVPORtKv95c2ReQN5VDjvkoRm9GWPTEPuvlWg= +modernc.org/libc v1.67.4/go.mod h1:QvvnnJ5P7aitu0ReNpVIEyesuhmDLQ8kaEoyMjIFZJA= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.44.0 h1:YjCKJnzZde2mLVy0cMKTSL4PxCmbIguOq9lGp8ZvGOc= +modernc.org/sqlite v1.44.0/go.mod h1:2Dq41ir5/qri7QJJJKNZcP4UF7TsX/KNeykYgPDtGhE= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= sigs.k8s.io/controller-runtime v0.22.4 h1:GEjV7KV3TY8e+tJ2LCTxUTanW4z/FmNB7l327UfMq9A= sigs.k8s.io/controller-runtime v0.22.4/go.mod h1:+QX1XUpTXN4mLoblf4tqr5CQcyHPAki2HLXqQMY6vh8= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg= diff --git a/pkg/vmcp/config/config.go b/pkg/vmcp/config/config.go index 4f25a469e1..f477c01232 100644 --- a/pkg/vmcp/config/config.go +++ b/pkg/vmcp/config/config.go @@ -151,7 +151,7 @@ type Config struct { Audit *audit.Config `json:"audit,omitempty" yaml:"audit,omitempty"` // Optimizer configures the MCP optimizer for context optimization on large toolsets. - // When enabled, vMCP exposes only find_tool and call_tool operations to clients + // When enabled, vMCP exposes optim_find_tool and optim_call_tool operations to clients // instead of all backend tools directly. This reduces token usage by allowing // LLMs to discover relevant tools on demand rather than receiving all tool definitions. // +optional @@ -447,13 +447,6 @@ type FailureHandlingConfig struct { // +optional UnhealthyThreshold int `json:"unhealthyThreshold,omitempty" yaml:"unhealthyThreshold,omitempty"` - // StatusReportingInterval is the interval for reporting status updates to Kubernetes. - // This controls how often the vMCP runtime reports backend health and phase changes. - // Lower values provide faster status updates but increase API server load. - // +kubebuilder:default="30s" - // +optional - StatusReportingInterval Duration `json:"statusReportingInterval,omitempty" yaml:"statusReportingInterval,omitempty"` - // PartialFailureMode defines behavior when some backends are unavailable. // - fail: Fail entire request if any backend is unavailable // - best_effort: Continue with available backends @@ -703,16 +696,72 @@ type OutputProperty struct { Default thvjson.Any `json:"default,omitempty" yaml:"default,omitempty"` } -// OptimizerConfig configures the MCP optimizer. -// When enabled, vMCP exposes only find_tool and call_tool operations to clients -// instead of all backend tools directly. +// OptimizerConfig configures the MCP optimizer for semantic tool discovery. +// The optimizer reduces token usage by allowing LLMs to discover relevant tools +// on demand rather than receiving all tool definitions upfront. // +kubebuilder:object:generate=true // +gendoc type OptimizerConfig struct { - // EmbeddingService is the name of a Kubernetes Service that provides the embedding service - // for semantic tool discovery. The service must implement the optimizer embedding API. - // +kubebuilder:validation:Required - EmbeddingService string `json:"embeddingService" yaml:"embeddingService"` + // Enabled determines whether the optimizer is active. + // When true, vMCP exposes optim_find_tool and optim_call_tool instead of all backend tools. + // +optional + Enabled bool `json:"enabled" yaml:"enabled"` + + // EmbeddingBackend specifies the embedding provider: "ollama", "openai-compatible", or "placeholder". + // - "ollama": Uses local Ollama HTTP API for embeddings + // - "openai-compatible": Uses OpenAI-compatible API (vLLM, OpenAI, etc.) + // - "placeholder": Uses deterministic hash-based embeddings (for testing/development) + // +kubebuilder:validation:Enum=ollama;openai-compatible;placeholder + // +optional + EmbeddingBackend string `json:"embeddingBackend,omitempty" yaml:"embeddingBackend,omitempty"` + + // EmbeddingURL is the base URL for the embedding service (Ollama or OpenAI-compatible API). + // Required when EmbeddingBackend is "ollama" or "openai-compatible". + // Examples: + // - Ollama: "http://localhost:11434" + // - vLLM: "http://vllm-service:8000/v1" + // - OpenAI: "https://api.openai.com/v1" + // +optional + EmbeddingURL string `json:"embeddingURL,omitempty" yaml:"embeddingURL,omitempty"` + + // EmbeddingModel is the model name to use for embeddings. + // Required when EmbeddingBackend is "ollama" or "openai-compatible". + // Examples: + // - Ollama: "nomic-embed-text", "all-minilm" + // - vLLM: "BAAI/bge-small-en-v1.5" + // - OpenAI: "text-embedding-3-small" + // +optional + EmbeddingModel string `json:"embeddingModel,omitempty" yaml:"embeddingModel,omitempty"` + + // EmbeddingDimension is the dimension of the embedding vectors. + // Common values: + // - 384: all-MiniLM-L6-v2, nomic-embed-text + // - 768: BAAI/bge-small-en-v1.5 + // - 1536: OpenAI text-embedding-3-small + // +kubebuilder:validation:Minimum=1 + // +optional + EmbeddingDimension int `json:"embeddingDimension,omitempty" yaml:"embeddingDimension,omitempty"` + + // PersistPath is the optional filesystem path for persisting the chromem-go database. + // If empty, the database will be in-memory only (ephemeral). + // When set, tool metadata and embeddings are persisted to disk for faster restarts. + // +optional + PersistPath string `json:"persistPath,omitempty" yaml:"persistPath,omitempty"` + + // FTSDBPath is the path to the SQLite FTS5 database for BM25 text search. + // If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + // Hybrid search (semantic + BM25) is always enabled. + // +optional + FTSDBPath string `json:"ftsDBPath,omitempty" yaml:"ftsDBPath,omitempty"` + + // HybridSearchRatio controls the mix of semantic vs BM25 results in hybrid search. + // Value range: 0 (all BM25) to 100 (all semantic), representing percentage. + // Default: 70 (70% semantic, 30% BM25) + // Only used when FTSDBPath is set. + // +optional + // +kubebuilder:validation:Minimum=0 + // +kubebuilder:validation:Maximum=100 + HybridSearchRatio *int `json:"hybridSearchRatio,omitempty" yaml:"hybridSearchRatio,omitempty"` } // Validator validates configuration. diff --git a/pkg/vmcp/config/defaults.go b/pkg/vmcp/config/defaults.go index ab8b6b8f9c..834f2b7625 100644 --- a/pkg/vmcp/config/defaults.go +++ b/pkg/vmcp/config/defaults.go @@ -20,9 +20,6 @@ const ( // before marking a backend as unhealthy. defaultUnhealthyThreshold = 3 - // defaultStatusReportingInterval is the default interval for reporting status updates. - defaultStatusReportingInterval = 30 * time.Second - // defaultPartialFailureMode defines the default behavior when some backends fail. // "fail" means the entire request fails if any backend is unavailable. defaultPartialFailureMode = "fail" @@ -52,10 +49,9 @@ func DefaultOperationalConfig() *OperationalConfig { PerWorkload: nil, }, FailureHandling: &FailureHandlingConfig{ - HealthCheckInterval: Duration(defaultHealthCheckInterval), - UnhealthyThreshold: defaultUnhealthyThreshold, - StatusReportingInterval: Duration(defaultStatusReportingInterval), - PartialFailureMode: defaultPartialFailureMode, + HealthCheckInterval: Duration(defaultHealthCheckInterval), + UnhealthyThreshold: defaultUnhealthyThreshold, + PartialFailureMode: defaultPartialFailureMode, CircuitBreaker: &CircuitBreakerConfig{ Enabled: defaultCircuitBreakerEnabled, FailureThreshold: defaultCircuitBreakerFailureThreshold, diff --git a/pkg/vmcp/optimizer/dummy_optimizer.go b/pkg/vmcp/optimizer/dummy_optimizer.go index 00c9be9eae..fcaafd27d6 100644 --- a/pkg/vmcp/optimizer/dummy_optimizer.go +++ b/pkg/vmcp/optimizer/dummy_optimizer.go @@ -10,7 +10,11 @@ import ( "strings" "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + mcpserver "github.com/mark3labs/mcp-go/server" + + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/server/adapter" ) // DummyOptimizer implements the Optimizer interface using exact string matching. @@ -21,19 +25,19 @@ import ( // For production use, see the EmbeddingOptimizer which uses semantic similarity. type DummyOptimizer struct { // tools contains all available tools indexed by name. - tools map[string]server.ServerTool + tools map[string]mcpserver.ServerTool } // NewDummyOptimizer creates a new DummyOptimizer with the given tools. // // The tools slice should contain all backend tools (as ServerTool with handlers). -func NewDummyOptimizer(tools []server.ServerTool) Optimizer { - toolMap := make(map[string]server.ServerTool, len(tools)) +func NewDummyOptimizer(tools []mcpserver.ServerTool) Optimizer { + toolMap := make(map[string]mcpserver.ServerTool, len(tools)) for _, tool := range tools { toolMap[tool.Tool.Name] = tool } - return DummyOptimizer{ + return &DummyOptimizer{ tools: toolMap, } } @@ -46,7 +50,7 @@ func NewDummyOptimizer(tools []server.ServerTool) Optimizer { // // Returns all matching tools with a score of 1.0 (exact match semantics). // TokenMetrics are returned as zero values (not implemented in dummy). -func (d DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindToolOutput, error) { +func (d *DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindToolOutput, error) { if input.ToolDescription == "" { return nil, fmt.Errorf("tool_description is required") } @@ -65,14 +69,25 @@ func (d DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindT return nil, err } matches = append(matches, ToolMatch{ - Name: tool.Tool.Name, - Description: tool.Tool.Description, - InputSchema: schema, - Score: 1.0, // Exact match semantics + Name: tool.Tool.Name, + Description: tool.Tool.Description, + InputSchema: schema, + BackendID: "dummy", + SimilarityScore: 1.0, // Exact match semantics + TokenCount: 0, // Not implemented in dummy }) } } + // Apply limit if specified + limit := input.Limit + if limit <= 0 { + limit = 10 + } + if len(matches) > limit { + matches = matches[:limit] + } + return &FindToolOutput{ Tools: matches, TokenMetrics: TokenMetrics{}, // Zero values for dummy @@ -83,7 +98,7 @@ func (d DummyOptimizer) FindTool(_ context.Context, input FindToolInput) (*FindT // // The tool is looked up by exact name match. If found, the handler // is invoked directly with the given parameters. -func (d DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) { +func (d *DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) { if input.ToolName == "" { return nil, fmt.Errorf("tool_name is required") } @@ -103,17 +118,108 @@ func (d DummyOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp return tool.Handler(ctx, request) } +// Close is a no-op for DummyOptimizer. +func (*DummyOptimizer) Close() error { + return nil +} + +// HandleSessionRegistration is a no-op for DummyOptimizer. +// Returns false to indicate the dummy optimizer doesn't handle session registration. +func (*DummyOptimizer) HandleSessionRegistration( + _ context.Context, + _ string, + _ *aggregator.AggregatedCapabilities, + _ *mcpserver.MCPServer, + _ func([]vmcp.Resource) []mcpserver.ServerResource, +) (bool, error) { + return false, nil +} + +// CreateFindToolHandler implements adapter.OptimizerHandlerProvider. +func (d *DummyOptimizer) CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Parse input from request arguments + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return mcp.NewToolResultError("invalid arguments: expected object"), nil + } + + input := FindToolInput{} + if desc, ok := args["tool_description"].(string); ok { + input.ToolDescription = desc + } + if kw, ok := args["tool_keywords"].(string); ok { + input.ToolKeywords = kw + } + if limit, ok := args["limit"].(float64); ok { + input.Limit = int(limit) + } + + output, err := d.FindTool(ctx, input) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("find_tool failed: %v", err)), nil + } + return mcp.NewToolResultStructuredOnly(output), nil + } +} + +// CreateCallToolHandler implements adapter.OptimizerHandlerProvider. +func (d *DummyOptimizer) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Parse input from request arguments + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return mcp.NewToolResultError("invalid arguments: expected object"), nil + } + + input := CallToolInput{} + if name, ok := args["tool_name"].(string); ok { + input.ToolName = name + } + if params, ok := args["parameters"].(map[string]any); ok { + input.Parameters = params + } + + result, err := d.CallTool(ctx, input) + if err != nil { + return mcp.NewToolResultError(fmt.Sprintf("call_tool failed: %v", err)), nil + } + return result, nil + } +} + // getToolSchema returns the input schema for a tool. // Prefers RawInputSchema if set, otherwise marshals InputSchema. -func getToolSchema(tool mcp.Tool) (json.RawMessage, error) { +func getToolSchema(tool mcp.Tool) (map[string]any, error) { + var result map[string]any + if len(tool.RawInputSchema) > 0 { - return tool.RawInputSchema, nil + if err := json.Unmarshal(tool.RawInputSchema, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal raw input schema for tool %s: %w", tool.Name, err) + } + return result, nil } - // Fall back to InputSchema + // Fall back to InputSchema - convert to map via JSON round-trip + // Check if InputSchema has any content by marshaling it data, err := json.Marshal(tool.InputSchema) if err != nil { return nil, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) } - return data, nil + + // Empty struct marshals to "{}", only process if not empty + if string(data) != "{}" && string(data) != "null" { + if err := json.Unmarshal(data, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal input schema for tool %s: %w", tool.Name, err) + } + return result, nil + } + + return nil, nil } + +// Verify DummyOptimizer implements Optimizer interface. +var _ Optimizer = (*DummyOptimizer)(nil) + +// Verify DummyOptimizer implements OptimizerHandlerProvider interface. +var _ adapter.OptimizerHandlerProvider = (*DummyOptimizer)(nil) diff --git a/pkg/vmcp/optimizer/find_tool_semantic_search_test.go b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go new file mode 100644 index 0000000000..742401d04a --- /dev/null +++ b/pkg/vmcp/optimizer/find_tool_semantic_search_test.go @@ -0,0 +1,689 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/embeddings" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +const ( + testBackendOllama = "ollama" + testBackendOpenAI = "openai" +) + +// verifyEmbeddingBackendWorking verifies that the embedding backend is actually working by attempting to generate an embedding +// This ensures the service is not just reachable but actually functional +func verifyEmbeddingBackendWorking(t *testing.T, manager *embeddings.Manager, backendType string) { + t.Helper() + _, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + if backendType == testBackendOllama { + t.Skipf("Skipping test: Ollama is reachable but embedding generation failed. Error: %v. Ensure 'ollama pull %s' has been executed", err, embeddings.DefaultModelAllMiniLM) + } else { + t.Skipf("Skipping test: Embedding backend is reachable but embedding generation failed. Error: %v", err) + } + } +} + +// TestFindTool_SemanticSearch tests semantic search capabilities +// These tests verify that find_tool can find tools based on semantic meaning, +// not just exact keyword matches +func TestFindTool_SemanticSearch(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + embeddingBackend := testBackendOllama + embeddingConfig := &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, // all-MiniLM-L6-v2 dimension + } + + // Test if Ollama is available + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + // Try OpenAI-compatible (might be vLLM or Ollama v1 API) + embeddingConfig.BackendType = testBackendOpenAI + embeddingConfig.BaseURL = "http://localhost:11434" + embeddingConfig.Model = embeddings.DefaultModelAllMiniLM + embeddingConfig.Dimension = 768 + embeddingManager, err = embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping semantic search test: No embedding backend available (Ollama or OpenAI-compatible). Error: %v", err) + return + } + embeddingBackend = testBackendOpenAI + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify embedding backend is actually working, not just reachable + verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend) + + // Setup optimizer integration with high semantic ratio to favor semantic search + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + hybridRatio := 90 // 90% semantic, 10% BM25 to test semantic search + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: embeddingBackend, + EmbeddingURL: embeddingConfig.BaseURL, + EmbeddingModel: embeddingConfig.Model, + EmbeddingDimension: embeddingConfig.Dimension, + HybridSearchRatio: &hybridRatio, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + // Create tools with diverse descriptions to test semantic understanding + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_pull_requests", + Description: "List pull requests in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_pull_request", + Description: "Create a new pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_merge_pull_request", + Description: "Merge a pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_issue_read", + Description: "Get information about a specific issue in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_issues", + Description: "List issues in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + { + Name: "github_get_commit", + Description: "Get details for a commit from a GitHub repository", + BackendID: "github", + }, + { + Name: "github_get_branch", + Description: "Get information about a branch in a GitHub repository", + BackendID: "github", + }, + { + Name: "fetch_fetch", + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.", + BackendID: "fetch", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test cases for semantic search - queries that mean the same thing but use different words + testCases := []struct { + name string + query string + keywords string + expectedTools []string // Tools that should be found semantically + description string + }{ + { + name: "semantic_pr_synonyms", + query: "view code review request", + keywords: "", + expectedTools: []string{"github_pull_request_read", "github_list_pull_requests"}, + description: "Should find PR tools using semantic synonyms (code review = pull request)", + }, + { + name: "semantic_merge_synonyms", + query: "combine code changes", + keywords: "", + expectedTools: []string{"github_merge_pull_request"}, + description: "Should find merge tool using semantic meaning (combine = merge)", + }, + { + name: "semantic_create_synonyms", + query: "make a new code review", + keywords: "", + expectedTools: []string{"github_create_pull_request", "github_list_pull_requests", "github_pull_request_read"}, + description: "Should find PR-related tools using semantic meaning (make = create, code review = PR)", + }, + { + name: "semantic_issue_synonyms", + query: "show bug reports", + keywords: "", + expectedTools: []string{"github_issue_read", "github_list_issues"}, + description: "Should find issue tools using semantic synonyms (bug report = issue)", + }, + { + name: "semantic_repository_synonyms", + query: "start a new project", + keywords: "", + expectedTools: []string{"github_create_repository"}, + description: "Should find repository tool using semantic meaning (project = repository)", + }, + { + name: "semantic_commit_synonyms", + query: "get change details", + keywords: "", + expectedTools: []string{"github_get_commit"}, + description: "Should find commit tool using semantic meaning (change = commit)", + }, + { + name: "semantic_fetch_synonyms", + query: "download web page content", + keywords: "", + expectedTools: []string{"fetch_fetch"}, + description: "Should find fetch tool using semantic synonyms (download = fetch)", + }, + { + name: "semantic_branch_synonyms", + query: "get branch information", + keywords: "", + expectedTools: []string{"github_get_branch"}, + description: "Should find branch tool using semantic meaning", + }, + { + name: "semantic_related_concepts", + query: "code collaboration features", + keywords: "", + expectedTools: []string{"github_pull_request_read", "github_create_pull_request", "github_issue_read"}, + description: "Should find collaboration-related tools (PRs and issues are collaboration features)", + }, + { + name: "semantic_intent_based", + query: "I want to see what code changes were made", + keywords: "", + expectedTools: []string{"github_get_commit", "github_pull_request_read"}, + description: "Should find tools based on user intent (seeing code changes = commits/PRs)", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": tc.query, + "tool_keywords": tc.keywords, + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError, "Tool call should not return error for query: %s", tc.query) + + // Parse the result + require.NotEmpty(t, result.Content, "Result should have content") + textContent, okText := mcp.AsTextContent(result.Content[0]) + require.True(t, okText, "Result should be text content") + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "Result should be valid JSON") + + toolsArray, okArray := response["tools"].([]interface{}) + require.True(t, okArray, "Response should have tools array") + require.NotEmpty(t, toolsArray, "Should return at least one result for semantic query: %s", tc.query) + + // Extract tool names from results + foundTools := make([]string, 0, len(toolsArray)) + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap, "Tool should be a map") + toolName, okName := toolMap["name"].(string) + require.True(t, okName, "Tool should have name") + foundTools = append(foundTools, toolName) + + // Verify similarity score exists and is reasonable + similarity, okScore := toolMap["similarity_score"].(float64) + require.True(t, okScore, "Tool should have similarity_score") + assert.Greater(t, similarity, 0.0, "Similarity score should be positive") + } + + // Check that at least one expected tool is found + foundCount := 0 + for _, expectedTool := range tc.expectedTools { + for _, foundTool := range foundTools { + if foundTool == expectedTool { + foundCount++ + break + } + } + } + + assert.GreaterOrEqual(t, foundCount, 1, + "Semantic query '%s' should find at least one expected tool from %v. Found tools: %v (found %d/%d)", + tc.query, tc.expectedTools, foundTools, foundCount, len(tc.expectedTools)) + + // Log results for debugging + if foundCount < len(tc.expectedTools) { + t.Logf("Semantic query '%s': Found %d/%d expected tools. Found: %v, Expected: %v", + tc.query, foundCount, len(tc.expectedTools), foundTools, tc.expectedTools) + } + + // Verify token metrics exist + tokenMetrics, okMetrics := response["token_metrics"].(map[string]interface{}) + require.True(t, okMetrics, "Response should have token_metrics") + assert.Contains(t, tokenMetrics, "baseline_tokens") + assert.Contains(t, tokenMetrics, "returned_tokens") + }) + } +} + +// TestFindTool_SemanticVsKeyword tests that semantic search finds different results than keyword search +func TestFindTool_SemanticVsKeyword(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available + embeddingBackend := "ollama" + embeddingConfig := &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + // Try OpenAI-compatible + embeddingConfig.BackendType = testBackendOpenAI + embeddingManager, err = embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: No embedding backend available. Error: %v", err) + return + } + embeddingBackend = testBackendOpenAI + } + + // Verify embedding backend is actually working, not just reachable + verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend) + _ = embeddingManager.Close() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Test with high semantic ratio + hybridRatioSemantic := 90 // 90% semantic + configSemantic := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db-semantic"), + EmbeddingBackend: embeddingBackend, + EmbeddingURL: embeddingConfig.BaseURL, + EmbeddingModel: embeddings.DefaultModelAllMiniLM, + EmbeddingDimension: 384, + HybridSearchRatio: &hybridRatioSemantic, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integrationSemantic, err := NewIntegration(ctx, configSemantic, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integrationSemantic.Close() }() + + // Test with low semantic ratio (high BM25) + hybridRatioKeyword := 10 // 10% semantic, 90% BM25 + configKeyword := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db-keyword"), + EmbeddingBackend: embeddingBackend, + EmbeddingURL: embeddingConfig.BaseURL, + EmbeddingModel: embeddings.DefaultModelAllMiniLM, + EmbeddingDimension: 384, + HybridSearchRatio: &hybridRatioKeyword, + } + + integrationKeyword, err := NewIntegration(ctx, configKeyword, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integrationKeyword.Close() }() + + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Register both integrations + err = integrationSemantic.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + err = integrationKeyword.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integrationSemantic.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + err = integrationKeyword.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + // Query that has semantic meaning but no exact keyword match + query := "view code review" + + // Test semantic search + requestSemantic := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "", + "limit": 10, + }, + }, + } + + handlerSemantic := integrationSemantic.CreateFindToolHandler() + resultSemantic, err := handlerSemantic(ctxWithCaps, requestSemantic) + require.NoError(t, err) + require.False(t, resultSemantic.IsError) + + // Test keyword search + requestKeyword := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "", + "limit": 10, + }, + }, + } + + handlerKeyword := integrationKeyword.CreateFindToolHandler() + resultKeyword, err := handlerKeyword(ctxWithCaps, requestKeyword) + require.NoError(t, err) + require.False(t, resultKeyword.IsError) + + // Parse both results + textSemantic, _ := mcp.AsTextContent(resultSemantic.Content[0]) + var responseSemantic map[string]any + json.Unmarshal([]byte(textSemantic.Text), &responseSemantic) + + textKeyword, _ := mcp.AsTextContent(resultKeyword.Content[0]) + var responseKeyword map[string]any + json.Unmarshal([]byte(textKeyword.Text), &responseKeyword) + + toolsSemantic, _ := responseSemantic["tools"].([]interface{}) + toolsKeyword, _ := responseKeyword["tools"].([]interface{}) + + // Both should find results (semantic should find PR tools, keyword might not) + assert.NotEmpty(t, toolsSemantic, "Semantic search should find results") + assert.NotEmpty(t, toolsKeyword, "Keyword search should find results") + + // Semantic search should find pull request tools even without exact keyword match + foundPRSemantic := false + for _, toolInterface := range toolsSemantic { + toolMap, _ := toolInterface.(map[string]interface{}) + toolName, _ := toolMap["name"].(string) + if toolName == "github_pull_request_read" { + foundPRSemantic = true + break + } + } + + t.Logf("Semantic search (90%% semantic): Found %d tools", len(toolsSemantic)) + t.Logf("Keyword search (10%% semantic): Found %d tools", len(toolsKeyword)) + t.Logf("Semantic search found PR tool: %v", foundPRSemantic) + + // Semantic search should be able to find semantically related tools + // even when keywords don't match exactly + assert.True(t, foundPRSemantic, + "Semantic search should find 'github_pull_request_read' for query 'view code review' even without exact keyword match") +} + +// TestFindTool_SemanticSimilarityScores tests that similarity scores are meaningful +func TestFindTool_SemanticSimilarityScores(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available + embeddingBackend := "ollama" + embeddingConfig := &embeddings.Config{ + BackendType: embeddingBackend, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + // Try OpenAI-compatible + embeddingConfig.BackendType = testBackendOpenAI + embeddingManager, err = embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: No embedding backend available. Error: %v", err) + return + } + embeddingBackend = testBackendOpenAI + } + + // Verify embedding backend is actually working, not just reachable + verifyEmbeddingBackendWorking(t, embeddingManager, embeddingBackend) + _ = embeddingManager.Close() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + hybridRatio := 90 // High semantic ratio + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: embeddingBackend, + EmbeddingURL: embeddingConfig.BaseURL, + EmbeddingModel: embeddings.DefaultModelAllMiniLM, + EmbeddingDimension: 384, + HybridSearchRatio: &hybridRatio, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + { + Name: "fetch_fetch", + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.", + BackendID: "fetch", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Query for pull request + query := "view pull request" + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": "", + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.False(t, result.IsError) + + textContent, _ := mcp.AsTextContent(result.Content[0]) + var response map[string]any + json.Unmarshal([]byte(textContent.Text), &response) + + toolsArray, _ := response["tools"].([]interface{}) + require.NotEmpty(t, toolsArray) + + // Check that results are sorted by similarity (highest first) + var similarities []float64 + for _, toolInterface := range toolsArray { + toolMap, _ := toolInterface.(map[string]interface{}) + similarity, _ := toolMap["similarity_score"].(float64) + similarities = append(similarities, similarity) + } + + // Verify results are sorted by similarity (descending) + for i := 1; i < len(similarities); i++ { + assert.GreaterOrEqual(t, similarities[i-1], similarities[i], + "Results should be sorted by similarity score (descending). Scores: %v", similarities) + } + + // The most relevant tool (pull request) should have a higher similarity than unrelated tools + if len(similarities) > 1 { + // First result should have highest similarity + assert.Greater(t, similarities[0], 0.0, "Top result should have positive similarity") + } +} diff --git a/pkg/vmcp/optimizer/find_tool_string_matching_test.go b/pkg/vmcp/optimizer/find_tool_string_matching_test.go new file mode 100644 index 0000000000..65e0fd0a38 --- /dev/null +++ b/pkg/vmcp/optimizer/find_tool_string_matching_test.go @@ -0,0 +1,696 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/embeddings" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// verifyOllamaWorking verifies that Ollama is actually working by attempting to generate an embedding +// This ensures the service is not just reachable but actually functional +func verifyOllamaWorking(t *testing.T, manager *embeddings.Manager) { + t.Helper() + _, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Skipf("Skipping test: Ollama is reachable but embedding generation failed. Error: %v. Ensure 'ollama pull %s' has been executed", err, embeddings.DefaultModelAllMiniLM) + } +} + +// getRealToolData returns test data based on actual MCP server tools +// These are real tool descriptions from GitHub and other MCP servers +func getRealToolData() []vmcp.Tool { + return []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_pull_requests", + Description: "List pull requests in a GitHub repository. If the user specifies an author, then DO NOT use this tool and use the search_pull_requests tool instead.", + BackendID: "github", + }, + { + Name: "github_search_pull_requests", + Description: "Search for pull requests in GitHub repositories using issues search syntax already scoped to is:pr", + BackendID: "github", + }, + { + Name: "github_create_pull_request", + Description: "Create a new pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_merge_pull_request", + Description: "Merge a pull request in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_pull_request_review_write", + Description: "Create and/or submit, delete review of a pull request.", + BackendID: "github", + }, + { + Name: "github_issue_read", + Description: "Get information about a specific issue in a GitHub repository.", + BackendID: "github", + }, + { + Name: "github_list_issues", + Description: "List issues in a GitHub repository. For pagination, use the 'endCursor' from the previous response's 'pageInfo' in the 'after' parameter.", + BackendID: "github", + }, + { + Name: "github_create_repository", + Description: "Create a new GitHub repository in your account or specified organization", + BackendID: "github", + }, + { + Name: "github_get_commit", + Description: "Get details for a commit from a GitHub repository", + BackendID: "github", + }, + { + Name: "fetch_fetch", + Description: "Fetches a URL from the internet and optionally extracts its contents as markdown.", + BackendID: "fetch", + }, + } +} + +// TestFindTool_StringMatching tests that find_tool can match strings correctly +func TestFindTool_StringMatching(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Setup optimizer integration + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify Ollama is actually working, not just reachable + verifyOllamaWorking(t, embeddingManager) + + hybridRatio := 50 // 50% semantic, 50% BM25 for better string matching + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: embeddings.BackendTypeOllama, + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: embeddings.DefaultModelAllMiniLM, + EmbeddingDimension: 384, + HybridSearchRatio: &hybridRatio, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + // Get real tool data + tools := getRealToolData() + + // Create capabilities with real tools + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + // Build routing table + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + // Register session and generate embeddings + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + // Create context with capabilities + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test cases: query -> expected tool names that should be found + testCases := []struct { + name string + query string + keywords string + expectedTools []string // Tools that should definitely be in results + minResults int // Minimum number of results expected + description string + }{ + { + name: "exact_pull_request_match", + query: "pull request", + keywords: "pull request", + expectedTools: []string{"github_pull_request_read", "github_list_pull_requests", "github_create_pull_request"}, + minResults: 3, + description: "Should find tools with exact 'pull request' string match", + }, + { + name: "pull_request_in_name", + query: "pull request", + keywords: "pull_request", + expectedTools: []string{"github_pull_request_read", "github_list_pull_requests"}, + minResults: 2, + description: "Should match tools with 'pull_request' in name", + }, + { + name: "list_pull_requests", + query: "list pull requests", + keywords: "list pull requests", + expectedTools: []string{"github_list_pull_requests"}, + minResults: 1, + description: "Should find list pull requests tool", + }, + { + name: "read_pull_request", + query: "read pull request", + keywords: "read pull request", + expectedTools: []string{"github_pull_request_read"}, + minResults: 1, + description: "Should find read pull request tool", + }, + { + name: "create_pull_request", + query: "create pull request", + keywords: "create pull request", + expectedTools: []string{"github_create_pull_request"}, + minResults: 1, + description: "Should find create pull request tool", + }, + { + name: "merge_pull_request", + query: "merge pull request", + keywords: "merge pull request", + expectedTools: []string{"github_merge_pull_request"}, + minResults: 1, + description: "Should find merge pull request tool", + }, + { + name: "search_pull_requests", + query: "search pull requests", + keywords: "search pull requests", + expectedTools: []string{"github_search_pull_requests"}, + minResults: 1, + description: "Should find search pull requests tool", + }, + { + name: "issue_tools", + query: "issue", + keywords: "issue", + expectedTools: []string{"github_issue_read", "github_list_issues"}, + minResults: 2, + description: "Should find issue-related tools", + }, + { + name: "repository_tool", + query: "create repository", + keywords: "create repository", + expectedTools: []string{"github_create_repository"}, + minResults: 1, + description: "Should find create repository tool", + }, + { + name: "commit_tool", + query: "get commit", + keywords: "commit", + expectedTools: []string{"github_get_commit"}, + minResults: 1, + description: "Should find get commit tool", + }, + { + name: "fetch_tool", + query: "fetch URL", + keywords: "fetch", + expectedTools: []string{"fetch_fetch"}, + minResults: 1, + description: "Should find fetch tool", + }, + } + + for _, tc := range testCases { + tc := tc // capture loop variable + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create the tool call request + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": tc.query, + "tool_keywords": tc.keywords, + "limit": 20, + }, + }, + } + + // Call the handler + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError, "Tool call should not return error") + + // Parse the result + require.NotEmpty(t, result.Content, "Result should have content") + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok, "Result should be text content") + + // Parse JSON response + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err, "Result should be valid JSON") + + // Check tools array exists + toolsArray, ok := response["tools"].([]interface{}) + require.True(t, ok, "Response should have tools array") + require.GreaterOrEqual(t, len(toolsArray), tc.minResults, + "Should return at least %d results for query: %s", tc.minResults, tc.query) + + // Extract tool names from results + foundTools := make([]string, 0, len(toolsArray)) + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap, "Tool should be a map") + toolName, okName := toolMap["name"].(string) + require.True(t, okName, "Tool should have name") + foundTools = append(foundTools, toolName) + } + + // Check that at least some expected tools are found + // String matching may not be perfect, so we check that at least one expected tool is found + foundCount := 0 + for _, expectedTool := range tc.expectedTools { + for _, foundTool := range foundTools { + if foundTool == expectedTool { + foundCount++ + break + } + } + } + + // We should find at least one expected tool, or at least 50% of expected tools + minExpected := 1 + if len(tc.expectedTools) > 1 { + half := len(tc.expectedTools) / 2 + if half > minExpected { + minExpected = half + } + } + + assert.GreaterOrEqual(t, foundCount, minExpected, + "Query '%s' should find at least %d of expected tools %v. Found tools: %v (found %d/%d)", + tc.query, minExpected, tc.expectedTools, foundTools, foundCount, len(tc.expectedTools)) + + // Log which expected tools were found for debugging + if foundCount < len(tc.expectedTools) { + t.Logf("Query '%s': Found %d/%d expected tools. Found: %v, Expected: %v", + tc.query, foundCount, len(tc.expectedTools), foundTools, tc.expectedTools) + } + + // Verify token metrics exist + tokenMetrics, ok := response["token_metrics"].(map[string]interface{}) + require.True(t, ok, "Response should have token_metrics") + assert.Contains(t, tokenMetrics, "baseline_tokens") + assert.Contains(t, tokenMetrics, "returned_tokens") + assert.Contains(t, tokenMetrics, "tokens_saved") + assert.Contains(t, tokenMetrics, "savings_percentage") + }) + } +} + +// TestFindTool_ExactStringMatch tests that exact string matches work correctly +func TestFindTool_ExactStringMatch(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Setup optimizer integration with higher BM25 ratio for better string matching + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify Ollama is actually working, not just reachable + verifyOllamaWorking(t, embeddingManager) + + hybridRatio := 30 // 30% semantic, 70% BM25 for better exact string matching + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: embeddings.BackendTypeOllama, + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: embeddings.DefaultModelAllMiniLM, + EmbeddingDimension: 384, + HybridSearchRatio: &hybridRatio, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + // Create tools with specific strings to match + tools := []vmcp.Tool{ + { + Name: "test_pull_request_tool", + Description: "This tool handles pull requests in GitHub", + BackendID: "test", + }, + { + Name: "test_issue_tool", + Description: "This tool handles issues in GitHub", + BackendID: "test", + }, + { + Name: "test_repository_tool", + Description: "This tool creates repositories", + BackendID: "test", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + for _, tool := range tools { + capabilities.RoutingTable.Tools[tool.Name] = &vmcp.BackendTarget{ + WorkloadID: tool.BackendID, + WorkloadName: tool.BackendID, + } + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "test", "test", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test exact string matching + testCases := []struct { + name string + query string + keywords string + expectedTool string + description string + }{ + { + name: "exact_pull_request_string", + query: "pull request", + keywords: "pull request", + expectedTool: "test_pull_request_tool", + description: "Should match exact 'pull request' string", + }, + { + name: "exact_issue_string", + query: "issue", + keywords: "issue", + expectedTool: "test_issue_tool", + description: "Should match exact 'issue' string", + }, + { + name: "exact_repository_string", + query: "repository", + keywords: "repository", + expectedTool: "test_repository_tool", + description: "Should match exact 'repository' string", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": tc.query, + "tool_keywords": tc.keywords, + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + + textContent, okText := mcp.AsTextContent(result.Content[0]) + require.True(t, okText) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + toolsArray, okArray := response["tools"].([]interface{}) + require.True(t, okArray) + require.NotEmpty(t, toolsArray, "Should find at least one tool for query: %s", tc.query) + + // Check that the expected tool is in the results + found := false + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap) + toolName, okName := toolMap["name"].(string) + require.True(t, okName) + if toolName == tc.expectedTool { + found = true + break + } + } + + assert.True(t, found, + "Expected tool '%s' not found in results for query '%s'. This indicates string matching is not working correctly.", + tc.expectedTool, tc.query) + }) + } +} + +// TestFindTool_CaseInsensitive tests case-insensitive string matching +func TestFindTool_CaseInsensitive(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Verify Ollama is actually working, not just reachable + verifyOllamaWorking(t, embeddingManager) + + hybridRatio := 30 // Favor BM25 for string matching + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: embeddings.BackendTypeOllama, + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: embeddings.DefaultModelAllMiniLM, + EmbeddingDimension: 384, + HybridSearchRatio: &hybridRatio, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + t.Cleanup(func() { _ = integration.Close() }) + + tools := []vmcp.Tool{ + { + Name: "github_pull_request_read", + Description: "Get information on a specific pull request in GitHub repository.", + BackendID: "github", + }, + } + + capabilities := &aggregator.AggregatedCapabilities{ + Tools: tools, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "github_pull_request_read": { + WorkloadID: "github", + WorkloadName: "github", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + session := &mockSession{sessionID: "test-session"} + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Manually ingest tools for testing (OnRegisterSession skips ingestion) + mcpTools := make([]mcp.Tool, len(tools)) + for i, tool := range tools { + mcpTools[i] = mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + } + } + err = integration.IngestToolsForTesting(ctx, "github", "GitHub", nil, mcpTools) + require.NoError(t, err) + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + // Test different case variations + queries := []string{ + "PULL REQUEST", + "Pull Request", + "pull request", + "PuLl ReQuEsT", + } + + for _, query := range queries { + query := query + t.Run("case_"+strings.ToLower(query), func(t *testing.T) { + t.Parallel() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": query, + "tool_keywords": strings.ToLower(query), + "limit": 10, + }, + }, + } + + handler := integration.CreateFindToolHandler() + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.IsError) + + textContent, okText := mcp.AsTextContent(result.Content[0]) + require.True(t, okText) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + toolsArray, okArray := response["tools"].([]interface{}) + require.True(t, okArray) + + // Should find the pull request tool regardless of case + found := false + for _, toolInterface := range toolsArray { + toolMap, okMap := toolInterface.(map[string]interface{}) + require.True(t, okMap) + toolName, okName := toolMap["name"].(string) + require.True(t, okName) + if toolName == "github_pull_request_read" { + found = true + break + } + } + + assert.True(t, found, + "Should find pull request tool with case-insensitive query: %s", query) + }) + } +} diff --git a/pkg/vmcp/optimizer/internal/INTEGRATION.md b/pkg/vmcp/optimizer/internal/INTEGRATION.md new file mode 100644 index 0000000000..a231a0dabb --- /dev/null +++ b/pkg/vmcp/optimizer/internal/INTEGRATION.md @@ -0,0 +1,134 @@ +# Integrating Optimizer with vMCP + +## Overview + +The optimizer package ingests MCP server and tool metadata into a searchable database with semantic embeddings. This enables intelligent tool discovery and token optimization for LLM consumption. + +## Integration Approach + +**Event-Driven Ingestion**: The optimizer integrates directly with vMCP's startup process. When vMCP starts and loads its configured servers, it calls the optimizer to ingest each server's metadata and tools. + +❌ **NOT** a separate polling service discovering backends +✅ **IS** called directly by vMCP during server initialization + +## How It Is Integrated + +The optimizer is already integrated into vMCP and works automatically when enabled via configuration. Here's how the integration works: + +### Initialization + +When vMCP starts with optimizer enabled in the configuration, it: + +1. Initializes the optimizer database (chromem-go + SQLite FTS5) +2. Configures the embedding backend (placeholder, Ollama, or vLLM) +3. Sets up the ingestion service + +### Automatic Ingestion + +The optimizer integrates with vMCP's `OnRegisterSession` hook, which is called whenever: + +- vMCP starts and loads configured MCP servers +- A new MCP server is dynamically added +- A session reconnects or refreshes + +When this hook is triggered, the optimizer: + +1. Retrieves the server's metadata and tools via MCP protocol +2. Generates embeddings for searchable content +3. Stores the data in both the vector database (chromem-go) and FTS5 database +4. Makes the tools immediately available for semantic search + +### Exposed Tools + +When the optimizer is enabled, vMCP automatically exposes these tools to LLM clients: + +- `optim.find_tool`: Semantic search for tools across all registered servers +- `optim.call_tool`: Dynamic tool invocation after discovery + +### Implementation Location + +The integration code is located in: +- `cmd/vmcp/optimizer.go`: Optimizer initialization and configuration +- `pkg/vmcp/optimizer/optimizer.go`: Session registration hook implementation +- `cmd/thv-operator/pkg/optimizer/ingestion/service.go`: Core ingestion service + +## Configuration + +Add optimizer configuration to vMCP's config: + +```yaml +# vMCP config +optimizer: + enabled: true + db_path: /data/optimizer.db + embedding: + backend: vllm # or "ollama" for local dev, "placeholder" for testing + url: http://vllm-service:8000 + model: sentence-transformers/all-MiniLM-L6-v2 + dimension: 384 +``` + +## Error Handling + +**Important**: Optimizer failures should NOT break vMCP functionality: + +- ✅ Log warnings if optimizer fails +- ✅ Continue server startup even if ingestion fails +- ✅ Run ingestion in goroutines to avoid blocking +- ❌ Don't fail server startup if optimizer is unavailable + +## Benefits + +1. **Automatic**: Servers are indexed as they're added to vMCP +2. **Up-to-date**: Database reflects current vMCP state +3. **No polling**: Event-driven, efficient +4. **Semantic search**: Enables intelligent tool discovery +5. **Token optimization**: Tracks token usage for LLM efficiency + +## Testing + +```go +func TestOptimizerIntegration(t *testing.T) { + // Initialize optimizer + optimizerSvc, err := ingestion.NewService(&ingestion.Config{ + DBConfig: &db.Config{Path: "/tmp/test-optimizer.db"}, + EmbeddingConfig: &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + Dimension: 384, + }, + }) + require.NoError(t, err) + defer optimizerSvc.Close() + + // Simulate vMCP starting a server + ctx := context.Background() + tools := []mcp.Tool{ + {Name: "get_weather", Description: "Get current weather"}, + {Name: "get_forecast", Description: "Get weather forecast"}, + } + + err = optimizerSvc.IngestServer( + ctx, + "weather-001", + "weather-service", + "http://weather.local", + models.TransportSSE, + ptr("Weather information service"), + tools, + ) + require.NoError(t, err) + + // Verify ingestion + server, err := optimizerSvc.GetServer(ctx, "weather-001") + require.NoError(t, err) + assert.Equal(t, "weather-service", server.Name) +} +``` + +## See Also + +- [Optimizer Package README](./README.md) - Package overview and API + diff --git a/pkg/vmcp/optimizer/internal/README.md b/pkg/vmcp/optimizer/internal/README.md new file mode 100644 index 0000000000..7db703b711 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/README.md @@ -0,0 +1,339 @@ +# Optimizer Package + +The optimizer package provides semantic tool discovery and ingestion for MCP servers in ToolHive's vMCP. It enables intelligent, context-aware tool selection to reduce token usage and improve LLM performance. + +## Features + +- **Pure Go**: No CGO dependencies - uses [chromem-go](https://github.com/philippgille/chromem-go) for vector search and `modernc.org/sqlite` for FTS5 +- **Hybrid Search**: Combines semantic search (chromem-go) with BM25 full-text search (SQLite FTS5) +- **In-Memory by Default**: Fast ephemeral database with optional persistence +- **Pluggable Embeddings**: Supports vLLM, Ollama, and placeholder backends +- **Event-Driven**: Integrates with vMCP's `OnRegisterSession` hook for automatic ingestion +- **Semantic + Keyword Search**: Configurable ratio between semantic and BM25 search +- **Token Counting**: Tracks token usage for LLM consumption metrics + +## Architecture + +``` +cmd/thv-operator/pkg/optimizer/ +├── models/ # Domain models (Server, Tool, etc.) +├── db/ # Hybrid database layer (chromem-go + SQLite FTS5) +│ ├── db.go # Database coordinator +│ ├── fts.go # SQLite FTS5 for BM25 search (pure Go) +│ ├── hybrid.go # Hybrid search combining semantic + BM25 +│ ├── backend_server.go # Server operations +│ └── backend_tool.go # Tool operations +├── embeddings/ # Embedding backends (vLLM, Ollama, placeholder) +├── ingestion/ # Event-driven ingestion service +└── tokens/ # Token counting for LLM metrics +``` + +## Embedding Backends + +The optimizer supports multiple embedding backends: + +| Backend | Use Case | Performance | Setup | +|---------|----------|-------------|-------| +| **vLLM** | **Production/Kubernetes (recommended)** | Excellent (GPU) | Deploy vLLM service | +| Ollama | Local development, CPU-only | Good | `ollama serve` | +| Placeholder | Testing, CI/CD | Fast (hash-based) | Zero setup | + +**For production Kubernetes deployments, vLLM is recommended** due to its high-throughput performance, GPU efficiency (PagedAttention), and scalability for multi-user environments. + +## Hybrid Search + +The optimizer **always uses hybrid search** combining: + +1. **Semantic Search** (chromem-go): Understands meaning and context via embeddings +2. **BM25 Full-Text Search** (SQLite FTS5): Keyword matching with Porter stemming + +This dual approach ensures the best of both worlds: semantic understanding for intent-based queries and keyword precision for technical terms and acronyms. + +### Configuration + +```yaml +optimizer: + enabled: true + embeddingBackend: placeholder + embeddingDimension: 384 + # persistPath: /data/optimizer # Optional: for persistence + # ftsDBPath: /data/optimizer-fts.db # Optional: defaults to :memory: or {persistPath}/fts.db + hybridSearchRatio: 70 # 70% semantic, 30% BM25 (default, 0-100 percentage) +``` + +| Ratio | Semantic | BM25 | Best For | +|-------|----------|------|----------| +| 1.0 | 100% | 0% | Pure semantic (intent-heavy queries) | +| 0.7 | 70% | 30% | **Default**: Balanced hybrid | +| 0.5 | 50% | 50% | Equal weight | +| 0.0 | 0% | 100% | Pure keyword (exact term matching) | + +### How It Works + +1. **Parallel Execution**: Semantic and BM25 searches run concurrently +2. **Result Merging**: Combines results and removes duplicates +3. **Ranking**: Sorts by similarity/relevance score +4. **Limit Enforcement**: Returns top N results + +### Example Queries + +| Query | Semantic Match | BM25 Match | Winner | +|-------|----------------|------------|--------| +| "What's the weather?" | ✅ `get_current_weather` | ✅ `weather_forecast` | Both (deduped) | +| "SQL database query" | ❌ (no embeddings) | ✅ `execute_sql` | BM25 | +| "Make it rain outside" | ✅ `weather_control` | ❌ (no keyword) | Semantic | + +## Quick Start + +### vMCP Integration (Recommended) + +The optimizer is designed to work as part of vMCP, not standalone: + +```yaml +# examples/vmcp-config-optimizer.yaml +optimizer: + enabled: true + embeddingBackend: placeholder # or "ollama", "openai-compatible" + embeddingDimension: 384 + # persistPath: /data/optimizer # Optional: for chromem-go persistence + # ftsDBPath: /data/fts.db # Optional: auto-defaults to :memory: or {persistPath}/fts.db + # hybridSearchRatio: 70 # Optional: 70% semantic, 30% BM25 (default, 0-100 percentage) +``` + +Start vMCP with optimizer: + +```bash +thv vmcp serve --config examples/vmcp-config-optimizer.yaml +``` + +When optimizer is enabled, vMCP exposes: +- `optim.find_tool`: Semantic search for tools +- `optim.call_tool`: Dynamic tool invocation + +### Programmatic Usage + +```go +import ( + "context" + + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/db" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/embeddings" + "github.com/stacklok/toolhive/cmd/thv-operator/pkg/optimizer/ingestion" +) + +func main() { + ctx := context.Background() + + // Initialize database (in-memory) + database, err := db.NewDB(&db.Config{ + PersistPath: "", // Empty = in-memory only + }) + if err != nil { + panic(err) + } + + // Initialize embedding manager with Ollama (default) + embeddingMgr, err := embeddings.NewManager(&embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + }) + if err != nil { + panic(err) + } + + // Create ingestion service + svc, err := ingestion.NewService(&ingestion.Config{ + DBConfig: &db.Config{PersistPath: ""}, + EmbeddingConfig: embeddingMgr.Config(), + }) + if err != nil { + panic(err) + } + defer svc.Close() + + // Ingest a server (called by vMCP on session registration) + err = svc.IngestServer(ctx, "server-id", "MyServer", nil, []mcp.Tool{...}) + if err != nil { + panic(err) + } +} +``` + +### Production Deployment with vLLM (Kubernetes) + +```yaml +optimizer: + enabled: true + embeddingBackend: openai-compatible + embeddingURL: http://vllm-service:8000/v1 + embeddingModel: BAAI/bge-small-en-v1.5 + embeddingDimension: 768 + persistPath: /data/optimizer # Persistent storage for faster restarts +``` + +Deploy vLLM alongside vMCP: + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: vllm-embeddings +spec: + template: + spec: + containers: + - name: vllm + image: vllm/vllm-openai:latest + args: + - --model + - BAAI/bge-small-en-v1.5 + - --port + - "8000" + resources: + limits: + nvidia.com/gpu: 1 +``` + +### Local Development with Ollama + +```bash +# Start Ollama +ollama serve + +# Pull an embedding model +ollama pull all-minilm +``` + +Configure vMCP: + +```yaml +optimizer: + enabled: true + embeddingBackend: ollama + embeddingURL: http://localhost:11434 + embeddingModel: all-minilm + embeddingDimension: 384 +``` + +## Configuration + +### Database + +- **Storage**: chromem-go (pure Go, no CGO) +- **Default**: In-memory (ephemeral) +- **Persistence**: Optional via `persistPath` +- **Format**: Binary (gob encoding) + +### Embedding Models + +Common embedding dimensions: +- **384**: all-MiniLM-L6-v2, nomic-embed-text (default) +- **768**: BAAI/bge-small-en-v1.5 +- **1536**: OpenAI text-embedding-3-small + +### Performance + +From chromem-go benchmarks (mid-range 2020 Intel laptop): +- **1,000 tools**: ~0.5ms query time +- **5,000 tools**: ~2.2ms query time +- **25,000 tools**: ~9.9ms query time +- **100,000 tools**: ~39.6ms query time + +Perfect for typical vMCP deployments (hundreds to thousands of tools). + +## Testing + +Run the unit tests: + +```bash +# Test all packages +go test ./cmd/thv-operator/pkg/optimizer/... + +# Test with coverage +go test -cover ./cmd/thv-operator/pkg/optimizer/... + +# Test specific package +go test ./cmd/thv-operator/pkg/optimizer/models +``` + +## Inspecting the Database + +The optimizer uses a hybrid database (chromem-go + SQLite FTS5). Here's how to inspect each: + +### Inspecting SQLite FTS5 (Easiest) + +The FTS5 database is standard SQLite and can be opened with any SQLite tool: + +```bash +# Use sqlite3 CLI +sqlite3 /tmp/vmcp-optimizer-fts.db + +# Count documents +SELECT COUNT(*) FROM backend_servers_fts; +SELECT COUNT(*) FROM backend_tools_fts; + +# View tool names and descriptions +SELECT tool_name, tool_description FROM backend_tools_fts LIMIT 10; + +# Full-text search with BM25 ranking +SELECT tool_name, rank +FROM backend_tool_fts_index +WHERE backend_tool_fts_index MATCH 'github repository' +ORDER BY rank +LIMIT 5; + +# Join servers and tools +SELECT s.name, t.tool_name, t.tool_description +FROM backend_tools_fts t +JOIN backend_servers_fts s ON t.mcpserver_id = s.id +LIMIT 10; +``` + +**VSCode Extension**: Install `alexcvzz.vscode-sqlite` to view `.db` files directly in VSCode. + +### Inspecting chromem-go (Vector Database) + +chromem-go uses `.gob` binary files. Use the provided inspection scripts: + +```bash +# Quick summary (shows collection sizes and first few documents) +go run scripts/inspect-chromem-raw.go /tmp/vmcp-optimizer-debug.db + +# View specific tool with full metadata and embeddings +go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db get_file_contents + +# View all documents (warning: lots of output) +go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db + +# Search by content +go run scripts/view-chromem-tool.go /tmp/vmcp-optimizer-debug.db "search" +``` + +### chromem-go Schema + +Each document in chromem-go contains: + +```go +Document { + ID: string // "github" or UUID for tools + Content: string // "tool_name. description..." + Embedding: []float32 // 384-dimensional vector + Metadata: map[string]string // {"type": "backend_tool", "server_id": "github", "data": "...JSON..."} +} +``` + +**Collections**: +- `backend_servers`: Server metadata (3 documents in typical setup) +- `backend_tools`: Tool metadata and embeddings (40+ documents) + +## Known Limitations + +1. **Scale**: Optimized for <100,000 tools (more than sufficient for typical vMCP deployments) +2. **Approximate Search**: chromem-go uses exhaustive search (not HNSW), but this is fine for our scale +3. **Persistence Format**: Binary gob format (not human-readable) + +## License + +This package is part of ToolHive and follows the same license. diff --git a/pkg/vmcp/optimizer/internal/db/backend_server.go b/pkg/vmcp/optimizer/internal/db/backend_server.go new file mode 100644 index 0000000000..bbaea358f9 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/backend_server.go @@ -0,0 +1,136 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package db provides chromem-go based database operations for the optimizer. +package db + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" +) + +// backendServerOps provides operations for backend servers in chromem-go +// This is a private implementation detail. Use the Database interface instead. +type backendServerOps struct { + db *chromemDB + embeddingFunc chromem.EmbeddingFunc +} + +// newBackendServerOps creates a new backendServerOps instance +func newBackendServerOps(db *chromemDB, embeddingFunc chromem.EmbeddingFunc) *backendServerOps { + return &backendServerOps{ + db: db, + embeddingFunc: embeddingFunc, + } +} + +// create adds a new backend server to the collection +func (ops *backendServerOps) create(ctx context.Context, server *models.BackendServer) error { + collection, err := ops.db.getOrCreateCollection(ctx, BackendServerCollection, ops.embeddingFunc) + if err != nil { + return fmt.Errorf("failed to get backend server collection: %w", err) + } + + // Prepare content for embedding (name + description) + content := server.Name + if server.Description != nil && *server.Description != "" { + content += ". " + *server.Description + } + + // Serialize metadata + metadata, err := serializeServerMetadata(server) + if err != nil { + return fmt.Errorf("failed to serialize server metadata: %w", err) + } + + // Create document + doc := chromem.Document{ + ID: server.ID, + Content: content, + Metadata: metadata, + } + + // If embedding is provided, use it + if len(server.ServerEmbedding) > 0 { + doc.Embedding = server.ServerEmbedding + } + + // Add document to chromem-go collection + err = collection.AddDocument(ctx, doc) + if err != nil { + return fmt.Errorf("failed to add server document to chromem-go: %w", err) + } + + // Also add to FTS5 database if available (for keyword filtering) + // Use background context to avoid cancellation issues - FTS5 is supplementary + if ftsDB := ops.db.getFTSDB(); ftsDB != nil { + // Use background context with timeout for FTS operations + // This ensures FTS operations complete even if the original context is canceled + ftsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := ftsDB.UpsertServer(ftsCtx, server); err != nil { + // Log but don't fail - FTS5 is supplementary + logger.Warnf("Failed to upsert server to FTS5: %v", err) + } + } + + logger.Debugf("Created backend server: %s (chromem-go + FTS5)", server.ID) + return nil +} + +// update updates an existing backend server (creates if not exists) +func (ops *backendServerOps) update(ctx context.Context, server *models.BackendServer) error { + // chromem-go doesn't have an update operation, so we delete and re-create + err := ops.delete(ctx, server.ID) + if err != nil { + // If server doesn't exist, that's fine + logger.Debugf("Server %s not found for update, will create new", server.ID) + } + + return ops.create(ctx, server) +} + +// delete removes a backend server +func (ops *backendServerOps) delete(ctx context.Context, serverID string) error { + collection, err := ops.db.getCollection(BackendServerCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist, nothing to delete + return nil + } + + err = collection.Delete(ctx, nil, nil, serverID) + if err != nil { + return fmt.Errorf("failed to delete server from chromem-go: %w", err) + } + + // Also delete from FTS5 database if available + if ftsDB := ops.db.getFTSDB(); ftsDB != nil { + if err := ftsDB.DeleteServer(ctx, serverID); err != nil { + // Log but don't fail + logger.Warnf("Failed to delete server from FTS5: %v", err) + } + } + + logger.Debugf("Deleted backend server: %s (chromem-go + FTS5)", serverID) + return nil +} + +// Helper functions for metadata serialization + +func serializeServerMetadata(server *models.BackendServer) (map[string]string, error) { + data, err := json.Marshal(server) + if err != nil { + return nil, err + } + return map[string]string{ + "data": string(data), + "type": "backend_server", + }, nil +} diff --git a/pkg/vmcp/optimizer/internal/db/backend_tool.go b/pkg/vmcp/optimizer/internal/db/backend_tool.go new file mode 100644 index 0000000000..0971f1f01d --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/backend_tool.go @@ -0,0 +1,235 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" +) + +// backendToolOps provides operations for backend tools in chromem-go +// This is a private implementation detail. Use the Database interface instead. +type backendToolOps struct { + db *chromemDB + embeddingFunc chromem.EmbeddingFunc +} + +// newBackendToolOps creates a new backendToolOps instance +func newBackendToolOps(db *chromemDB, embeddingFunc chromem.EmbeddingFunc) *backendToolOps { + return &backendToolOps{ + db: db, + embeddingFunc: embeddingFunc, + } +} + +// create adds a new backend tool to the collection +func (ops *backendToolOps) create(ctx context.Context, tool *models.BackendTool, serverName string) error { + collection, err := ops.db.getOrCreateCollection(ctx, BackendToolCollection, ops.embeddingFunc) + if err != nil { + return fmt.Errorf("failed to get backend tool collection: %w", err) + } + + // Prepare content for embedding (name + description + input schema summary) + content := tool.ToolName + if tool.Description != nil && *tool.Description != "" { + content += ". " + *tool.Description + } + + // Serialize metadata + metadata, err := serializeToolMetadata(tool) + if err != nil { + return fmt.Errorf("failed to serialize tool metadata: %w", err) + } + + // Create document + doc := chromem.Document{ + ID: tool.ID, + Content: content, + Metadata: metadata, + } + + // If embedding is provided, use it + if len(tool.ToolEmbedding) > 0 { + doc.Embedding = tool.ToolEmbedding + } + + // Add document to chromem-go collection + err = collection.AddDocument(ctx, doc) + if err != nil { + return fmt.Errorf("failed to add tool document to chromem-go: %w", err) + } + + // Also add to FTS5 database if available (for BM25 search) + // Use background context to avoid cancellation issues - FTS5 is supplementary + if ops.db.fts != nil { + // Use background context with timeout for FTS operations + // This ensures FTS operations complete even if the original context is canceled + ftsCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := ops.db.fts.UpsertToolMeta(ftsCtx, tool, serverName); err != nil { + // Log but don't fail - FTS5 is supplementary + logger.Warnf("Failed to upsert tool to FTS5: %v", err) + } + } + + logger.Debugf("Created backend tool: %s (chromem-go + FTS5)", tool.ID) + return nil +} + +// deleteByServer removes all tools for a given server from both chromem-go and FTS5 +func (ops *backendToolOps) deleteByServer(ctx context.Context, serverID string) error { + collection, err := ops.db.getCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist, nothing to delete in chromem-go + logger.Debug("Backend tool collection not found, skipping chromem-go deletion") + } else { + // Query all tools for this server + tools, err := ops.listByServer(ctx, serverID) + if err != nil { + return fmt.Errorf("failed to list tools for server: %w", err) + } + + // Delete each tool from chromem-go + for _, tool := range tools { + if err := collection.Delete(ctx, nil, nil, tool.ID); err != nil { + logger.Warnf("Failed to delete tool %s from chromem-go: %v", tool.ID, err) + } + } + + logger.Debugf("Deleted %d tools from chromem-go for server: %s", len(tools), serverID) + } + + // Also delete from FTS5 database if available + if ops.db.fts != nil { + if err := ops.db.fts.DeleteToolsByServer(ctx, serverID); err != nil { + logger.Warnf("Failed to delete tools from FTS5 for server %s: %v", serverID, err) + } else { + logger.Debugf("Deleted tools from FTS5 for server: %s", serverID) + } + } + + return nil +} + +// listByServer returns all tools for a given server +func (ops *backendToolOps) listByServer(ctx context.Context, serverID string) ([]*models.BackendTool, error) { + collection, err := ops.db.getCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + // Collection doesn't exist yet, return empty list + return []*models.BackendTool{}, nil + } + + // Get count to determine nResults + count := collection.Count() + if count == 0 { + return []*models.BackendTool{}, nil + } + + // Query with a generic term and metadata filter + // Using "tool" as a generic query that should match all tools + results, err := collection.Query(ctx, "tool", count, map[string]string{"server_id": serverID}, nil) + if err != nil { + // If no tools match, return empty list + return []*models.BackendTool{}, nil + } + + tools := make([]*models.BackendTool, 0, len(results)) + for _, result := range results { + tool, err := deserializeToolMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize tool: %v", err) + continue + } + tools = append(tools, tool) + } + + return tools, nil +} + +// search performs semantic search for backend tools +// This is used internally by searchHybrid. +func (ops *backendToolOps) search( + ctx context.Context, + query string, + limit int, + serverID *string, +) ([]*models.BackendToolWithMetadata, error) { + collection, err := ops.db.getCollection(BackendToolCollection, ops.embeddingFunc) + if err != nil { + return []*models.BackendToolWithMetadata{}, nil + } + + // Get collection count and adjust limit if necessary + count := collection.Count() + if count == 0 { + return []*models.BackendToolWithMetadata{}, nil + } + if limit > count { + limit = count + } + + // Build metadata filter if server ID is provided + var metadataFilter map[string]string + if serverID != nil { + metadataFilter = map[string]string{"server_id": *serverID} + } + + results, err := collection.Query(ctx, query, limit, metadataFilter, nil) + if err != nil { + return nil, fmt.Errorf("failed to search tools: %w", err) + } + + tools := make([]*models.BackendToolWithMetadata, 0, len(results)) + for _, result := range results { + tool, err := deserializeToolMetadata(result.Metadata) + if err != nil { + logger.Warnf("Failed to deserialize tool: %v", err) + continue + } + + // Add similarity score + toolWithMeta := &models.BackendToolWithMetadata{ + BackendTool: *tool, + Similarity: result.Similarity, + } + tools = append(tools, toolWithMeta) + } + + return tools, nil +} + +// Helper functions for metadata serialization + +func serializeToolMetadata(tool *models.BackendTool) (map[string]string, error) { + data, err := json.Marshal(tool) + if err != nil { + return nil, err + } + return map[string]string{ + "data": string(data), + "type": "backend_tool", + "server_id": tool.MCPServerID, + }, nil +} + +func deserializeToolMetadata(metadata map[string]string) (*models.BackendTool, error) { + data, ok := metadata["data"] + if !ok { + return nil, fmt.Errorf("missing data field in metadata") + } + + var tool models.BackendTool + if err := json.Unmarshal([]byte(data), &tool); err != nil { + return nil, err + } + + return &tool, nil +} diff --git a/pkg/vmcp/optimizer/internal/db/database_impl.go b/pkg/vmcp/optimizer/internal/db/database_impl.go new file mode 100644 index 0000000000..afed3fbbfe --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/database_impl.go @@ -0,0 +1,93 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "fmt" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" +) + +// databaseImpl implements the Database interface +type databaseImpl struct { + db *chromemDB + embeddingFunc chromem.EmbeddingFunc + backendServerOps *backendServerOps + backendToolOps *backendToolOps +} + +// NewDatabase creates a new Database instance with the provided configuration and embedding function. +// This is the main entry point for creating a database instance. +func NewDatabase(config *Config, embeddingFunc chromem.EmbeddingFunc) (Database, error) { + db, err := newChromemDB(config) + if err != nil { + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + + impl := &databaseImpl{ + db: db, + embeddingFunc: embeddingFunc, + } + + impl.backendServerOps = newBackendServerOps(db, embeddingFunc) + impl.backendToolOps = newBackendToolOps(db, embeddingFunc) + + return impl, nil +} + +// CreateOrUpdateServer creates or updates a backend server +func (d *databaseImpl) CreateOrUpdateServer(ctx context.Context, server *models.BackendServer) error { + return d.backendServerOps.update(ctx, server) +} + +// DeleteServer removes a backend server +func (d *databaseImpl) DeleteServer(ctx context.Context, serverID string) error { + return d.backendServerOps.delete(ctx, serverID) +} + +// CreateTool adds a new backend tool +func (d *databaseImpl) CreateTool(ctx context.Context, tool *models.BackendTool, serverName string) error { + return d.backendToolOps.create(ctx, tool, serverName) +} + +// DeleteToolsByServer removes all tools for a given server +func (d *databaseImpl) DeleteToolsByServer(ctx context.Context, serverID string) error { + return d.backendToolOps.deleteByServer(ctx, serverID) +} + +// SearchToolsHybrid performs hybrid search for backend tools +func (d *databaseImpl) SearchToolsHybrid( + ctx context.Context, + query string, + config *HybridSearchConfig, +) ([]*models.BackendToolWithMetadata, error) { + return d.backendToolOps.searchHybrid(ctx, query, config) +} + +// ListToolsByServer returns all tools for a given server +func (d *databaseImpl) ListToolsByServer(ctx context.Context, serverID string) ([]*models.BackendTool, error) { + return d.backendToolOps.listByServer(ctx, serverID) +} + +// GetTotalToolTokens returns the total token count across all tools +func (d *databaseImpl) GetTotalToolTokens(ctx context.Context) (int, error) { + // Use FTS database to efficiently count all tool tokens + if d.db.fts != nil { + return d.db.fts.GetTotalToolTokens(ctx) + } + return 0, fmt.Errorf("FTS database not available") +} + +// Reset clears all collections and FTS tables +func (d *databaseImpl) Reset() { + d.db.reset() +} + +// Close releases all database resources +func (d *databaseImpl) Close() error { + return d.db.close() +} diff --git a/pkg/vmcp/optimizer/internal/db/database_test.go b/pkg/vmcp/optimizer/internal/db/database_test.go new file mode 100644 index 0000000000..2dfd4b1e43 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/database_test.go @@ -0,0 +1,302 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/embeddings" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" +) + +// TestDatabase_ServerOperations tests the full lifecycle of server operations through the Database interface +func TestDatabase_ServerOperations(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDatabase(t) + defer func() { _ = db.Close() }() + + description := "A test MCP server" + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: &description, + Group: "default", + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Test create + err := db.CreateOrUpdateServer(ctx, server) + require.NoError(t, err) + + // Test update (same as create in our implementation) + server.Name = "Updated Server" + err = db.CreateOrUpdateServer(ctx, server) + require.NoError(t, err) + + // Test delete + err = db.DeleteServer(ctx, "server-1") + require.NoError(t, err) + + // Delete non-existent server should not error + err = db.DeleteServer(ctx, "non-existent") + require.NoError(t, err) +} + +// TestDatabase_ToolOperations tests the full lifecycle of tool operations through the Database interface +func TestDatabase_ToolOperations(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDatabase(t) + defer func() { _ = db.Close() }() + + description := "Test tool for weather" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "get_weather", + Description: &description, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 100, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Test create + err := db.CreateTool(ctx, tool, "Test Server") + require.NoError(t, err) + + // Test list by server + tools, err := db.ListToolsByServer(ctx, "server-1") + require.NoError(t, err) + require.Len(t, tools, 1) + assert.Equal(t, "get_weather", tools[0].ToolName) + + // Test delete by server + err = db.DeleteToolsByServer(ctx, "server-1") + require.NoError(t, err) + + // Verify deletion + tools, err = db.ListToolsByServer(ctx, "server-1") + require.NoError(t, err) + require.Empty(t, tools) +} + +// TestDatabase_HybridSearch tests hybrid search functionality through the Database interface +func TestDatabase_HybridSearch(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDatabase(t) + defer func() { _ = db.Close() }() + + // Create test tools + weatherDesc := "Get current weather information" + weatherTool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "get_weather", + Description: &weatherDesc, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 100, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + err := db.CreateTool(ctx, weatherTool, "Weather Server") + require.NoError(t, err) + + searchDesc := "Search the web for information" + searchTool := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "search_web", + Description: &searchDesc, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 150, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + err = db.CreateTool(ctx, searchTool, "Search Server") + require.NoError(t, err) + + // Test hybrid search + config := &HybridSearchConfig{ + SemanticRatio: 70, + Limit: 5, + ServerID: nil, + } + + results, err := db.SearchToolsHybrid(ctx, "weather", config) + require.NoError(t, err) + require.NotEmpty(t, results) + + // Weather tool should be in results + foundWeather := false + for _, result := range results { + if result.ToolName == "get_weather" { + foundWeather = true + break + } + } + assert.True(t, foundWeather, "Weather tool should be in search results") +} + +// TestDatabase_TokenCounting tests token counting functionality +func TestDatabase_TokenCounting(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDatabase(t) + defer func() { _ = db.Close() }() + + // Create tool with known token count + description := "Test tool" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &description, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 100, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + err := db.CreateTool(ctx, tool, "Test Server") + require.NoError(t, err) + + // Get total tokens - should not error even if FTS isn't fully populated yet + totalTokens, err := db.GetTotalToolTokens(ctx) + require.NoError(t, err) + // Token counting via FTS may have some timing issues in tests + assert.GreaterOrEqual(t, totalTokens, 0) + + // Add another tool + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "test_tool_2", + Description: &description, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 150, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + err = db.CreateTool(ctx, tool2, "Test Server") + require.NoError(t, err) + + // Get total tokens again + totalTokens, err = db.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.GreaterOrEqual(t, totalTokens, 0) +} + +// TestDatabase_Reset tests database reset functionality +func TestDatabase_Reset(t *testing.T) { + t.Parallel() + ctx := context.Background() + + db := createTestDatabase(t) + defer func() { _ = db.Close() }() + + // Add some data + description := "Test server" + server := &models.BackendServer{ + ID: "server-1", + Name: "Test Server", + Description: &description, + Group: "default", + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + err := db.CreateOrUpdateServer(ctx, server) + require.NoError(t, err) + + toolDesc := "Test tool" + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: &toolDesc, + InputSchema: []byte(`{"type": "object"}`), + TokenCount: 100, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + err = db.CreateTool(ctx, tool, "Test Server") + require.NoError(t, err) + + // Reset database + db.Reset() + + // Verify data is cleared + tools, err := db.ListToolsByServer(ctx, "server-1") + require.NoError(t, err) + assert.Empty(t, tools) +} + +// Helper function to create a test database +func createTestDatabase(t *testing.T) Database { + t.Helper() + tmpDir := t.TempDir() + + // Create embedding function + embeddingFunc := func(_ context.Context, text string) ([]float32, error) { + // Try to use Ollama if available, otherwise use simple test embeddings + config := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + manager, err := embeddings.NewManager(config) + if err != nil { + // Ollama not available, use simple test embeddings + embedding := make([]float32, 384) + for i := range embedding { + embedding[i] = float32(len(text)) * 0.001 + } + if len(text) > 0 { + embedding[0] = float32(text[0]) + } + return embedding, nil + } + defer func() { _ = manager.Close() }() + + results, err := manager.GenerateEmbedding([]string{text}) + if err != nil { + // Fallback to simple embeddings + embedding := make([]float32, 384) + for i := range embedding { + embedding[i] = float32(len(text)) * 0.001 + } + return embedding, nil + } + if len(results) == 0 { + return nil, assert.AnError + } + return results[0], nil + } + + config := &Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + FTSDBPath: ":memory:", + } + + db, err := NewDatabase(config, embeddingFunc) + require.NoError(t, err) + + return db +} diff --git a/pkg/vmcp/optimizer/internal/db/db.go b/pkg/vmcp/optimizer/internal/db/db.go new file mode 100644 index 0000000000..0644c7d2b2 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/db.go @@ -0,0 +1,217 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + + "github.com/philippgille/chromem-go" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// Config holds database configuration +// +// The optimizer database is designed to be ephemeral - it's rebuilt from scratch +// on each startup by ingesting MCP backends. Persistence is optional and primarily +// useful for development/debugging to avoid re-generating embeddings. +type Config struct { + // PersistPath is the optional path for chromem-go persistence. + // If empty, chromem-go will be in-memory only (recommended for production). + PersistPath string + + // FTSDBPath is the path for SQLite FTS5 database for BM25 search. + // If empty, defaults to ":memory:" for in-memory FTS5, or "{PersistPath}/fts.db" if PersistPath is set. + // FTS5 is always enabled for hybrid search. + FTSDBPath string +} + +// chromemDB represents the hybrid database (chromem-go + SQLite FTS5) for optimizer data +// This is a private implementation detail. Use the Database interface instead. +type chromemDB struct { + config *Config + chromem *chromem.DB // Vector/semantic search + fts *FTSDatabase // BM25 full-text search (optional) + mu sync.RWMutex +} + +// Collection names +// +// Terminology: We use "backend_servers" and "backend_tools" to be explicit about +// tracking MCP server metadata. While vMCP uses "Backend" for the workload concept, +// the optimizer focuses on the MCP server component for semantic search and tool discovery. +// This naming convention provides clarity across the database layer. +const ( + BackendServerCollection = "backend_servers" + BackendToolCollection = "backend_tools" +) + +// newChromemDB creates a new chromem-go database with FTS5 for hybrid search +// This is a private function. Use NewDatabase instead. +func newChromemDB(config *Config) (*chromemDB, error) { + var chromemInstance *chromem.DB + var err error + + if config.PersistPath != "" { + logger.Infof("Creating chromem-go database with persistence at: %s", config.PersistPath) + chromemInstance, err = chromem.NewPersistentDB(config.PersistPath, false) + if err != nil { + // Check if error is due to corrupted database (missing collection metadata) + if strings.Contains(err.Error(), "collection metadata file not found") { + logger.Warnf("Database appears corrupted, attempting to remove and recreate: %v", err) + // Try to remove corrupted database directory + // Use RemoveAll which should handle directories recursively + // If it fails, we'll try to create with a new path or fall back to in-memory + if removeErr := os.RemoveAll(config.PersistPath); removeErr != nil { + logger.Warnf("Failed to remove corrupted database directory (may be in use): %v. Will try to recreate anyway.", removeErr) + // Try to rename the corrupted directory and create a new one + backupPath := config.PersistPath + ".corrupted" + if renameErr := os.Rename(config.PersistPath, backupPath); renameErr != nil { + logger.Warnf("Failed to rename corrupted database: %v. Attempting to create database anyway.", renameErr) + // Continue and let chromem-go handle it - it might work if the corruption is partial + } else { + logger.Infof("Renamed corrupted database to: %s", backupPath) + } + } + // Retry creating the database + chromemInstance, err = chromem.NewPersistentDB(config.PersistPath, false) + if err != nil { + // If still failing, return the error but suggest manual cleanup + return nil, fmt.Errorf( + "failed to create persistent database after cleanup attempt. Please manually remove %s and try again: %w", + config.PersistPath, err) + } + logger.Info("Successfully recreated database after cleanup") + } else { + return nil, fmt.Errorf("failed to create persistent database: %w", err) + } + } + } else { + logger.Info("Creating in-memory chromem-go database") + chromemInstance = chromem.NewDB() + } + + db := &chromemDB{ + config: config, + chromem: chromemInstance, + } + + // Set default FTS5 path if not provided + ftsPath := config.FTSDBPath + if ftsPath == "" { + if config.PersistPath != "" { + // Persistent mode: store FTS5 alongside chromem-go + ftsPath = config.PersistPath + "/fts.db" + } else { + // In-memory mode: use SQLite in-memory database + ftsPath = ":memory:" + } + } + + // Initialize FTS5 database for BM25 text search (always enabled) + logger.Infof("Initializing FTS5 database for hybrid search at: %s", ftsPath) + ftsDB, err := NewFTSDatabase(&FTSConfig{DBPath: ftsPath}) + if err != nil { + return nil, fmt.Errorf("failed to create FTS5 database: %w", err) + } + db.fts = ftsDB + logger.Info("Hybrid search enabled (chromem-go + FTS5)") + + logger.Info("Optimizer database initialized successfully") + return db, nil +} + +// getOrCreateCollection gets an existing collection or creates a new one +func (db *chromemDB) getOrCreateCollection( + _ context.Context, + name string, + embeddingFunc chromem.EmbeddingFunc, +) (*chromem.Collection, error) { + db.mu.Lock() + defer db.mu.Unlock() + + // Try to get existing collection first + collection := db.chromem.GetCollection(name, embeddingFunc) + if collection != nil { + return collection, nil + } + + // Create new collection if it doesn't exist + collection, err := db.chromem.CreateCollection(name, nil, embeddingFunc) + if err != nil { + return nil, fmt.Errorf("failed to create collection %s: %w", name, err) + } + + logger.Debugf("Created new collection: %s", name) + return collection, nil +} + +// getCollection gets an existing collection +func (db *chromemDB) getCollection(name string, embeddingFunc chromem.EmbeddingFunc) (*chromem.Collection, error) { + db.mu.RLock() + defer db.mu.RUnlock() + + collection := db.chromem.GetCollection(name, embeddingFunc) + if collection == nil { + return nil, fmt.Errorf("collection not found: %s", name) + } + return collection, nil +} + +// deleteCollection deletes a collection +func (db *chromemDB) deleteCollection(name string) { + db.mu.Lock() + defer db.mu.Unlock() + + //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error + db.chromem.DeleteCollection(name) + logger.Debugf("Deleted collection: %s", name) +} + +// close closes both databases +func (db *chromemDB) close() error { + logger.Info("Closing optimizer databases") + // chromem-go doesn't need explicit close, but FTS5 does + if db.fts != nil { + if err := db.fts.Close(); err != nil { + return fmt.Errorf("failed to close FTS database: %w", err) + } + } + return nil +} + +// getChromemDB returns the underlying chromem.DB instance +func (db *chromemDB) getChromemDB() *chromem.DB { + return db.chromem +} + +// getFTSDB returns the FTS database (may be nil if FTS is disabled) +func (db *chromemDB) getFTSDB() *FTSDatabase { + return db.fts +} + +// reset clears all collections and FTS tables (useful for testing and startup) +func (db *chromemDB) reset() { + db.mu.Lock() + defer db.mu.Unlock() + + //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error + db.chromem.DeleteCollection(BackendServerCollection) + //nolint:errcheck,gosec // DeleteCollection in chromem-go doesn't return an error + db.chromem.DeleteCollection(BackendToolCollection) + + // Clear FTS5 tables if available + if db.fts != nil { + //nolint:errcheck // Best effort cleanup + _, _ = db.fts.db.Exec("DELETE FROM backend_tools_fts") + //nolint:errcheck // Best effort cleanup + _, _ = db.fts.db.Exec("DELETE FROM backend_servers_fts") + } + + logger.Debug("Reset all collections and FTS tables") +} diff --git a/pkg/vmcp/optimizer/internal/db/db_test.go b/pkg/vmcp/optimizer/internal/db/db_test.go new file mode 100644 index 0000000000..197015a772 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/db_test.go @@ -0,0 +1,305 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestNewDB_CorruptedDatabase tests database recovery from corruption +func TestNewDB_CorruptedDatabase(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "corrupted-db") + + // Create a directory that looks like a corrupted database + err := os.MkdirAll(dbPath, 0755) + require.NoError(t, err) + + // Create a file that might cause issues + err = os.WriteFile(filepath.Join(dbPath, "some-file"), []byte("corrupted"), 0644) + require.NoError(t, err) + + config := &Config{ + PersistPath: dbPath, + } + + // Should recover from corruption + db, err := newChromemDB(config) + require.NoError(t, err) + require.NotNil(t, db) + defer func() { _ = db.close() }() +} + +// TestNewDB_CorruptedDatabase_RecoveryFailure tests when recovery fails +func TestNewDB_CorruptedDatabase_RecoveryFailure(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + dbPath := filepath.Join(tmpDir, "corrupted-db") + + // Create a directory that looks like a corrupted database + err := os.MkdirAll(dbPath, 0755) + require.NoError(t, err) + + // Create a file that might cause issues + err = os.WriteFile(filepath.Join(dbPath, "some-file"), []byte("corrupted"), 0644) + require.NoError(t, err) + + // Make directory read-only to simulate recovery failure + // Note: This might not work on all systems, so we'll test the error path differently + // Instead, we'll test with an invalid path that can't be created + config := &Config{ + PersistPath: "/invalid/path/that/does/not/exist", + } + + _, err = newChromemDB(config) + // Should return error for invalid path + assert.Error(t, err) +} + +// TestDB_GetOrCreateCollection tests collection creation and retrieval +func TestDB_GetOrCreateCollection(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := newChromemDB(config) + require.NoError(t, err) + defer func() { _ = db.close() }() + + // Create a simple embedding function + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Get or create collection + collection, err := db.getOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + require.NotNil(t, collection) + + // Get existing collection + collection2, err := db.getOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + require.NotNil(t, collection2) + assert.Equal(t, collection, collection2) +} + +// TestDB_GetCollection tests collection retrieval +func TestDB_GetCollection(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := newChromemDB(config) + require.NoError(t, err) + defer func() { _ = db.close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Get non-existent collection should fail + _, err = db.getCollection("non-existent", embeddingFunc) + assert.Error(t, err) + + // Create collection first + _, err = db.getOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + + // Now get it + collection, err := db.getCollection("test-collection", embeddingFunc) + require.NoError(t, err) + require.NotNil(t, collection) +} + +// TestDB_DeleteCollection tests collection deletion +func TestDB_DeleteCollection(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := newChromemDB(config) + require.NoError(t, err) + defer func() { _ = db.close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Create collection + _, err = db.getOrCreateCollection(ctx, "test-collection", embeddingFunc) + require.NoError(t, err) + + // Delete collection + db.deleteCollection("test-collection") + + // Verify it's deleted + _, err = db.getCollection("test-collection", embeddingFunc) + assert.Error(t, err) +} + +// TestDB_Reset tests database reset +func TestDB_Reset(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := newChromemDB(config) + require.NoError(t, err) + defer func() { _ = db.close() }() + + embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { + return []float32{0.1, 0.2, 0.3}, nil + } + + // Create collections + _, err = db.getOrCreateCollection(ctx, BackendServerCollection, embeddingFunc) + require.NoError(t, err) + + _, err = db.getOrCreateCollection(ctx, BackendToolCollection, embeddingFunc) + require.NoError(t, err) + + // Reset database + db.reset() + + // Verify collections are deleted + _, err = db.getCollection(BackendServerCollection, embeddingFunc) + assert.Error(t, err) + + _, err = db.getCollection(BackendToolCollection, embeddingFunc) + assert.Error(t, err) +} + +// TestDB_GetChromemDB tests chromem DB accessor +func TestDB_GetChromemDB(t *testing.T) { + t.Parallel() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := newChromemDB(config) + require.NoError(t, err) + defer func() { _ = db.close() }() + + chromemDB := db.getChromemDB() + require.NotNil(t, chromemDB) +} + +// TestDB_GetFTSDB tests FTS DB accessor +func TestDB_GetFTSDB(t *testing.T) { + t.Parallel() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := newChromemDB(config) + require.NoError(t, err) + defer func() { _ = db.close() }() + + ftsDB := db.getFTSDB() + require.NotNil(t, ftsDB) +} + +// TestDB_Close tests database closing +func TestDB_Close(t *testing.T) { + t.Parallel() + + config := &Config{ + PersistPath: "", // In-memory + } + + db, err := newChromemDB(config) + require.NoError(t, err) + + err = db.close() + require.NoError(t, err) + + // Multiple closes should be safe + err = db.close() + require.NoError(t, err) +} + +// TestNewDB_FTSDBPath tests FTS database path configuration +func TestNewDB_FTSDBPath(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + tests := []struct { + name string + config *Config + wantErr bool + }{ + { + name: "in-memory FTS with persistent chromem", + config: &Config{ + PersistPath: filepath.Join(tmpDir, "db"), + FTSDBPath: ":memory:", + }, + wantErr: false, + }, + { + name: "persistent FTS with persistent chromem", + config: &Config{ + PersistPath: filepath.Join(tmpDir, "db2"), + FTSDBPath: filepath.Join(tmpDir, "fts.db"), + }, + wantErr: false, + }, + { + name: "default FTS path with persistent chromem", + config: &Config{ + PersistPath: filepath.Join(tmpDir, "db3"), + // FTSDBPath not set, should default to {PersistPath}/fts.db + }, + wantErr: false, + }, + { + name: "in-memory FTS with in-memory chromem", + config: &Config{ + PersistPath: "", + FTSDBPath: ":memory:", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + db, err := newChromemDB(tt.config) + if tt.wantErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + require.NotNil(t, db) + defer func() { _ = db.close() }() + + // Verify FTS DB is accessible + ftsDB := db.getFTSDB() + require.NotNil(t, ftsDB) + } + }) + } +} diff --git a/pkg/vmcp/optimizer/internal/db/fts.go b/pkg/vmcp/optimizer/internal/db/fts.go new file mode 100644 index 0000000000..a325ab5e48 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/fts.go @@ -0,0 +1,360 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "database/sql" + _ "embed" + "fmt" + "strings" + "sync" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" +) + +//go:embed schema_fts.sql +var schemaFTS string + +// FTSConfig holds FTS5 database configuration +type FTSConfig struct { + // DBPath is the path to the SQLite database file + // If empty, uses ":memory:" for in-memory database + DBPath string +} + +// FTSDatabase handles FTS5 (BM25) search operations +type FTSDatabase struct { + config *FTSConfig + db *sql.DB + mu sync.RWMutex +} + +// NewFTSDatabase creates a new FTS5 database for BM25 search +func NewFTSDatabase(config *FTSConfig) (*FTSDatabase, error) { + dbPath := config.DBPath + if dbPath == "" { + dbPath = ":memory:" + } + + // Open with modernc.org/sqlite (pure Go) + sqlDB, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, fmt.Errorf("failed to open FTS database: %w", err) + } + + // Set pragmas for performance + pragmas := []string{ + "PRAGMA journal_mode=WAL", + "PRAGMA synchronous=NORMAL", + "PRAGMA foreign_keys=ON", + "PRAGMA busy_timeout=5000", + } + + for _, pragma := range pragmas { + if _, err := sqlDB.Exec(pragma); err != nil { + _ = sqlDB.Close() + return nil, fmt.Errorf("failed to set pragma: %w", err) + } + } + + ftsDB := &FTSDatabase{ + config: config, + db: sqlDB, + } + + // Initialize schema + if err := ftsDB.initializeSchema(); err != nil { + _ = sqlDB.Close() + return nil, fmt.Errorf("failed to initialize FTS schema: %w", err) + } + + logger.Infof("FTS5 database initialized successfully at: %s", dbPath) + return ftsDB, nil +} + +// initializeSchema creates the FTS5 tables and triggers +// +// Note: We execute the schema directly rather than using a migration framework +// because the FTS database is ephemeral (destroyed on shutdown, recreated on startup). +// Migrations are only needed when you need to preserve data across schema changes. +func (fts *FTSDatabase) initializeSchema() error { + fts.mu.Lock() + defer fts.mu.Unlock() + + _, err := fts.db.Exec(schemaFTS) + if err != nil { + return fmt.Errorf("failed to execute schema: %w", err) + } + + logger.Debug("FTS5 schema initialized") + return nil +} + +// UpsertServer inserts or updates a server in the FTS database +func (fts *FTSDatabase) UpsertServer( + ctx context.Context, + server *models.BackendServer, +) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + query := ` + INSERT INTO backend_servers_fts (id, name, description, server_group, last_updated, created_at) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + name = excluded.name, + description = excluded.description, + server_group = excluded.server_group, + last_updated = excluded.last_updated + ` + + _, err := fts.db.ExecContext( + ctx, + query, + server.ID, + server.Name, + server.Description, + server.Group, + server.LastUpdated, + server.CreatedAt, + ) + + if err != nil { + return fmt.Errorf("failed to upsert server in FTS: %w", err) + } + + logger.Debugf("Upserted server in FTS: %s", server.ID) + return nil +} + +// UpsertToolMeta inserts or updates a tool in the FTS database +func (fts *FTSDatabase) UpsertToolMeta( + ctx context.Context, + tool *models.BackendTool, + _ string, // serverName - unused, keeping for interface compatibility +) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + // Convert input schema to JSON string + var schemaStr *string + if len(tool.InputSchema) > 0 { + str := string(tool.InputSchema) + schemaStr = &str + } + + query := ` + INSERT INTO backend_tools_fts ( + id, mcpserver_id, tool_name, tool_description, + input_schema, token_count, last_updated, created_at + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(id) DO UPDATE SET + mcpserver_id = excluded.mcpserver_id, + tool_name = excluded.tool_name, + tool_description = excluded.tool_description, + input_schema = excluded.input_schema, + token_count = excluded.token_count, + last_updated = excluded.last_updated + ` + + _, err := fts.db.ExecContext( + ctx, + query, + tool.ID, + tool.MCPServerID, + tool.ToolName, + tool.Description, + schemaStr, + tool.TokenCount, + tool.LastUpdated, + tool.CreatedAt, + ) + + if err != nil { + return fmt.Errorf("failed to upsert tool in FTS: %w", err) + } + + logger.Debugf("Upserted tool in FTS: %s", tool.ToolName) + return nil +} + +// DeleteServer removes a server and its tools from FTS database +func (fts *FTSDatabase) DeleteServer(ctx context.Context, serverID string) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + // Foreign key cascade will delete related tools + _, err := fts.db.ExecContext(ctx, "DELETE FROM backend_servers_fts WHERE id = ?", serverID) + if err != nil { + return fmt.Errorf("failed to delete server from FTS: %w", err) + } + + logger.Debugf("Deleted server from FTS: %s", serverID) + return nil +} + +// DeleteToolsByServer removes all tools for a server from FTS database +func (fts *FTSDatabase) DeleteToolsByServer(ctx context.Context, serverID string) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + result, err := fts.db.ExecContext(ctx, "DELETE FROM backend_tools_fts WHERE mcpserver_id = ?", serverID) + if err != nil { + return fmt.Errorf("failed to delete tools from FTS: %w", err) + } + + count, _ := result.RowsAffected() + logger.Debugf("Deleted %d tools from FTS for server: %s", count, serverID) + return nil +} + +// DeleteTool removes a tool from FTS database +func (fts *FTSDatabase) DeleteTool(ctx context.Context, toolID string) error { + fts.mu.Lock() + defer fts.mu.Unlock() + + _, err := fts.db.ExecContext(ctx, "DELETE FROM backend_tools_fts WHERE id = ?", toolID) + if err != nil { + return fmt.Errorf("failed to delete tool from FTS: %w", err) + } + + logger.Debugf("Deleted tool from FTS: %s", toolID) + return nil +} + +// SearchBM25 performs BM25 full-text search on tools +func (fts *FTSDatabase) SearchBM25( + ctx context.Context, + query string, + limit int, + serverID *string, +) ([]*models.BackendToolWithMetadata, error) { + fts.mu.RLock() + defer fts.mu.RUnlock() + + // Sanitize FTS5 query + sanitizedQuery := sanitizeFTS5Query(query) + if sanitizedQuery == "" { + return []*models.BackendToolWithMetadata{}, nil + } + + // Build query with optional server filter + sqlQuery := ` + SELECT + t.id, + t.mcpserver_id, + t.tool_name, + t.tool_description, + t.input_schema, + t.token_count, + t.last_updated, + t.created_at, + fts.rank + FROM backend_tool_fts_index fts + JOIN backend_tools_fts t ON fts.tool_id = t.id + WHERE backend_tool_fts_index MATCH ? + ` + + args := []interface{}{sanitizedQuery} + + if serverID != nil { + sqlQuery += " AND t.mcpserver_id = ?" + args = append(args, *serverID) + } + + sqlQuery += " ORDER BY rank LIMIT ?" + args = append(args, limit) + + rows, err := fts.db.QueryContext(ctx, sqlQuery, args...) + if err != nil { + return nil, fmt.Errorf("failed to search tools: %w", err) + } + defer func() { _ = rows.Close() }() + + var results []*models.BackendToolWithMetadata + for rows.Next() { + var tool models.BackendTool + var schemaStr sql.NullString + var rank float32 + + err := rows.Scan( + &tool.ID, + &tool.MCPServerID, + &tool.ToolName, + &tool.Description, + &schemaStr, + &tool.TokenCount, + &tool.LastUpdated, + &tool.CreatedAt, + &rank, + ) + if err != nil { + logger.Warnf("Failed to scan tool row: %v", err) + continue + } + + if schemaStr.Valid { + tool.InputSchema = []byte(schemaStr.String) + } + + // Convert BM25 rank to similarity score (higher is better) + // FTS5 rank is negative, so we negate and normalize + similarity := float32(1.0 / (1.0 - float64(rank))) + + results = append(results, &models.BackendToolWithMetadata{ + BackendTool: tool, + Similarity: similarity, + }) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating tool rows: %w", err) + } + + logger.Debugf("BM25 search found %d tools for query: %s", len(results), query) + return results, nil +} + +// GetTotalToolTokens returns the sum of token_count across all tools +func (fts *FTSDatabase) GetTotalToolTokens(ctx context.Context) (int, error) { + fts.mu.RLock() + defer fts.mu.RUnlock() + + var totalTokens int + query := "SELECT COALESCE(SUM(token_count), 0) FROM backend_tools_fts" + + err := fts.db.QueryRowContext(ctx, query).Scan(&totalTokens) + if err != nil { + return 0, fmt.Errorf("failed to get total tool tokens: %w", err) + } + + return totalTokens, nil +} + +// Close closes the FTS database connection +func (fts *FTSDatabase) Close() error { + return fts.db.Close() +} + +// sanitizeFTS5Query escapes special characters in FTS5 queries +// FTS5 uses: " * ( ) AND OR NOT +func sanitizeFTS5Query(query string) string { + // Remove or escape special FTS5 characters + replacer := strings.NewReplacer( + `"`, `""`, // Escape quotes + `*`, ` `, // Remove wildcards + `(`, ` `, // Remove parentheses + `)`, ` `, + ) + + sanitized := replacer.Replace(query) + + // Remove multiple spaces + sanitized = strings.Join(strings.Fields(sanitized), " ") + + return strings.TrimSpace(sanitized) +} diff --git a/pkg/vmcp/optimizer/internal/db/fts_test_coverage.go b/pkg/vmcp/optimizer/internal/db/fts_test_coverage.go new file mode 100644 index 0000000000..ab358020ae --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/fts_test_coverage.go @@ -0,0 +1,162 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" +) + +// stringPtr returns a pointer to the given string +func stringPtr(s string) *string { + return &s +} + +// TestFTSDatabase_GetTotalToolTokens tests token counting +func TestFTSDatabase_GetTotalToolTokens(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &FTSConfig{ + DBPath: ":memory:", + } + + ftsDB, err := NewFTSDatabase(config) + require.NoError(t, err) + defer func() { _ = ftsDB.Close() }() + + // Initially should be 0 + totalTokens, err := ftsDB.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.Equal(t, 0, totalTokens) + + // Add a tool + tool := &models.BackendTool{ + ID: "tool-1", + MCPServerID: "server-1", + ToolName: "test_tool", + Description: stringPtr("Test tool"), + TokenCount: 100, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + err = ftsDB.UpsertToolMeta(ctx, tool, "TestServer") + require.NoError(t, err) + + // Should now have tokens + totalTokens, err = ftsDB.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.Equal(t, 100, totalTokens) + + // Add another tool + tool2 := &models.BackendTool{ + ID: "tool-2", + MCPServerID: "server-1", + ToolName: "test_tool2", + Description: stringPtr("Test tool 2"), + TokenCount: 50, + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + err = ftsDB.UpsertToolMeta(ctx, tool2, "TestServer") + require.NoError(t, err) + + // Should sum tokens + totalTokens, err = ftsDB.GetTotalToolTokens(ctx) + require.NoError(t, err) + assert.Equal(t, 150, totalTokens) +} + +// TestSanitizeFTS5Query tests query sanitization +func TestSanitizeFTS5Query(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "remove quotes", + input: `"test query"`, + expected: "test query", + }, + { + name: "remove wildcards", + input: "test*query", + expected: "test query", + }, + { + name: "remove parentheses", + input: "test(query)", + expected: "test query", + }, + { + name: "remove multiple spaces", + input: "test query", + expected: "test query", + }, + { + name: "trim whitespace", + input: " test query ", + expected: "test query", + }, + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "only special characters", + input: `"*()`, + expected: "", + }, + { + name: "mixed special characters", + input: `test"query*with(special)chars`, + expected: "test query with special chars", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := sanitizeFTS5Query(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} + +// TestFTSDatabase_SearchBM25_EmptyQuery tests empty query handling +func TestFTSDatabase_SearchBM25_EmptyQuery(t *testing.T) { + t.Parallel() + ctx := context.Background() + + config := &FTSConfig{ + DBPath: ":memory:", + } + + ftsDB, err := NewFTSDatabase(config) + require.NoError(t, err) + defer func() { _ = ftsDB.Close() }() + + // Empty query should return empty results + results, err := ftsDB.SearchBM25(ctx, "", 10, nil) + require.NoError(t, err) + assert.Empty(t, results) + + // Query with only special characters should return empty results + results, err = ftsDB.SearchBM25(ctx, `"*()`, 10, nil) + require.NoError(t, err) + assert.Empty(t, results) +} diff --git a/pkg/vmcp/optimizer/internal/db/hybrid.go b/pkg/vmcp/optimizer/internal/db/hybrid.go new file mode 100644 index 0000000000..82059dcb85 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/hybrid.go @@ -0,0 +1,172 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + "fmt" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" +) + +// HybridSearchConfig configures hybrid search behavior +type HybridSearchConfig struct { + // SemanticRatio controls the mix of semantic vs BM25 results (0-100, representing percentage) + // Default: 70 (70% semantic, 30% BM25) + SemanticRatio int + + // Limit is the total number of results to return + Limit int + + // ServerID optionally filters results to a specific server + ServerID *string +} + +// DefaultHybridConfig returns sensible defaults for hybrid search +func DefaultHybridConfig() *HybridSearchConfig { + return &HybridSearchConfig{ + SemanticRatio: 70, + Limit: 10, + } +} + +// searchHybrid performs hybrid search combining semantic (chromem-go) and BM25 (FTS5) results +// This matches the Python mcp-optimizer's hybrid search implementation +func (ops *backendToolOps) searchHybrid( + ctx context.Context, + queryText string, + config *HybridSearchConfig, +) ([]*models.BackendToolWithMetadata, error) { + if config == nil { + config = DefaultHybridConfig() + } + + // Calculate limits for each search method + // Convert percentage to ratio (0-100 -> 0.0-1.0) + semanticRatioFloat := float64(config.SemanticRatio) / 100.0 + semanticLimit := max(1, int(float64(config.Limit)*semanticRatioFloat)) + bm25Limit := max(1, config.Limit-semanticLimit) + + logger.Debugf( + "Hybrid search: semantic_limit=%d, bm25_limit=%d, ratio=%d%%", + semanticLimit, bm25Limit, config.SemanticRatio, + ) + + // Execute both searches in parallel + type searchResult struct { + results []*models.BackendToolWithMetadata + err error + } + + semanticCh := make(chan searchResult, 1) + bm25Ch := make(chan searchResult, 1) + + // Semantic search + go func() { + results, err := ops.search(ctx, queryText, semanticLimit, config.ServerID) + semanticCh <- searchResult{results, err} + }() + + // BM25 search + go func() { + results, err := ops.db.fts.SearchBM25(ctx, queryText, bm25Limit, config.ServerID) + bm25Ch <- searchResult{results, err} + }() + + // Collect results + var semanticResults, bm25Results []*models.BackendToolWithMetadata + var errs []error + + // Wait for semantic results + semanticRes := <-semanticCh + if semanticRes.err != nil { + logger.Warnf("Semantic search failed: %v", semanticRes.err) + errs = append(errs, semanticRes.err) + } else { + semanticResults = semanticRes.results + } + + // Wait for BM25 results + bm25Res := <-bm25Ch + if bm25Res.err != nil { + logger.Warnf("BM25 search failed: %v", bm25Res.err) + errs = append(errs, bm25Res.err) + } else { + bm25Results = bm25Res.results + } + + // If both failed, return error + if len(errs) == 2 { + return nil, fmt.Errorf("both search methods failed: semantic=%v, bm25=%v", errs[0], errs[1]) + } + + // Combine and deduplicate results + combined := combineAndDeduplicateResults(semanticResults, bm25Results, config.Limit) + + logger.Infof( + "Hybrid search completed: semantic=%d, bm25=%d, combined=%d (requested=%d)", + len(semanticResults), len(bm25Results), len(combined), config.Limit, + ) + + return combined, nil +} + +// combineAndDeduplicateResults merges semantic and BM25 results, removing duplicates +// Keeps the result with the higher similarity score for duplicates +func combineAndDeduplicateResults( + semantic, bm25 []*models.BackendToolWithMetadata, + limit int, +) []*models.BackendToolWithMetadata { + // Use a map to deduplicate by tool ID + seen := make(map[string]*models.BackendToolWithMetadata) + + // Add semantic results first (they typically have higher quality) + for _, result := range semantic { + seen[result.ID] = result + } + + // Add BM25 results, only if not seen or if similarity is higher + for _, result := range bm25 { + if existing, exists := seen[result.ID]; exists { + // Keep the one with higher similarity + if result.Similarity > existing.Similarity { + seen[result.ID] = result + } + } else { + seen[result.ID] = result + } + } + + // Convert map to slice + combined := make([]*models.BackendToolWithMetadata, 0, len(seen)) + for _, result := range seen { + combined = append(combined, result) + } + + // Sort by similarity (descending) and limit + sortedResults := sortBySimilarity(combined) + if len(sortedResults) > limit { + sortedResults = sortedResults[:limit] + } + + return sortedResults +} + +// sortBySimilarity sorts results by similarity score in descending order +func sortBySimilarity(results []*models.BackendToolWithMetadata) []*models.BackendToolWithMetadata { + // Simple bubble sort (fine for small result sets) + sorted := make([]*models.BackendToolWithMetadata, len(results)) + copy(sorted, results) + + for i := 0; i < len(sorted); i++ { + for j := i + 1; j < len(sorted); j++ { + if sorted[j].Similarity > sorted[i].Similarity { + sorted[i], sorted[j] = sorted[j], sorted[i] + } + } + } + + return sorted +} diff --git a/pkg/vmcp/optimizer/internal/db/interface.go b/pkg/vmcp/optimizer/internal/db/interface.go new file mode 100644 index 0000000000..37e0c82884 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/interface.go @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package db + +import ( + "context" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" +) + +// Database is the main interface for optimizer database operations. +// It provides methods for managing backend servers and tools with hybrid search capabilities. +type Database interface { + // Server operations + CreateOrUpdateServer(ctx context.Context, server *models.BackendServer) error + DeleteServer(ctx context.Context, serverID string) error + + // Tool operations + CreateTool(ctx context.Context, tool *models.BackendTool, serverName string) error + DeleteToolsByServer(ctx context.Context, serverID string) error + SearchToolsHybrid(ctx context.Context, query string, config *HybridSearchConfig) ([]*models.BackendToolWithMetadata, error) + ListToolsByServer(ctx context.Context, serverID string) ([]*models.BackendTool, error) + + // Statistics + GetTotalToolTokens(ctx context.Context) (int, error) + + // Lifecycle + Reset() + Close() error +} diff --git a/pkg/vmcp/optimizer/internal/db/schema_fts.sql b/pkg/vmcp/optimizer/internal/db/schema_fts.sql new file mode 100644 index 0000000000..101dbea7d7 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/schema_fts.sql @@ -0,0 +1,120 @@ +-- FTS5 schema for BM25 full-text search +-- Complements chromem-go (which handles vector/semantic search) +-- +-- This schema only contains: +-- 1. Metadata tables for tool/server information +-- 2. FTS5 virtual tables for BM25 keyword search +-- +-- Note: chromem-go handles embeddings separately in memory/persistent storage + +-- Backend servers metadata (for FTS queries and joining) +CREATE TABLE IF NOT EXISTS backend_servers_fts ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT, + server_group TEXT NOT NULL DEFAULT 'default', + last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_backend_servers_fts_group ON backend_servers_fts(server_group); + +-- Backend tools metadata (for FTS queries and joining) +CREATE TABLE IF NOT EXISTS backend_tools_fts ( + id TEXT PRIMARY KEY, + mcpserver_id TEXT NOT NULL, + tool_name TEXT NOT NULL, + tool_description TEXT, + input_schema TEXT, -- JSON string + token_count INTEGER NOT NULL DEFAULT 0, + last_updated TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (mcpserver_id) REFERENCES backend_servers_fts(id) ON DELETE CASCADE +); + +CREATE INDEX IF NOT EXISTS idx_backend_tools_fts_server ON backend_tools_fts(mcpserver_id); +CREATE INDEX IF NOT EXISTS idx_backend_tools_fts_name ON backend_tools_fts(tool_name); + +-- FTS5 virtual table for backend tools +-- Uses Porter stemming for better keyword matching +-- Indexes: server name, tool name, and tool description +CREATE VIRTUAL TABLE IF NOT EXISTS backend_tool_fts_index +USING fts5( + tool_id UNINDEXED, + mcp_server_name, + tool_name, + tool_description, + tokenize='porter', + content='backend_tools_fts', + content_rowid='rowid' +); + +-- Triggers to keep FTS5 index in sync with backend_tools_fts table +CREATE TRIGGER IF NOT EXISTS backend_tools_fts_ai AFTER INSERT ON backend_tools_fts BEGIN + INSERT INTO backend_tool_fts_index( + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) + SELECT + rowid, + new.id, + (SELECT name FROM backend_servers_fts WHERE id = new.mcpserver_id), + new.tool_name, + COALESCE(new.tool_description, '') + FROM backend_tools_fts + WHERE id = new.id; +END; + +CREATE TRIGGER IF NOT EXISTS backend_tools_fts_ad AFTER DELETE ON backend_tools_fts BEGIN + INSERT INTO backend_tool_fts_index( + backend_tool_fts_index, + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) VALUES ( + 'delete', + old.rowid, + old.id, + NULL, + NULL, + NULL + ); +END; + +CREATE TRIGGER IF NOT EXISTS backend_tools_fts_au AFTER UPDATE ON backend_tools_fts BEGIN + INSERT INTO backend_tool_fts_index( + backend_tool_fts_index, + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) VALUES ( + 'delete', + old.rowid, + old.id, + NULL, + NULL, + NULL + ); + INSERT INTO backend_tool_fts_index( + rowid, + tool_id, + mcp_server_name, + tool_name, + tool_description + ) + SELECT + rowid, + new.id, + (SELECT name FROM backend_servers_fts WHERE id = new.mcpserver_id), + new.tool_name, + COALESCE(new.tool_description, '') + FROM backend_tools_fts + WHERE id = new.id; +END; diff --git a/pkg/vmcp/optimizer/internal/db/sqlite_fts.go b/pkg/vmcp/optimizer/internal/db/sqlite_fts.go new file mode 100644 index 0000000000..23ae5bcdfb --- /dev/null +++ b/pkg/vmcp/optimizer/internal/db/sqlite_fts.go @@ -0,0 +1,11 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package db provides database operations for the optimizer. +// This file handles FTS5 (Full-Text Search) using modernc.org/sqlite (pure Go). +package db + +import ( + // Pure Go SQLite driver with FTS5 support + _ "modernc.org/sqlite" +) diff --git a/pkg/vmcp/optimizer/internal/embeddings/cache.go b/pkg/vmcp/optimizer/internal/embeddings/cache.go new file mode 100644 index 0000000000..68f6bbe74b --- /dev/null +++ b/pkg/vmcp/optimizer/internal/embeddings/cache.go @@ -0,0 +1,104 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package embeddings provides caching for embedding vectors. +package embeddings + +import ( + "container/list" + "sync" +) + +// cache implements an LRU cache for embeddings +type cache struct { + maxSize int + mu sync.RWMutex + items map[string]*list.Element + lru *list.List + hits int64 + misses int64 +} + +type cacheEntry struct { + key string + value []float32 +} + +// newCache creates a new LRU cache +func newCache(maxSize int) *cache { + return &cache{ + maxSize: maxSize, + items: make(map[string]*list.Element), + lru: list.New(), + } +} + +// Get retrieves an embedding from the cache +func (c *cache) Get(key string) []float32 { + c.mu.Lock() + defer c.mu.Unlock() + + elem, ok := c.items[key] + if !ok { + c.misses++ + return nil + } + + c.hits++ + c.lru.MoveToFront(elem) + return elem.Value.(*cacheEntry).value +} + +// Put stores an embedding in the cache +func (c *cache) Put(key string, value []float32) { + c.mu.Lock() + defer c.mu.Unlock() + + // Check if key already exists + if elem, ok := c.items[key]; ok { + c.lru.MoveToFront(elem) + elem.Value.(*cacheEntry).value = value + return + } + + // Add new entry + entry := &cacheEntry{ + key: key, + value: value, + } + elem := c.lru.PushFront(entry) + c.items[key] = elem + + // Evict if necessary + if c.lru.Len() > c.maxSize { + c.evict() + } +} + +// evict removes the least recently used item +func (c *cache) evict() { + elem := c.lru.Back() + if elem != nil { + c.lru.Remove(elem) + entry := elem.Value.(*cacheEntry) + delete(c.items, entry.key) + } +} + +// Size returns the current cache size +func (c *cache) Size() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.lru.Len() +} + +// Clear clears the cache +func (c *cache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + c.items = make(map[string]*list.Element) + c.lru = list.New() + c.hits = 0 + c.misses = 0 +} diff --git a/pkg/vmcp/optimizer/internal/embeddings/cache_test.go b/pkg/vmcp/optimizer/internal/embeddings/cache_test.go new file mode 100644 index 0000000000..9b16346056 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/embeddings/cache_test.go @@ -0,0 +1,172 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "testing" +) + +func TestCache_GetPut(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Test cache miss + result := c.Get("key1") + if result != nil { + t.Error("Expected cache miss for non-existent key") + } + if c.misses != 1 { + t.Errorf("Expected 1 miss, got %d", c.misses) + } + + // Test cache put and hit + embedding := []float32{1.0, 2.0, 3.0} + c.Put("key1", embedding) + + result = c.Get("key1") + if result == nil { + t.Fatal("Expected cache hit for existing key") + } + if c.hits != 1 { + t.Errorf("Expected 1 hit, got %d", c.hits) + } + + // Verify embedding values + if len(result) != len(embedding) { + t.Errorf("Embedding length mismatch: got %d, want %d", len(result), len(embedding)) + } + for i := range embedding { + if result[i] != embedding[i] { + t.Errorf("Embedding value mismatch at index %d: got %f, want %f", i, result[i], embedding[i]) + } + } +} + +func TestCache_LRUEviction(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Add two items (fills cache) + c.Put("key1", []float32{1.0}) + c.Put("key2", []float32{2.0}) + + if c.Size() != 2 { + t.Errorf("Expected cache size 2, got %d", c.Size()) + } + + // Add third item (should evict key1) + c.Put("key3", []float32{3.0}) + + if c.Size() != 2 { + t.Errorf("Expected cache size 2 after eviction, got %d", c.Size()) + } + + // key1 should be evicted (oldest) + if result := c.Get("key1"); result != nil { + t.Error("key1 should have been evicted") + } + + // key2 and key3 should still exist + if result := c.Get("key2"); result == nil { + t.Error("key2 should still exist") + } + if result := c.Get("key3"); result == nil { + t.Error("key3 should still exist") + } +} + +func TestCache_MoveToFrontOnAccess(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Add two items + c.Put("key1", []float32{1.0}) + c.Put("key2", []float32{2.0}) + + // Access key1 (moves it to front) + c.Get("key1") + + // Add third item (should evict key2, not key1) + c.Put("key3", []float32{3.0}) + + // key1 should still exist (was accessed recently) + if result := c.Get("key1"); result == nil { + t.Error("key1 should still exist (was accessed recently)") + } + + // key2 should be evicted (was oldest) + if result := c.Get("key2"); result != nil { + t.Error("key2 should have been evicted") + } + + // key3 should exist + if result := c.Get("key3"); result == nil { + t.Error("key3 should exist") + } +} + +func TestCache_UpdateExistingKey(t *testing.T) { + t.Parallel() + c := newCache(2) + + // Add initial value + c.Put("key1", []float32{1.0}) + + // Update with new value + newEmbedding := []float32{2.0, 3.0} + c.Put("key1", newEmbedding) + + // Should get updated value + result := c.Get("key1") + if result == nil { + t.Fatal("Expected cache hit for existing key") + } + + if len(result) != len(newEmbedding) { + t.Errorf("Embedding length mismatch: got %d, want %d", len(result), len(newEmbedding)) + } + + // Cache size should still be 1 + if c.Size() != 1 { + t.Errorf("Expected cache size 1, got %d", c.Size()) + } +} + +func TestCache_Clear(t *testing.T) { + t.Parallel() + c := newCache(10) + + // Add some items + c.Put("key1", []float32{1.0}) + c.Put("key2", []float32{2.0}) + c.Put("key3", []float32{3.0}) + + // Access some items to generate stats + c.Get("key1") + c.Get("missing") + + if c.Size() != 3 { + t.Errorf("Expected cache size 3, got %d", c.Size()) + } + + // Clear cache + c.Clear() + + if c.Size() != 0 { + t.Errorf("Expected cache size 0 after clear, got %d", c.Size()) + } + + // Stats should be reset + if c.hits != 0 { + t.Errorf("Expected 0 hits after clear, got %d", c.hits) + } + if c.misses != 0 { + t.Errorf("Expected 0 misses after clear, got %d", c.misses) + } + + // Items should be gone + if result := c.Get("key1"); result != nil { + t.Error("key1 should be gone after clear") + } +} diff --git a/pkg/vmcp/optimizer/internal/embeddings/manager.go b/pkg/vmcp/optimizer/internal/embeddings/manager.go new file mode 100644 index 0000000000..4f29729e3b --- /dev/null +++ b/pkg/vmcp/optimizer/internal/embeddings/manager.go @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "fmt" + "strings" + "sync" + + "github.com/stacklok/toolhive/pkg/logger" +) + +const ( + // DefaultModelAllMiniLM is the default Ollama model name + DefaultModelAllMiniLM = "all-minilm" + // BackendTypeOllama is the Ollama backend type + BackendTypeOllama = "ollama" +) + +// Config holds configuration for the embedding manager +type Config struct { + // BackendType specifies which backend to use: + // - "ollama": Ollama native API (default) + // - "vllm": vLLM OpenAI-compatible API + // - "unified": Generic OpenAI-compatible API (works with both) + // - "openai": OpenAI-compatible API + BackendType string + + // BaseURL is the base URL for the embedding service + // - Ollama: http://127.0.0.1:11434 (or http://localhost:11434, will be normalized to 127.0.0.1) + // - vLLM: http://localhost:8000 + BaseURL string + + // Model is the model name to use + // - Ollama: "all-minilm" (default), "nomic-embed-text" + // - vLLM: "sentence-transformers/all-MiniLM-L6-v2", "intfloat/e5-mistral-7b-instruct" + Model string + + // Dimension is the embedding dimension (default 384 for all-MiniLM-L6-v2) + Dimension int + + // EnableCache enables caching of embeddings + EnableCache bool + + // MaxCacheSize is the maximum number of embeddings to cache (default 1000) + MaxCacheSize int +} + +// Backend interface for different embedding implementations +type Backend interface { + Embed(text string) ([]float32, error) + EmbedBatch(texts []string) ([][]float32, error) + Dimension() int + Close() error +} + +// Manager manages embedding generation using pluggable backends +// Default backend is all-MiniLM-L6-v2 (same model as codegate) +type Manager struct { + config *Config + backend Backend + cache *cache + mu sync.RWMutex +} + +// NewManager creates a new embedding manager +func NewManager(config *Config) (*Manager, error) { + if config.Dimension == 0 { + config.Dimension = 384 // Default dimension for all-MiniLM-L6-v2 + } + + if config.MaxCacheSize == 0 { + config.MaxCacheSize = 1000 + } + + // Default to Ollama + if config.BackendType == "" { + config.BackendType = BackendTypeOllama + } + + // Initialize backend based on configuration + var backend Backend + var err error + + switch config.BackendType { + case BackendTypeOllama: + // Use Ollama native API (requires ollama serve) + baseURL := config.BaseURL + if baseURL == "" { + baseURL = "http://127.0.0.1:11434" + } else { + // Normalize localhost to 127.0.0.1 to avoid IPv6 resolution issues + baseURL = strings.ReplaceAll(baseURL, "localhost", "127.0.0.1") + } + model := config.Model + if model == "" { + model = DefaultModelAllMiniLM // Default: all-MiniLM-L6-v2 + } + // Update dimension if not set and using default model + if config.Dimension == 0 && model == DefaultModelAllMiniLM { + config.Dimension = 384 + } + backend, err = NewOllamaBackend(baseURL, model) + if err != nil { + return nil, fmt.Errorf( + "failed to initialize Ollama backend: %w (ensure 'ollama serve' is running and 'ollama pull %s' has been executed)", + err, DefaultModelAllMiniLM) + } + + case "vllm", "unified", "openai": + // Use OpenAI-compatible API + // vLLM is recommended for production Kubernetes deployments (GPU-accelerated, high-throughput) + // Also supports: Ollama v1 API, OpenAI, or any OpenAI-compatible service + if config.BaseURL == "" { + return nil, fmt.Errorf("BaseURL is required for %s backend", config.BackendType) + } + if config.Model == "" { + return nil, fmt.Errorf("model is required for %s backend", config.BackendType) + } + backend, err = NewOpenAICompatibleBackend(config.BaseURL, config.Model, config.Dimension) + if err != nil { + return nil, fmt.Errorf("failed to initialize %s backend: %w", config.BackendType, err) + } + + default: + return nil, fmt.Errorf("unknown backend type: %s (supported: ollama, vllm, unified, openai)", config.BackendType) + } + + m := &Manager{ + config: config, + backend: backend, + } + + if config.EnableCache { + m.cache = newCache(config.MaxCacheSize) + } + + return m, nil +} + +// GenerateEmbedding generates embeddings for the given texts +// Returns a 2D slice where each row is an embedding for the corresponding text +// Uses all-MiniLM-L6-v2 by default (same model as codegate) +func (m *Manager) GenerateEmbedding(texts []string) ([][]float32, error) { + if len(texts) == 0 { + return nil, fmt.Errorf("no texts provided") + } + + // Check cache for single text requests + if len(texts) == 1 && m.config.EnableCache && m.cache != nil { + if cached := m.cache.Get(texts[0]); cached != nil { + logger.Debugf("Cache hit for embedding") + return [][]float32{cached}, nil + } + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Use backend to generate embeddings + embeddings, err := m.backend.EmbedBatch(texts) + if err != nil { + return nil, fmt.Errorf("failed to generate embeddings: %w", err) + } + + // Cache single embeddings + if len(texts) == 1 && m.config.EnableCache && m.cache != nil { + m.cache.Put(texts[0], embeddings[0]) + } + + logger.Debugf("Generated %d embeddings (dimension: %d)", len(embeddings), m.backend.Dimension()) + return embeddings, nil +} + +// GetCacheStats returns cache statistics +func (m *Manager) GetCacheStats() map[string]interface{} { + if !m.config.EnableCache || m.cache == nil { + return map[string]interface{}{ + "enabled": false, + } + } + + return map[string]interface{}{ + "enabled": true, + "hits": m.cache.hits, + "misses": m.cache.misses, + "size": m.cache.Size(), + "maxsize": m.config.MaxCacheSize, + } +} + +// ClearCache clears the embedding cache +func (m *Manager) ClearCache() { + if m.config.EnableCache && m.cache != nil { + m.cache.Clear() + logger.Info("Embedding cache cleared") + } +} + +// Close releases resources +func (m *Manager) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + if m.backend != nil { + return m.backend.Close() + } + + return nil +} + +// Dimension returns the embedding dimension +func (m *Manager) Dimension() int { + if m.backend != nil { + return m.backend.Dimension() + } + return m.config.Dimension +} diff --git a/pkg/vmcp/optimizer/internal/embeddings/manager_test_coverage.go b/pkg/vmcp/optimizer/internal/embeddings/manager_test_coverage.go new file mode 100644 index 0000000000..529d65ec4c --- /dev/null +++ b/pkg/vmcp/optimizer/internal/embeddings/manager_test_coverage.go @@ -0,0 +1,158 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestManager_GetCacheStats tests cache statistics +func TestManager_GetCacheStats(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: true, + MaxCacheSize: 100, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + stats := manager.GetCacheStats() + require.NotNil(t, stats) + assert.True(t, stats["enabled"].(bool)) + assert.Contains(t, stats, "hits") + assert.Contains(t, stats, "misses") + assert.Contains(t, stats, "size") + assert.Contains(t, stats, "maxsize") +} + +// TestManager_GetCacheStats_Disabled tests cache statistics when cache is disabled +func TestManager_GetCacheStats_Disabled(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + stats := manager.GetCacheStats() + require.NotNil(t, stats) + assert.False(t, stats["enabled"].(bool)) +} + +// TestManager_ClearCache tests cache clearing +func TestManager_ClearCache(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: true, + MaxCacheSize: 100, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + // Clear cache should not panic + manager.ClearCache() + + // Multiple clears should be safe + manager.ClearCache() +} + +// TestManager_ClearCache_Disabled tests cache clearing when cache is disabled +func TestManager_ClearCache_Disabled(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + // Clear cache should not panic even when disabled + manager.ClearCache() +} + +// TestManager_Dimension tests dimension accessor +func TestManager_Dimension(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + dimension := manager.Dimension() + assert.Equal(t, 384, dimension) +} + +// TestManager_Dimension_Default tests default dimension +func TestManager_Dimension_Default(t *testing.T) { + t.Parallel() + + config := &Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + // Dimension not set, should default to 384 + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + defer func() { _ = manager.Close() }() + + dimension := manager.Dimension() + assert.Equal(t, 384, dimension) +} diff --git a/pkg/vmcp/optimizer/internal/embeddings/ollama.go b/pkg/vmcp/optimizer/internal/embeddings/ollama.go new file mode 100644 index 0000000000..6cb6e1f8a2 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/embeddings/ollama.go @@ -0,0 +1,148 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// OllamaBackend implements the Backend interface using Ollama +// This provides local embeddings without remote API calls +// Ollama must be running locally (ollama serve) +type OllamaBackend struct { + baseURL string + model string + dimension int + client *http.Client +} + +type ollamaEmbedRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` +} + +type ollamaEmbedResponse struct { + Embedding []float64 `json:"embedding"` +} + +// normalizeLocalhostURL converts localhost to 127.0.0.1 to avoid IPv6 resolution issues +func normalizeLocalhostURL(url string) string { + // Replace localhost with 127.0.0.1 to ensure IPv4 connection + // This prevents connection refused errors when Ollama only listens on IPv4 + return strings.ReplaceAll(url, "localhost", "127.0.0.1") +} + +// NewOllamaBackend creates a new Ollama backend +// Requires Ollama to be running locally: ollama serve +// Default model: all-minilm (all-MiniLM-L6-v2, 384 dimensions) +func NewOllamaBackend(baseURL, model string) (*OllamaBackend, error) { + if baseURL == "" { + baseURL = "http://127.0.0.1:11434" + } else { + // Normalize localhost to 127.0.0.1 to avoid IPv6 resolution issues + baseURL = normalizeLocalhostURL(baseURL) + } + if model == "" { + model = "all-minilm" // Default embedding model (all-MiniLM-L6-v2) + } + + logger.Infof("Initializing Ollama backend (model: %s, url: %s)", model, baseURL) + + // Determine dimension based on model + dimension := 384 // Default for all-minilm + if model == "nomic-embed-text" { + dimension = 768 + } + + backend := &OllamaBackend{ + baseURL: baseURL, + model: model, + dimension: dimension, + client: &http.Client{}, + } + + // Test connection + resp, err := backend.client.Get(baseURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to Ollama at %s: %w (is 'ollama serve' running?)", baseURL, err) + } + _ = resp.Body.Close() + + logger.Info("Successfully connected to Ollama") + return backend, nil +} + +// Embed generates an embedding for a single text +func (o *OllamaBackend) Embed(text string) ([]float32, error) { + reqBody := ollamaEmbedRequest{ + Model: o.model, + Prompt: text, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + resp, err := o.client.Post( + o.baseURL+"/api/embeddings", + "application/json", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return nil, fmt.Errorf("failed to call Ollama API: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("ollama API returned status %d: %s", resp.StatusCode, string(body)) + } + + var embedResp ollamaEmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + // Convert []float64 to []float32 + embedding := make([]float32, len(embedResp.Embedding)) + for i, v := range embedResp.Embedding { + embedding[i] = float32(v) + } + + return embedding, nil +} + +// EmbedBatch generates embeddings for multiple texts +func (o *OllamaBackend) EmbedBatch(texts []string) ([][]float32, error) { + embeddings := make([][]float32, len(texts)) + + for i, text := range texts { + emb, err := o.Embed(text) + if err != nil { + return nil, fmt.Errorf("failed to embed text %d: %w", i, err) + } + embeddings[i] = emb + } + + return embeddings, nil +} + +// Dimension returns the embedding dimension +func (o *OllamaBackend) Dimension() int { + return o.dimension +} + +// Close releases any resources +func (*OllamaBackend) Close() error { + // HTTP client doesn't need explicit cleanup + return nil +} diff --git a/pkg/vmcp/optimizer/internal/embeddings/ollama_test.go b/pkg/vmcp/optimizer/internal/embeddings/ollama_test.go new file mode 100644 index 0000000000..16d7793e85 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/embeddings/ollama_test.go @@ -0,0 +1,69 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "testing" +) + +func TestOllamaBackend_ConnectionFailure(t *testing.T) { + t.Parallel() + // This test verifies that Ollama backend handles connection failures gracefully + + // Test that NewOllamaBackend handles connection failure gracefully + _, err := NewOllamaBackend("http://localhost:99999", "all-minilm") + if err == nil { + t.Error("Expected error when connecting to invalid Ollama URL") + } +} + +func TestManagerWithOllama(t *testing.T) { + t.Parallel() + // Test that Manager works with Ollama when available + config := &Config{ + BackendType: BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: DefaultModelAllMiniLM, + Dimension: 768, + EnableCache: true, + MaxCacheSize: 100, + } + + manager, err := NewManager(config) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + defer manager.Close() + + // Test single embedding + embeddings, err := manager.GenerateEmbedding([]string{"test text"}) + if err != nil { + // Model might not be pulled - skip gracefully + t.Skipf("Skipping test: Failed to generate embedding. Error: %v. Run 'ollama pull nomic-embed-text'", err) + return + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } + + // Ollama all-minilm uses 384 dimensions + if len(embeddings[0]) != 384 { + t.Errorf("Expected dimension 384, got %d", len(embeddings[0])) + } + + // Test batch embeddings + texts := []string{"text 1", "text 2", "text 3"} + embeddings, err = manager.GenerateEmbedding(texts) + if err != nil { + // Model might not be pulled - skip gracefully + t.Skipf("Skipping test: Failed to generate batch embeddings. Error: %v. Run 'ollama pull nomic-embed-text'", err) + return + } + + if len(embeddings) != 3 { + t.Errorf("Expected 3 embeddings, got %d", len(embeddings)) + } +} diff --git a/pkg/vmcp/optimizer/internal/embeddings/openai_compatible.go b/pkg/vmcp/optimizer/internal/embeddings/openai_compatible.go new file mode 100644 index 0000000000..c98adba54a --- /dev/null +++ b/pkg/vmcp/optimizer/internal/embeddings/openai_compatible.go @@ -0,0 +1,152 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/stacklok/toolhive/pkg/logger" +) + +// OpenAICompatibleBackend implements the Backend interface for OpenAI-compatible APIs. +// +// Supported Services: +// - vLLM: Recommended for production Kubernetes deployments +// - High-throughput GPU-accelerated inference +// - PagedAttention for efficient GPU memory utilization +// - Superior scalability for multi-user environments +// - Ollama: Good for local development (via /v1/embeddings endpoint) +// - OpenAI: For cloud-based embeddings +// - Any OpenAI-compatible embedding service +// +// For production deployments, vLLM is strongly recommended due to its performance +// characteristics and Kubernetes-native design. +type OpenAICompatibleBackend struct { + baseURL string + model string + dimension int + client *http.Client +} + +type openaiEmbedRequest struct { + Model string `json:"model"` + Input string `json:"input"` // OpenAI standard uses "input" +} + +type openaiEmbedResponse struct { + Object string `json:"object"` + Data []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + Model string `json:"model"` +} + +// NewOpenAICompatibleBackend creates a new OpenAI-compatible backend. +// +// Examples: +// - vLLM: NewOpenAICompatibleBackend("http://vllm-service:8000", "sentence-transformers/all-MiniLM-L6-v2", 384) +// - Ollama: NewOpenAICompatibleBackend("http://localhost:11434", "nomic-embed-text", 768) +// - OpenAI: NewOpenAICompatibleBackend("https://api.openai.com", "text-embedding-3-small", 1536) +func NewOpenAICompatibleBackend(baseURL, model string, dimension int) (*OpenAICompatibleBackend, error) { + if baseURL == "" { + return nil, fmt.Errorf("baseURL is required for OpenAI-compatible backend") + } + if model == "" { + return nil, fmt.Errorf("model is required for OpenAI-compatible backend") + } + if dimension == 0 { + dimension = 384 // Default dimension + } + + logger.Infof("Initializing OpenAI-compatible backend (model: %s, url: %s)", model, baseURL) + + backend := &OpenAICompatibleBackend{ + baseURL: baseURL, + model: model, + dimension: dimension, + client: &http.Client{}, + } + + // Test connection + resp, err := backend.client.Get(baseURL) + if err != nil { + return nil, fmt.Errorf("failed to connect to %s: %w", baseURL, err) + } + _ = resp.Body.Close() + + logger.Info("Successfully connected to OpenAI-compatible service") + return backend, nil +} + +// Embed generates an embedding for a single text using OpenAI-compatible API +func (o *OpenAICompatibleBackend) Embed(text string) ([]float32, error) { + reqBody := openaiEmbedRequest{ + Model: o.model, + Input: text, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // Use standard OpenAI v1 endpoint + resp, err := o.client.Post( + o.baseURL+"/v1/embeddings", + "application/json", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return nil, fmt.Errorf("failed to call embeddings API: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + } + + var embedResp openaiEmbedResponse + if err := json.NewDecoder(resp.Body).Decode(&embedResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(embedResp.Data) == 0 { + return nil, fmt.Errorf("no embeddings in response") + } + + return embedResp.Data[0].Embedding, nil +} + +// EmbedBatch generates embeddings for multiple texts +func (o *OpenAICompatibleBackend) EmbedBatch(texts []string) ([][]float32, error) { + embeddings := make([][]float32, len(texts)) + + for i, text := range texts { + emb, err := o.Embed(text) + if err != nil { + return nil, fmt.Errorf("failed to embed text %d: %w", i, err) + } + embeddings[i] = emb + } + + return embeddings, nil +} + +// Dimension returns the embedding dimension +func (o *OpenAICompatibleBackend) Dimension() int { + return o.dimension +} + +// Close releases any resources +func (*OpenAICompatibleBackend) Close() error { + // HTTP client doesn't need explicit cleanup + return nil +} diff --git a/pkg/vmcp/optimizer/internal/embeddings/openai_compatible_test.go b/pkg/vmcp/optimizer/internal/embeddings/openai_compatible_test.go new file mode 100644 index 0000000000..f9a686e953 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/embeddings/openai_compatible_test.go @@ -0,0 +1,226 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package embeddings + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +const testEmbeddingsEndpoint = "/v1/embeddings" + +func TestOpenAICompatibleBackend(t *testing.T) { + t.Parallel() + // Create a test server that mimics OpenAI-compatible API + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == testEmbeddingsEndpoint { + var req openaiEmbedRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("Failed to decode request: %v", err) + } + + // Return a mock embedding response + resp := openaiEmbedResponse{ + Object: "list", + Data: []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + }{ + { + Object: "embedding", + Embedding: make([]float32, 384), + Index: 0, + }, + }, + Model: req.Model, + } + + // Fill with test data + for i := range resp.Data[0].Embedding { + resp.Data[0].Embedding[i] = float32(i) / 384.0 + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + + // Health check endpoint + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Test backend creation + backend, err := NewOpenAICompatibleBackend(server.URL, "test-model", 384) + if err != nil { + t.Fatalf("Failed to create backend: %v", err) + } + defer backend.Close() + + // Test embedding generation + embedding, err := backend.Embed("test text") + if err != nil { + t.Fatalf("Failed to generate embedding: %v", err) + } + + if len(embedding) != 384 { + t.Errorf("Expected embedding dimension 384, got %d", len(embedding)) + } + + // Test batch embedding + texts := []string{"text1", "text2", "text3"} + embeddings, err := backend.EmbedBatch(texts) + if err != nil { + t.Fatalf("Failed to generate batch embeddings: %v", err) + } + + if len(embeddings) != len(texts) { + t.Errorf("Expected %d embeddings, got %d", len(texts), len(embeddings)) + } +} + +func TestOpenAICompatibleBackendErrors(t *testing.T) { + t.Parallel() + // Test missing baseURL + _, err := NewOpenAICompatibleBackend("", "model", 384) + if err == nil { + t.Error("Expected error for missing baseURL") + } + + // Test missing model + _, err = NewOpenAICompatibleBackend("http://localhost:8000", "", 384) + if err == nil { + t.Error("Expected error for missing model") + } +} + +func TestManagerWithVLLM(t *testing.T) { + t.Parallel() + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == testEmbeddingsEndpoint { + resp := openaiEmbedResponse{ + Object: "list", + Data: []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + }{ + { + Object: "embedding", + Embedding: make([]float32, 384), + Index: 0, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Test manager with vLLM backend + config := &Config{ + BackendType: "vllm", + BaseURL: server.URL, + Model: "sentence-transformers/all-MiniLM-L6-v2", + Dimension: 384, + EnableCache: true, + } + + manager, err := NewManager(config) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + // Test embedding generation + embeddings, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Fatalf("Failed to generate embeddings: %v", err) + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } + if len(embeddings[0]) != 384 { + t.Errorf("Expected dimension 384, got %d", len(embeddings[0])) + } +} + +func TestManagerWithUnified(t *testing.T) { + t.Parallel() + // Create a test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == testEmbeddingsEndpoint { + resp := openaiEmbedResponse{ + Object: "list", + Data: []struct { + Object string `json:"object"` + Embedding []float32 `json:"embedding"` + Index int `json:"index"` + }{ + { + Object: "embedding", + Embedding: make([]float32, 768), + Index: 0, + }, + }, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Test manager with unified backend + config := &Config{ + BackendType: "unified", + BaseURL: server.URL, + Model: "nomic-embed-text", + Dimension: 768, + EnableCache: false, + } + + manager, err := NewManager(config) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + defer manager.Close() + + // Test embedding generation + embeddings, err := manager.GenerateEmbedding([]string{"test"}) + if err != nil { + t.Fatalf("Failed to generate embeddings: %v", err) + } + + if len(embeddings) != 1 { + t.Errorf("Expected 1 embedding, got %d", len(embeddings)) + } +} + +func TestManagerFallbackBehavior(t *testing.T) { + t.Parallel() + // Test that invalid vLLM backend fails gracefully during initialization + // (No fallback behavior is currently implemented) + config := &Config{ + BackendType: "vllm", + BaseURL: "http://invalid-host-that-does-not-exist:9999", + Model: "test-model", + Dimension: 384, + } + + _, err := NewManager(config) + if err == nil { + t.Error("Expected error when creating manager with invalid backend URL") + } + // Test passes if error is returned (no fallback behavior) +} diff --git a/pkg/vmcp/optimizer/internal/ingestion/errors.go b/pkg/vmcp/optimizer/internal/ingestion/errors.go new file mode 100644 index 0000000000..93e8eab31c --- /dev/null +++ b/pkg/vmcp/optimizer/internal/ingestion/errors.go @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package ingestion provides services for ingesting MCP tools into the database. +package ingestion + +import "errors" + +var ( + // ErrIngestionFailed is returned when ingestion fails + ErrIngestionFailed = errors.New("ingestion failed") + + // ErrBackendRetrievalFailed is returned when backend retrieval fails + ErrBackendRetrievalFailed = errors.New("backend retrieval failed") + + // ErrToolHiveUnavailable is returned when ToolHive is unavailable + ErrToolHiveUnavailable = errors.New("ToolHive unavailable") + + // ErrBackendStatusNil is returned when backend status is nil + ErrBackendStatusNil = errors.New("backend status cannot be nil") + + // ErrInvalidRuntimeMode is returned for invalid runtime mode + ErrInvalidRuntimeMode = errors.New("invalid runtime mode: must be 'docker' or 'k8s'") +) diff --git a/pkg/vmcp/optimizer/internal/ingestion/service.go b/pkg/vmcp/optimizer/internal/ingestion/service.go new file mode 100644 index 0000000000..5801758b94 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/ingestion/service.go @@ -0,0 +1,345 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package ingestion + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" + + "github.com/stacklok/toolhive/pkg/logger" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/db" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/embeddings" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/tokens" +) + +// Config holds configuration for the ingestion service +type Config struct { + // Database configuration + DBConfig *db.Config + + // Embedding configuration (flattened from embeddings.Config) + EmbeddingBackend string + EmbeddingURL string + EmbeddingModel string + EmbeddingDimension int + + // MCP timeout in seconds + MCPTimeout int + + // Workloads to skip during ingestion + SkippedWorkloads []string + + // Runtime mode: "docker" or "k8s" + RuntimeMode string + + // Kubernetes configuration (used when RuntimeMode is "k8s") + K8sAPIServerURL string + K8sNamespace string + K8sAllNamespaces bool +} + +// Service handles ingestion of MCP backends and their tools +type Service struct { + config *Config + database db.Database + embeddingManager *embeddings.Manager + tokenCounter *tokens.Counter + tracer trace.Tracer + + // Embedding time tracking + embeddingTimeMu sync.Mutex + totalEmbeddingTime time.Duration +} + +// NewService creates a new ingestion service +func NewService(config *Config) (*Service, error) { + // Set defaults + if config.MCPTimeout == 0 { + config.MCPTimeout = 30 + } + if len(config.SkippedWorkloads) == 0 { + config.SkippedWorkloads = []string{"inspector", "mcp-optimizer"} + } + + // Construct embeddings.Config from individual fields + embeddingConfig := &embeddings.Config{ + BackendType: config.EmbeddingBackend, + BaseURL: config.EmbeddingURL, + Model: config.EmbeddingModel, + Dimension: config.EmbeddingDimension, + } + + // Initialize embedding manager first (needed for database) + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize embedding manager: %w", err) + } + + // Initialize token counter + tokenCounter := tokens.NewCounter() + + // Initialize tracer + tracer := otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/ingestion") + + svc := &Service{ + config: config, + embeddingManager: embeddingManager, + tokenCounter: tokenCounter, + tracer: tracer, + totalEmbeddingTime: 0, + } + + // Create embedding function for database with tracing + embeddingFunc := func(ctx context.Context, text string) ([]float32, error) { + // Create a span for embedding calculation + _, span := svc.tracer.Start(ctx, "optimizer.ingestion.calculate_embedding", + trace.WithAttributes( + attribute.String("operation", "embedding_calculation"), + )) + defer span.End() + + start := time.Now() + + // Our manager takes a slice, so wrap the single text + embeddingsResult, err := embeddingManager.GenerateEmbedding([]string{text}) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return nil, err + } + if len(embeddingsResult) == 0 { + err := fmt.Errorf("no embeddings generated") + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return nil, err + } + + // Track embedding time + duration := time.Since(start) + svc.embeddingTimeMu.Lock() + svc.totalEmbeddingTime += duration + svc.embeddingTimeMu.Unlock() + + span.SetAttributes( + attribute.Int64("embedding.duration_ms", duration.Milliseconds()), + ) + + return embeddingsResult[0], nil + } + + // Initialize database with embedding function + database, err := db.NewDatabase(config.DBConfig, embeddingFunc) + if err != nil { + _ = embeddingManager.Close() + return nil, fmt.Errorf("failed to initialize database: %w", err) + } + svc.database = database + + // Clear database on startup to ensure fresh embeddings + // This is important when the embedding model changes or for consistency + database.Reset() + logger.Info("Cleared optimizer database on startup") + + logger.Info("Ingestion service initialized for event-driven ingestion (chromem-go)") + return svc, nil +} + +// IngestServer ingests a single MCP server and its tools into the optimizer database. +// This is called by vMCP during session registration for each backend server. +// +// Parameters: +// - serverID: Unique identifier for the backend server +// - serverName: Human-readable server name +// - description: Optional server description +// - tools: List of tools available from this server +// +// This method will: +// 1. Create or update the backend server record (simplified metadata only) +// 2. Generate embeddings for server and tools +// 3. Count tokens for each tool +// 4. Store everything in the database for semantic search +// +// Note: URL, transport, status are NOT stored - vMCP manages backend lifecycle +func (s *Service) IngestServer( + ctx context.Context, + serverID string, + serverName string, + description *string, + tools []mcp.Tool, +) error { + // Create a span for the entire ingestion operation + ctx, span := s.tracer.Start(ctx, "optimizer.ingestion.ingest_server", + trace.WithAttributes( + attribute.String("server.id", serverID), + attribute.String("server.name", serverName), + attribute.Int("tools.count", len(tools)), + )) + defer span.End() + + start := time.Now() + logger.Infof("Ingesting server: %s (%d tools) [serverID=%s]", serverName, len(tools), serverID) + + // Create backend server record (simplified - vMCP manages lifecycle) + // chromem-go will generate embeddings automatically from the content + backendServer := &models.BackendServer{ + ID: serverID, + Name: serverName, + Description: description, + Group: "default", // TODO: Pass group from vMCP if needed + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + // Create or update server (chromem-go handles embeddings) + if err := s.database.CreateOrUpdateServer(ctx, backendServer); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return fmt.Errorf("failed to create/update server %s: %w", serverName, err) + } + logger.Debugf("Created/updated server: %s", serverName) + + // Sync tools for this server + toolCount, err := s.syncBackendTools(ctx, serverID, serverName, tools) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return fmt.Errorf("failed to sync tools for %s: %w", serverName, err) + } + + duration := time.Since(start) + span.SetAttributes( + attribute.Int64("ingestion.duration_ms", duration.Milliseconds()), + attribute.Int("tools.ingested", toolCount), + ) + + logger.Infow("Successfully ingested server", + "server_name", serverName, + "server_id", serverID, + "tools_count", toolCount, + "duration_ms", duration.Milliseconds()) + return nil +} + +// syncBackendTools synchronizes tools for a backend server +func (s *Service) syncBackendTools(ctx context.Context, serverID string, serverName string, tools []mcp.Tool) (int, error) { + // Create a span for tool synchronization + ctx, span := s.tracer.Start(ctx, "optimizer.ingestion.sync_backend_tools", + trace.WithAttributes( + attribute.String("server.id", serverID), + attribute.String("server.name", serverName), + attribute.Int("tools.count", len(tools)), + )) + defer span.End() + + logger.Debugf("syncBackendTools: server=%s, serverID=%s, tool_count=%d", serverName, serverID, len(tools)) + + // Delete existing tools + if err := s.database.DeleteToolsByServer(ctx, serverID); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return 0, fmt.Errorf("failed to delete existing tools: %w", err) + } + + if len(tools) == 0 { + return 0, nil + } + + // Create tool records (chromem-go will generate embeddings automatically) + for _, tool := range tools { + // Extract description for embedding + description := tool.Description + + // Convert InputSchema to JSON + schemaJSON, err := json.Marshal(tool.InputSchema) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return 0, fmt.Errorf("failed to marshal input schema for tool %s: %w", tool.Name, err) + } + + backendTool := &models.BackendTool{ + ID: uuid.New().String(), + MCPServerID: serverID, + ToolName: tool.Name, + Description: &description, + InputSchema: schemaJSON, + TokenCount: s.tokenCounter.CountToolTokens(tool), + CreatedAt: time.Now(), + LastUpdated: time.Now(), + } + + if err := s.database.CreateTool(ctx, backendTool, serverName); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return 0, fmt.Errorf("failed to create tool %s: %w", tool.Name, err) + } + } + + logger.Infof("Synced %d tools for server %s", len(tools), serverName) + return len(tools), nil +} + +// GetEmbeddingManager returns the embedding manager for this service +func (s *Service) GetEmbeddingManager() *embeddings.Manager { + return s.embeddingManager +} + +// GetDatabase returns the database for search and retrieval operations +func (s *Service) GetDatabase() db.Database { + return s.database +} + +// GetTotalToolTokens returns the total token count across all tools in the database +func (s *Service) GetTotalToolTokens(ctx context.Context) int { + totalTokens, err := s.database.GetTotalToolTokens(ctx) + if err != nil { + logger.Warnw("Failed to get total tool tokens", "error", err) + return 0 + } + return totalTokens +} + +// GetTotalEmbeddingTime returns the total time spent calculating embeddings +func (s *Service) GetTotalEmbeddingTime() time.Duration { + s.embeddingTimeMu.Lock() + defer s.embeddingTimeMu.Unlock() + return s.totalEmbeddingTime +} + +// ResetEmbeddingTime resets the total embedding time counter +func (s *Service) ResetEmbeddingTime() { + s.embeddingTimeMu.Lock() + defer s.embeddingTimeMu.Unlock() + s.totalEmbeddingTime = 0 +} + +// Close releases resources +func (s *Service) Close() error { + var errs []error + + if err := s.embeddingManager.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close embedding manager: %w", err)) + } + + if err := s.database.Close(); err != nil { + errs = append(errs, fmt.Errorf("failed to close database: %w", err)) + } + + if len(errs) > 0 { + return fmt.Errorf("errors closing service: %v", errs) + } + + return nil +} diff --git a/pkg/vmcp/optimizer/internal/ingestion/service_test.go b/pkg/vmcp/optimizer/internal/ingestion/service_test.go new file mode 100644 index 0000000000..de4b7cda77 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/ingestion/service_test.go @@ -0,0 +1,257 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package ingestion + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/db" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/embeddings" +) + +// TestServiceCreationAndIngestion demonstrates the complete chromem-go workflow: +// 1. Create in-memory database +// 2. Initialize ingestion service +// 3. Ingest server and tools +// 4. Query the database +func TestServiceCreationAndIngestion(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create temporary directory for persistence (optional) + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + // Check for the actual model we'll use: nomic-embed-text + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "nomic-embed-text", + Dimension: 768, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available or model not found. Error: %v. Run 'ollama serve && ollama pull nomic-embed-text'", err) + return + } + _ = embeddingManager.Close() + + // Initialize service with Ollama embeddings + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "nomic-embed-text", + EmbeddingDimension: 768, + } + + svc, err := NewService(config) + if err != nil { + t.Skipf("Skipping test: Failed to create service. Error: %v. Run 'ollama serve && ollama pull nomic-embed-text'", err) + return + } + defer func() { _ = svc.Close() }() + + // Create test tools + tools := []mcp.Tool{ + { + Name: "get_weather", + Description: "Get the current weather for a location", + }, + { + Name: "search_web", + Description: "Search the web for information", + }, + } + + // Ingest server with tools + serverName := "test-server" + serverID := "test-server-id" + description := "A test MCP server" + + err = svc.IngestServer(ctx, serverID, serverName, &description, tools) + if err != nil { + // Check if error is due to missing model + errStr := err.Error() + if strings.Contains(errStr, "model") || strings.Contains(errStr, "not found") || strings.Contains(errStr, "404") { + t.Skipf("Skipping test: Model not available. Error: %v. Run 'ollama serve && ollama pull nomic-embed-text'", err) + return + } + require.NoError(t, err) + } + + // Query tools + allTools, err := svc.database.ListToolsByServer(ctx, serverID) + require.NoError(t, err) + require.Len(t, allTools, 2, "Expected 2 tools to be ingested") + + // Verify tool names + toolNames := make(map[string]bool) + for _, tool := range allTools { + toolNames[tool.ToolName] = true + } + require.True(t, toolNames["get_weather"], "get_weather tool should be present") + require.True(t, toolNames["search_web"], "search_web tool should be present") + + // Search for similar tools + hybridConfig := &db.HybridSearchConfig{ + SemanticRatio: 70, + Limit: 5, + ServerID: &serverID, + } + results, err := svc.database.SearchToolsHybrid(ctx, "weather information", hybridConfig) + require.NoError(t, err) + require.NotEmpty(t, results, "Should find at least one similar tool") + + require.NotEmpty(t, results, "Should return at least one result") + + // Weather tool should be most similar to weather query + require.Equal(t, "get_weather", results[0].ToolName, + "Weather tool should be most similar to weather query") + toolNamesFound := make(map[string]bool) + for _, result := range results { + toolNamesFound[result.ToolName] = true + } + require.True(t, toolNamesFound["get_weather"], "get_weather should be in results") + require.True(t, toolNamesFound["search_web"], "search_web should be in results") +} + +// TestService_EmbeddingTimeTracking tests that embedding time is tracked correctly +func TestService_EmbeddingTimeTracking(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + // Initialize service + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Initially, embedding time should be 0 + initialTime := svc.GetTotalEmbeddingTime() + require.Equal(t, time.Duration(0), initialTime, "Initial embedding time should be 0") + + // Create test tools + tools := []mcp.Tool{ + { + Name: "test_tool_1", + Description: "First test tool for embedding", + }, + { + Name: "test_tool_2", + Description: "Second test tool for embedding", + }, + } + + // Reset embedding time before ingestion + svc.ResetEmbeddingTime() + + // Ingest server with tools (this will generate embeddings) + err = svc.IngestServer(ctx, "test-server-id", "TestServer", nil, tools) + require.NoError(t, err) + + // After ingestion, embedding time should be greater than 0 + totalEmbeddingTime := svc.GetTotalEmbeddingTime() + require.Greater(t, totalEmbeddingTime, time.Duration(0), + "Total embedding time should be greater than 0 after ingestion") + + // Reset and verify it's back to 0 + svc.ResetEmbeddingTime() + resetTime := svc.GetTotalEmbeddingTime() + require.Equal(t, time.Duration(0), resetTime, "Embedding time should be 0 after reset") +} + +// TestServiceWithOllama demonstrates using real embeddings (requires Ollama running) +// This test can be enabled locally to verify Ollama integration +func TestServiceWithOllama(t *testing.T) { + t.Parallel() + + // Skip if not explicitly enabled or Ollama is not available + if os.Getenv("TEST_OLLAMA") != "true" { + t.Skip("Skipping Ollama integration test (set TEST_OLLAMA=true to enable)") + } + + ctx := context.Background() + tmpDir := t.TempDir() + + // Initialize service with Ollama embeddings + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "ollama-db"), + }, + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "nomic-embed-text", + EmbeddingDimension: 384, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Create test tools + tools := []mcp.Tool{ + { + Name: "get_weather", + Description: "Get current weather conditions for any location worldwide", + }, + { + Name: "send_email", + Description: "Send an email message to a recipient", + }, + } + + // Ingest server + err = svc.IngestServer(ctx, "server-1", "TestServer", nil, tools) + require.NoError(t, err) + + // Search for weather-related tools + hybridConfig := &db.HybridSearchConfig{ + SemanticRatio: 70, + Limit: 5, + ServerID: nil, + } + results, err := svc.database.SearchToolsHybrid(ctx, "What's the temperature outside?", hybridConfig) + require.NoError(t, err) + require.NotEmpty(t, results) + + require.Equal(t, "get_weather", results[0].ToolName, + "Weather tool should be most similar to weather query") +} diff --git a/pkg/vmcp/optimizer/internal/ingestion/service_test_coverage.go b/pkg/vmcp/optimizer/internal/ingestion/service_test_coverage.go new file mode 100644 index 0000000000..dbe4d22f27 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/ingestion/service_test_coverage.go @@ -0,0 +1,273 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package ingestion + +import ( + "context" + "path/filepath" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/db" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/embeddings" +) + +// TestService_GetTotalToolTokens tests token counting +func TestService_GetTotalToolTokens(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Ingest some tools + tools := []mcp.Tool{ + { + Name: "tool1", + Description: "Tool 1", + }, + { + Name: "tool2", + Description: "Tool 2", + }, + } + + err = svc.IngestServer(ctx, "server-1", "TestServer", nil, tools) + require.NoError(t, err) + + // Get total tokens + totalTokens := svc.GetTotalToolTokens(ctx) + assert.GreaterOrEqual(t, totalTokens, 0, "Total tokens should be non-negative") +} + +// TestService_GetTotalToolTokens_NoFTS tests token counting without FTS +func TestService_GetTotalToolTokens_NoFTS(t *testing.T) { + t.Parallel() + ctx := context.Background() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: "", // In-memory + FTSDBPath: "", // Will default to :memory: + }, + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Get total tokens (should use FTS if available, fallback otherwise) + totalTokens := svc.GetTotalToolTokens(ctx) + assert.GreaterOrEqual(t, totalTokens, 0, "Total tokens should be non-negative") +} + +// TestService_GetDatabase tests database accessor +func TestService_GetDatabase(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + database := svc.GetDatabase() + require.NotNil(t, database) +} + +// TestService_GetEmbeddingManager tests embedding manager accessor +func TestService_GetEmbeddingManager(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + manager := svc.GetEmbeddingManager() + require.NotNil(t, manager) +} + +// TestService_IngestServer_ErrorHandling tests error handling during ingestion +func TestService_IngestServer_ErrorHandling(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + svc, err := NewService(config) + require.NoError(t, err) + defer func() { _ = svc.Close() }() + + // Test with empty tools list + err = svc.IngestServer(ctx, "server-1", "TestServer", nil, []mcp.Tool{}) + require.NoError(t, err, "Should handle empty tools list gracefully") + + // Test with nil description + err = svc.IngestServer(ctx, "server-2", "TestServer2", nil, []mcp.Tool{ + { + Name: "tool1", + Description: "Tool 1", + }, + }) + require.NoError(t, err, "Should handle nil description gracefully") +} + +// TestService_Close_ErrorHandling tests error handling during close +func TestService_Close_ErrorHandling(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + DBConfig: &db.Config{ + PersistPath: filepath.Join(tmpDir, "test-db"), + }, + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + svc, err := NewService(config) + require.NoError(t, err) + + // Close should succeed + err = svc.Close() + require.NoError(t, err) + + // Multiple closes should be safe + err = svc.Close() + require.NoError(t, err) +} diff --git a/pkg/vmcp/optimizer/internal/models/errors.go b/pkg/vmcp/optimizer/internal/models/errors.go new file mode 100644 index 0000000000..c5b10eebe6 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/models/errors.go @@ -0,0 +1,19 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package models defines domain models for the optimizer. +// It includes structures for MCP servers, tools, and related metadata. +package models + +import "errors" + +var ( + // ErrRemoteServerMissingURL is returned when a remote server doesn't have a URL + ErrRemoteServerMissingURL = errors.New("remote servers must have URL") + + // ErrContainerServerMissingPackage is returned when a container server doesn't have a package + ErrContainerServerMissingPackage = errors.New("container servers must have package") + + // ErrInvalidTokenMetrics is returned when token metrics are inconsistent + ErrInvalidTokenMetrics = errors.New("invalid token metrics: calculated values don't match") +) diff --git a/pkg/vmcp/optimizer/internal/models/models.go b/pkg/vmcp/optimizer/internal/models/models.go new file mode 100644 index 0000000000..6c810fbe04 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/models/models.go @@ -0,0 +1,176 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import ( + "encoding/json" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// BaseMCPServer represents the common fields for MCP servers. +type BaseMCPServer struct { + ID string `json:"id"` + Name string `json:"name"` + Remote bool `json:"remote"` + Transport TransportType `json:"transport"` + Description *string `json:"description,omitempty"` + ServerEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB + Group string `json:"group"` + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// RegistryServer represents an MCP server from the registry catalog. +type RegistryServer struct { + BaseMCPServer + URL *string `json:"url,omitempty"` // For remote servers + Package *string `json:"package,omitempty"` // For container servers +} + +// Validate checks if the registry server has valid data. +// Remote servers must have URL, container servers must have package. +func (r *RegistryServer) Validate() error { + if r.Remote && r.URL == nil { + return ErrRemoteServerMissingURL + } + if !r.Remote && r.Package == nil { + return ErrContainerServerMissingPackage + } + return nil +} + +// BackendServer represents a running MCP server backend. +// Simplified: Only stores metadata needed for tool organization and search results. +// vMCP manages backend lifecycle (URL, status, transport, etc.) +type BackendServer struct { + ID string `json:"id"` + Name string `json:"name"` + Description *string `json:"description,omitempty"` + Group string `json:"group"` + ServerEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// BaseTool represents the common fields for tools. +type BaseTool struct { + ID string `json:"id"` + MCPServerID string `json:"mcpserver_id"` + Details mcp.Tool `json:"details"` + DetailsEmbedding []float32 `json:"-"` // Excluded from JSON, stored as BLOB + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// RegistryTool represents a tool from a registry MCP server. +type RegistryTool struct { + BaseTool +} + +// BackendTool represents a tool from a backend MCP server. +// With chromem-go, embeddings are managed by the database. +type BackendTool struct { + ID string `json:"id"` + MCPServerID string `json:"mcpserver_id"` + ToolName string `json:"tool_name"` + Description *string `json:"description,omitempty"` + InputSchema json.RawMessage `json:"input_schema,omitempty"` + ToolEmbedding []float32 `json:"-"` // Managed by chromem-go + TokenCount int `json:"token_count"` + LastUpdated time.Time `json:"last_updated"` + CreatedAt time.Time `json:"created_at"` +} + +// ToolDetailsToJSON converts mcp.Tool to JSON for storage in the database. +func ToolDetailsToJSON(tool mcp.Tool) (string, error) { + data, err := json.Marshal(tool) + if err != nil { + return "", err + } + return string(data), nil +} + +// ToolDetailsFromJSON converts JSON to mcp.Tool +func ToolDetailsFromJSON(data string) (*mcp.Tool, error) { + var tool mcp.Tool + err := json.Unmarshal([]byte(data), &tool) + if err != nil { + return nil, err + } + return &tool, nil +} + +// BackendToolWithMetadata represents a backend tool with similarity score. +type BackendToolWithMetadata struct { + BackendTool + Similarity float32 `json:"similarity"` // Cosine similarity from chromem-go (0-1, higher is better) +} + +// RegistryToolWithMetadata represents a registry tool with server information and similarity distance. +type RegistryToolWithMetadata struct { + ServerName string `json:"server_name"` + ServerDescription *string `json:"server_description,omitempty"` + Distance float64 `json:"distance"` // Cosine distance from query embedding + Tool RegistryTool `json:"tool"` +} + +// BackendWithRegistry represents a backend server with its resolved registry relationship. +type BackendWithRegistry struct { + Backend BackendServer `json:"backend"` + Registry *RegistryServer `json:"registry,omitempty"` // NULL if autonomous +} + +// EffectiveDescription returns the description (inherited from registry or own). +func (b *BackendWithRegistry) EffectiveDescription() *string { + if b.Registry != nil { + return b.Registry.Description + } + return b.Backend.Description +} + +// EffectiveEmbedding returns the embedding (inherited from registry or own). +func (b *BackendWithRegistry) EffectiveEmbedding() []float32 { + if b.Registry != nil { + return b.Registry.ServerEmbedding + } + return b.Backend.ServerEmbedding +} + +// ServerNameForTools returns the server name to use as context for tool embeddings. +func (b *BackendWithRegistry) ServerNameForTools() string { + if b.Registry != nil { + return b.Registry.Name + } + return b.Backend.Name +} + +// TokenMetrics represents token efficiency metrics for tool filtering. +type TokenMetrics struct { + BaselineTokens int `json:"baseline_tokens"` // Total tokens for all running server tools + ReturnedTokens int `json:"returned_tokens"` // Total tokens for returned/filtered tools + TokensSaved int `json:"tokens_saved"` // Number of tokens saved by filtering + SavingsPercentage float64 `json:"savings_percentage"` // Percentage of tokens saved (0-100) +} + +// Validate checks if the token metrics are consistent. +func (t *TokenMetrics) Validate() error { + if t.TokensSaved != t.BaselineTokens-t.ReturnedTokens { + return ErrInvalidTokenMetrics + } + + var expectedPct float64 + if t.BaselineTokens > 0 { + expectedPct = (float64(t.TokensSaved) / float64(t.BaselineTokens)) * 100 + // Allow small floating point differences (0.01%) + if expectedPct-t.SavingsPercentage > 0.01 || t.SavingsPercentage-expectedPct > 0.01 { + return ErrInvalidTokenMetrics + } + } else if t.SavingsPercentage != 0.0 { + return ErrInvalidTokenMetrics + } + + return nil +} diff --git a/pkg/vmcp/optimizer/internal/models/models_test.go b/pkg/vmcp/optimizer/internal/models/models_test.go new file mode 100644 index 0000000000..af06e90bf4 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/models/models_test.go @@ -0,0 +1,273 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestRegistryServer_Validate(t *testing.T) { + t.Parallel() + url := "http://example.com/mcp" + pkg := "github.com/example/mcp-server" + + tests := []struct { + name string + server *RegistryServer + wantErr bool + }{ + { + name: "Remote server with URL is valid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: true, + }, + URL: &url, + }, + wantErr: false, + }, + { + name: "Container server with package is valid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: false, + }, + Package: &pkg, + }, + wantErr: false, + }, + { + name: "Remote server without URL is invalid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: true, + }, + }, + wantErr: true, + }, + { + name: "Container server without package is invalid", + server: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Remote: false, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.server.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("RegistryServer.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestToolDetailsToJSON(t *testing.T) { + t.Parallel() + tool := mcp.Tool{ + Name: "test_tool", + Description: "A test tool", + } + + json, err := ToolDetailsToJSON(tool) + if err != nil { + t.Fatalf("ToolDetailsToJSON() error = %v", err) + } + + if json == "" { + t.Error("ToolDetailsToJSON() returned empty string") + } + + // Try to parse it back + parsed, err := ToolDetailsFromJSON(json) + if err != nil { + t.Fatalf("ToolDetailsFromJSON() error = %v", err) + } + + if parsed.Name != tool.Name { + t.Errorf("Tool name mismatch: got %v, want %v", parsed.Name, tool.Name) + } + + if parsed.Description != tool.Description { + t.Errorf("Tool description mismatch: got %v, want %v", parsed.Description, tool.Description) + } +} + +func TestTokenMetrics_Validate(t *testing.T) { + t.Parallel() + tests := []struct { + name string + metrics *TokenMetrics + wantErr bool + }{ + { + name: "Valid metrics with savings", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 600, + TokensSaved: 400, + SavingsPercentage: 40.0, + }, + wantErr: false, + }, + { + name: "Valid metrics with no savings", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 1000, + TokensSaved: 0, + SavingsPercentage: 0.0, + }, + wantErr: false, + }, + { + name: "Invalid: tokens saved doesn't match", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 600, + TokensSaved: 500, // Should be 400 + SavingsPercentage: 40.0, + }, + wantErr: true, + }, + { + name: "Invalid: savings percentage doesn't match", + metrics: &TokenMetrics{ + BaselineTokens: 1000, + ReturnedTokens: 600, + TokensSaved: 400, + SavingsPercentage: 50.0, // Should be 40.0 + }, + wantErr: true, + }, + { + name: "Invalid: non-zero percentage with zero baseline", + metrics: &TokenMetrics{ + BaselineTokens: 0, + ReturnedTokens: 0, + TokensSaved: 0, + SavingsPercentage: 10.0, // Should be 0 + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := tt.metrics.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("TokenMetrics.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestBackendWithRegistry_EffectiveDescription(t *testing.T) { + t.Parallel() + registryDesc := "Registry description" + backendDesc := "Backend description" + + tests := []struct { + name string + w *BackendWithRegistry + want *string + }{ + { + name: "Uses registry description when available", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Description: &backendDesc, + }, + Registry: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Description: ®istryDesc, + }, + }, + }, + want: ®istryDesc, + }, + { + name: "Uses backend description when no registry", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Description: &backendDesc, + }, + Registry: nil, + }, + want: &backendDesc, + }, + { + name: "Returns nil when no description", + w: &BackendWithRegistry{ + Backend: BackendServer{}, + Registry: nil, + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := tt.w.EffectiveDescription() + if (got == nil) != (tt.want == nil) { + t.Errorf("BackendWithRegistry.EffectiveDescription() = %v, want %v", got, tt.want) + } + if got != nil && tt.want != nil && *got != *tt.want { + t.Errorf("BackendWithRegistry.EffectiveDescription() = %v, want %v", *got, *tt.want) + } + }) + } +} + +func TestBackendWithRegistry_ServerNameForTools(t *testing.T) { + t.Parallel() + tests := []struct { + name string + w *BackendWithRegistry + want string + }{ + { + name: "Uses registry name when available", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Name: "backend-name", + }, + Registry: &RegistryServer{ + BaseMCPServer: BaseMCPServer{ + Name: "registry-name", + }, + }, + }, + want: "registry-name", + }, + { + name: "Uses backend name when no registry", + w: &BackendWithRegistry{ + Backend: BackendServer{ + Name: "backend-name", + }, + Registry: nil, + }, + want: "backend-name", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.w.ServerNameForTools(); got != tt.want { + t.Errorf("BackendWithRegistry.ServerNameForTools() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/vmcp/optimizer/internal/models/transport.go b/pkg/vmcp/optimizer/internal/models/transport.go new file mode 100644 index 0000000000..8764b7fd48 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/models/transport.go @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import ( + "database/sql/driver" + "fmt" +) + +// TransportType represents the transport protocol used by an MCP server. +// Maps 1:1 to ToolHive transport modes. +type TransportType string + +const ( + // TransportSSE represents Server-Sent Events transport + TransportSSE TransportType = "sse" + // TransportStreamable represents Streamable HTTP transport + TransportStreamable TransportType = "streamable-http" +) + +// Valid returns true if the transport type is valid +func (t TransportType) Valid() bool { + switch t { + case TransportSSE, TransportStreamable: + return true + default: + return false + } +} + +// String returns the string representation +func (t TransportType) String() string { + return string(t) +} + +// Value implements the driver.Valuer interface for database storage +func (t TransportType) Value() (driver.Value, error) { + if !t.Valid() { + return nil, fmt.Errorf("invalid transport type: %s", t) + } + return string(t), nil +} + +// Scan implements the sql.Scanner interface for database retrieval +func (t *TransportType) Scan(value interface{}) error { + if value == nil { + return fmt.Errorf("transport type cannot be nil") + } + + str, ok := value.(string) + if !ok { + return fmt.Errorf("transport type must be a string, got %T", value) + } + + *t = TransportType(str) + if !t.Valid() { + return fmt.Errorf("invalid transport type from database: %s", str) + } + + return nil +} + +// MCPStatus represents the status of an MCP server backend. +type MCPStatus string + +const ( + // StatusRunning indicates the backend is running + StatusRunning MCPStatus = "running" + // StatusStopped indicates the backend is stopped + StatusStopped MCPStatus = "stopped" +) + +// Valid returns true if the status is valid +func (s MCPStatus) Valid() bool { + switch s { + case StatusRunning, StatusStopped: + return true + default: + return false + } +} + +// String returns the string representation +func (s MCPStatus) String() string { + return string(s) +} + +// Value implements the driver.Valuer interface for database storage +func (s MCPStatus) Value() (driver.Value, error) { + if !s.Valid() { + return nil, fmt.Errorf("invalid MCP status: %s", s) + } + return string(s), nil +} + +// Scan implements the sql.Scanner interface for database retrieval +func (s *MCPStatus) Scan(value interface{}) error { + if value == nil { + return fmt.Errorf("MCP status cannot be nil") + } + + str, ok := value.(string) + if !ok { + return fmt.Errorf("MCP status must be a string, got %T", value) + } + + *s = MCPStatus(str) + if !s.Valid() { + return fmt.Errorf("invalid MCP status from database: %s", str) + } + + return nil +} diff --git a/pkg/vmcp/optimizer/internal/models/transport_test.go b/pkg/vmcp/optimizer/internal/models/transport_test.go new file mode 100644 index 0000000000..156062c595 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/models/transport_test.go @@ -0,0 +1,276 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package models + +import ( + "testing" +) + +func TestTransportType_Valid(t *testing.T) { + t.Parallel() + tests := []struct { + name string + transport TransportType + want bool + }{ + { + name: "SSE transport is valid", + transport: TransportSSE, + want: true, + }, + { + name: "Streamable transport is valid", + transport: TransportStreamable, + want: true, + }, + { + name: "Invalid transport is not valid", + transport: TransportType("invalid"), + want: false, + }, + { + name: "Empty transport is not valid", + transport: TransportType(""), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.transport.Valid(); got != tt.want { + t.Errorf("TransportType.Valid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTransportType_Value(t *testing.T) { + t.Parallel() + tests := []struct { + name string + transport TransportType + wantValue string + wantErr bool + }{ + { + name: "SSE transport value", + transport: TransportSSE, + wantValue: "sse", + wantErr: false, + }, + { + name: "Streamable transport value", + transport: TransportStreamable, + wantValue: "streamable-http", + wantErr: false, + }, + { + name: "Invalid transport returns error", + transport: TransportType("invalid"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := tt.transport.Value() + if (err != nil) != tt.wantErr { + t.Errorf("TransportType.Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.wantValue { + t.Errorf("TransportType.Value() = %v, want %v", got, tt.wantValue) + } + }) + } +} + +func TestTransportType_Scan(t *testing.T) { + t.Parallel() + tests := []struct { + name string + value interface{} + want TransportType + wantErr bool + }{ + { + name: "Scan SSE transport", + value: "sse", + want: TransportSSE, + wantErr: false, + }, + { + name: "Scan streamable transport", + value: "streamable-http", + want: TransportStreamable, + wantErr: false, + }, + { + name: "Scan invalid transport returns error", + value: "invalid", + wantErr: true, + }, + { + name: "Scan nil returns error", + value: nil, + wantErr: true, + }, + { + name: "Scan non-string returns error", + value: 123, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var transport TransportType + err := transport.Scan(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("TransportType.Scan() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && transport != tt.want { + t.Errorf("TransportType.Scan() = %v, want %v", transport, tt.want) + } + }) + } +} + +func TestMCPStatus_Valid(t *testing.T) { + t.Parallel() + tests := []struct { + name string + status MCPStatus + want bool + }{ + { + name: "Running status is valid", + status: StatusRunning, + want: true, + }, + { + name: "Stopped status is valid", + status: StatusStopped, + want: true, + }, + { + name: "Invalid status is not valid", + status: MCPStatus("invalid"), + want: false, + }, + { + name: "Empty status is not valid", + status: MCPStatus(""), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := tt.status.Valid(); got != tt.want { + t.Errorf("MCPStatus.Valid() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestMCPStatus_Value(t *testing.T) { + t.Parallel() + tests := []struct { + name string + status MCPStatus + wantValue string + wantErr bool + }{ + { + name: "Running status value", + status: StatusRunning, + wantValue: "running", + wantErr: false, + }, + { + name: "Stopped status value", + status: StatusStopped, + wantValue: "stopped", + wantErr: false, + }, + { + name: "Invalid status returns error", + status: MCPStatus("invalid"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := tt.status.Value() + if (err != nil) != tt.wantErr { + t.Errorf("MCPStatus.Value() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && got != tt.wantValue { + t.Errorf("MCPStatus.Value() = %v, want %v", got, tt.wantValue) + } + }) + } +} + +func TestMCPStatus_Scan(t *testing.T) { + t.Parallel() + tests := []struct { + name string + value interface{} + want MCPStatus + wantErr bool + }{ + { + name: "Scan running status", + value: "running", + want: StatusRunning, + wantErr: false, + }, + { + name: "Scan stopped status", + value: "stopped", + want: StatusStopped, + wantErr: false, + }, + { + name: "Scan invalid status returns error", + value: "invalid", + wantErr: true, + }, + { + name: "Scan nil returns error", + value: nil, + wantErr: true, + }, + { + name: "Scan non-string returns error", + value: 123, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var status MCPStatus + err := status.Scan(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("MCPStatus.Scan() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && status != tt.want { + t.Errorf("MCPStatus.Scan() = %v, want %v", status, tt.want) + } + }) + } +} diff --git a/pkg/vmcp/optimizer/internal/tokens/counter.go b/pkg/vmcp/optimizer/internal/tokens/counter.go new file mode 100644 index 0000000000..11ed33c118 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/tokens/counter.go @@ -0,0 +1,68 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package tokens provides token counting utilities for LLM cost estimation. +// It estimates token counts for MCP tools and their metadata. +package tokens + +import ( + "encoding/json" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Counter counts tokens for LLM consumption +// This provides estimates of token usage for tools +type Counter struct { + // Simple heuristic: ~4 characters per token for English text + charsPerToken float64 +} + +// NewCounter creates a new token counter +func NewCounter() *Counter { + return &Counter{ + charsPerToken: 4.0, // GPT-style tokenization approximation + } +} + +// CountToolTokens estimates the number of tokens for a tool +func (c *Counter) CountToolTokens(tool mcp.Tool) int { + // Convert tool to JSON representation (as it would be sent to LLM) + toolJSON, err := json.Marshal(tool) + if err != nil { + // Fallback to simple estimation + return c.estimateFromTool(tool) + } + + // Estimate tokens from JSON length + return int(float64(len(toolJSON)) / c.charsPerToken) +} + +// estimateFromTool provides a fallback estimation from tool fields +func (c *Counter) estimateFromTool(tool mcp.Tool) int { + totalChars := len(tool.Name) + + if tool.Description != "" { + totalChars += len(tool.Description) + } + + // Estimate input schema size + schemaJSON, _ := json.Marshal(tool.InputSchema) + totalChars += len(schemaJSON) + + return int(float64(totalChars) / c.charsPerToken) +} + +// CountToolsTokens calculates total tokens for multiple tools +func (c *Counter) CountToolsTokens(tools []mcp.Tool) int { + total := 0 + for _, tool := range tools { + total += c.CountToolTokens(tool) + } + return total +} + +// EstimateText estimates tokens for arbitrary text +func (c *Counter) EstimateText(text string) int { + return int(float64(len(text)) / c.charsPerToken) +} diff --git a/pkg/vmcp/optimizer/internal/tokens/counter_test.go b/pkg/vmcp/optimizer/internal/tokens/counter_test.go new file mode 100644 index 0000000000..082ee385a1 --- /dev/null +++ b/pkg/vmcp/optimizer/internal/tokens/counter_test.go @@ -0,0 +1,146 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package tokens + +import ( + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestCountToolTokens(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tool := mcp.Tool{ + Name: "test_tool", + Description: "A test tool for counting tokens", + } + + tokens := counter.CountToolTokens(tool) + + // Should return a positive number + if tokens <= 0 { + t.Errorf("Expected positive token count, got %d", tokens) + } + + // Rough estimate: tool should have at least a few tokens + if tokens < 5 { + t.Errorf("Expected at least 5 tokens for a tool with name and description, got %d", tokens) + } +} + +func TestCountToolTokens_MinimalTool(t *testing.T) { + t.Parallel() + counter := NewCounter() + + // Minimal tool with just a name + tool := mcp.Tool{ + Name: "minimal", + } + + tokens := counter.CountToolTokens(tool) + + // Should return a positive number even for minimal tool + if tokens <= 0 { + t.Errorf("Expected positive token count for minimal tool, got %d", tokens) + } +} + +func TestCountToolTokens_NoDescription(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tool := mcp.Tool{ + Name: "test_tool", + } + + tokens := counter.CountToolTokens(tool) + + // Should still return a positive number + if tokens <= 0 { + t.Errorf("Expected positive token count for tool without description, got %d", tokens) + } +} + +func TestCountToolsTokens(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tools := []mcp.Tool{ + { + Name: "tool1", + Description: "First tool", + }, + { + Name: "tool2", + Description: "Second tool with longer description", + }, + } + + totalTokens := counter.CountToolsTokens(tools) + + // Should be greater than individual tools + tokens1 := counter.CountToolTokens(tools[0]) + tokens2 := counter.CountToolTokens(tools[1]) + + expectedTotal := tokens1 + tokens2 + if totalTokens != expectedTotal { + t.Errorf("Expected total tokens %d, got %d", expectedTotal, totalTokens) + } +} + +func TestCountToolsTokens_EmptyList(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tokens := counter.CountToolsTokens([]mcp.Tool{}) + + // Should return 0 for empty list + if tokens != 0 { + t.Errorf("Expected 0 tokens for empty list, got %d", tokens) + } +} + +func TestEstimateText(t *testing.T) { + t.Parallel() + counter := NewCounter() + + tests := []struct { + name string + text string + want int + }{ + { + name: "Empty text", + text: "", + want: 0, + }, + { + name: "Short text", + text: "Hello", + want: 1, // 5 chars / 4 chars per token ≈ 1 + }, + { + name: "Medium text", + text: "This is a test message", + want: 5, // 22 chars / 4 chars per token ≈ 5 + }, + { + name: "Long text", + text: "This is a much longer test message that should have more tokens because it contains significantly more characters", + want: 28, // 112 chars / 4 chars per token = 28 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := counter.EstimateText(tt.text) + if got != tt.want { + t.Errorf("EstimateText() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/vmcp/optimizer/optimizer.go b/pkg/vmcp/optimizer/optimizer.go index fea0425bb5..14ec0734c9 100644 --- a/pkg/vmcp/optimizer/optimizer.go +++ b/pkg/vmcp/optimizer/optimizer.go @@ -1,47 +1,103 @@ // SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. // SPDX-License-Identifier: Apache-2.0 -// Package optimizer provides the Optimizer interface for intelligent tool discovery -// and invocation in the Virtual MCP Server. +// Package optimizer provides semantic tool discovery for Virtual MCP Server. // -// When the optimizer is enabled, vMCP exposes only two tools to clients: -// - find_tool: Semantic search over available tools -// - call_tool: Dynamic invocation of any backend tool +// The optimizer reduces token usage by exposing only two tools to clients: +// - optim_find_tool: Semantic search over available tools +// - optim_call_tool: Dynamic invocation of backend tools // -// This reduces token usage by avoiding the need to send all tool definitions -// to the LLM, instead allowing it to discover relevant tools on demand. +// This allows LLMs to discover relevant tools on-demand instead of receiving +// all tool definitions upfront. +// +// Architecture: +// - Public API defined by Optimizer interface +// - Implementation details in internal/ subpackages +// - Embeddings generated once at startup for efficiency package optimizer import ( "context" "encoding/json" + "fmt" + "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" + + "github.com/stacklok/toolhive/pkg/logger" + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/config" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/db" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/ingestion" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/models" + "github.com/stacklok/toolhive/pkg/vmcp/server/adapter" ) +// Config is a type alias for config.OptimizerConfig, provided for test compatibility. +// Deprecated: Use config.OptimizerConfig directly. +type Config = config.OptimizerConfig + +// Integration is a type alias for EmbeddingOptimizer, provided for test compatibility. +// Deprecated: Use *EmbeddingOptimizer directly. +type Integration = EmbeddingOptimizer + +//nolint:revive // OptimizerIntegration kept for backward compatibility in tests +type OptimizerIntegration = EmbeddingOptimizer + // Optimizer defines the interface for intelligent tool discovery and invocation. // -// Implementations may use various strategies for tool matching: -// - DummyOptimizer: Exact string matching (for testing) -// - EmbeddingOptimizer: Semantic similarity via embeddings (production) +// Implementations manage their own lifecycle, including: +// - Embedding generation and database management +// - Backend tool ingestion at startup +// - Resource cleanup on shutdown +// +// The optimizer is called via MCP tool handlers (optim_find_tool, optim_call_tool) +// which delegate to these methods. type Optimizer interface { // FindTool searches for tools matching the given description and keywords. - // Returns matching tools ranked by relevance score. + // Returns matching tools ranked by relevance with token savings metrics. FindTool(ctx context.Context, input FindToolInput) (*FindToolOutput, error) // CallTool invokes a tool by name with the given parameters. - // Returns the tool's result or an error if the tool is not found or execution fails. - // Returns the MCP CallToolResult directly from the underlying tool handler. + // Handles tool name resolution and routing to the correct backend. CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) + + // Close cleans up optimizer resources (databases, caches, connections). + Close() error + + // HandleSessionRegistration handles session-specific setup for optimizer mode. + // Returns true if optimizer handled the registration, false otherwise. + HandleSessionRegistration( + ctx context.Context, + sessionID string, + caps *aggregator.AggregatedCapabilities, + mcpServer *server.MCPServer, + resourceConverter func([]vmcp.Resource) []server.ServerResource, + ) (bool, error) + + // OptimizerHandlerProvider provides tool handlers for adapter integration + adapter.OptimizerHandlerProvider } // FindToolInput contains the parameters for finding tools. type FindToolInput struct { // ToolDescription is a natural language description of the tool to find. - ToolDescription string `json:"tool_description" description:"Natural language description of the tool to find"` + ToolDescription string `json:"tool_description"` - // ToolKeywords is an optional list of keywords to narrow the search. - ToolKeywords []string `json:"tool_keywords,omitempty" description:"Optional keywords to narrow search"` + // ToolKeywords is an optional space-separated list of keywords to narrow search. + ToolKeywords string `json:"tool_keywords,omitempty"` + + // Limit is the maximum number of tools to return (default: 10). + Limit int `json:"limit,omitempty"` } // FindToolOutput contains the results of a tool search. @@ -49,43 +105,791 @@ type FindToolOutput struct { // Tools contains the matching tools, ranked by relevance. Tools []ToolMatch `json:"tools"` - // TokenMetrics provides information about token savings from using the optimizer. + // TokenMetrics provides information about token savings. TokenMetrics TokenMetrics `json:"token_metrics"` } // ToolMatch represents a tool that matched the search criteria. type ToolMatch struct { - // Name is the unique identifier of the tool. + // Name is the resolved name of the tool (after conflict resolution). Name string `json:"name"` // Description is the human-readable description of the tool. Description string `json:"description"` // InputSchema is the JSON schema for the tool's input parameters. - // Uses json.RawMessage to preserve the original schema format. - InputSchema json.RawMessage `json:"input_schema"` + InputSchema map[string]any `json:"input_schema"` + + // BackendID is the ID of the backend that provides this tool. + BackendID string `json:"backend_id"` + + // SimilarityScore indicates relevance (0.0-1.0, higher is better). + SimilarityScore float64 `json:"similarity_score"` - // Score indicates how well this tool matches the search criteria (0.0-1.0). - Score float64 `json:"score"` + // TokenCount is the estimated tokens for this tool's definition. + TokenCount int `json:"token_count"` } // TokenMetrics provides information about token usage optimization. type TokenMetrics struct { - // BaselineTokens is the estimated tokens if all tools were sent. + // BaselineTokens is the total tokens if all tools were sent. BaselineTokens int `json:"baseline_tokens"` - // ReturnedTokens is the actual tokens for the returned tools. + // ReturnedTokens is the tokens for the returned tools. ReturnedTokens int `json:"returned_tokens"` - // SavingsPercent is the percentage of tokens saved. - SavingsPercent float64 `json:"savings_percent"` + // TokensSaved is the number of tokens saved by filtering. + TokensSaved int `json:"tokens_saved"` + + // SavingsPercentage is the percentage of tokens saved (0-100). + SavingsPercentage float64 `json:"savings_percentage"` } // CallToolInput contains the parameters for calling a tool. type CallToolInput struct { + // BackendID is the ID of the backend that provides the tool. + BackendID string `json:"backend_id"` + // ToolName is the name of the tool to invoke. - ToolName string `json:"tool_name" description:"Name of the tool to call"` + ToolName string `json:"tool_name"` // Parameters are the arguments to pass to the tool. - Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"` + Parameters map[string]any `json:"parameters"` +} + +// Factory creates an Optimizer instance with direct backend access. +// Called once at startup to enable efficient ingestion and embedding generation. +type Factory func( + ctx context.Context, + cfg *config.OptimizerConfig, + mcpServer *server.MCPServer, + backendClient vmcp.BackendClient, + sessionManager *transportsession.Manager, +) (Optimizer, error) + +// EmbeddingOptimizer implements Optimizer using semantic embeddings and hybrid search. +// +// Architecture: +// - Uses chromem-go for vector embeddings (in-memory or persisted) +// - Uses SQLite FTS5 for BM25 keyword search +// - Combines both for hybrid semantic + keyword matching +// - Ingests backends once at startup, not per-session +type EmbeddingOptimizer struct { + config *config.OptimizerConfig + ingestionService *ingestion.Service + mcpServer *server.MCPServer + backendClient vmcp.BackendClient + sessionManager *transportsession.Manager + tracer trace.Tracer +} + +// NewIntegration is an alias for NewEmbeddingOptimizer, provided for test compatibility. +// Returns the concrete type to allow access to test helper methods. +// Deprecated: Use NewEmbeddingOptimizer directly. +func NewIntegration( + ctx context.Context, + cfg *config.OptimizerConfig, + mcpServer *server.MCPServer, + backendClient vmcp.BackendClient, + sessionManager *transportsession.Manager, +) (*EmbeddingOptimizer, error) { + opt, err := NewEmbeddingOptimizer(ctx, cfg, mcpServer, backendClient, sessionManager) + if err != nil { + return nil, err + } + if opt == nil { + return nil, nil + } + return opt.(*EmbeddingOptimizer), nil +} + +// NewEmbeddingOptimizer is a Factory that creates an embedding-based optimizer. +// This is the production implementation using semantic embeddings. +func NewEmbeddingOptimizer( + _ context.Context, + cfg *config.OptimizerConfig, + mcpServer *server.MCPServer, + backendClient vmcp.BackendClient, + sessionManager *transportsession.Manager, +) (Optimizer, error) { + if cfg == nil || !cfg.Enabled { + return nil, nil // Optimizer disabled + } + + // Initialize ingestion service with embedding backend + ingestionCfg := &ingestion.Config{ + DBConfig: &db.Config{ + PersistPath: cfg.PersistPath, + FTSDBPath: cfg.FTSDBPath, + }, + // Pass individual embedding fields + EmbeddingBackend: cfg.EmbeddingBackend, + EmbeddingURL: cfg.EmbeddingURL, + EmbeddingModel: cfg.EmbeddingModel, + EmbeddingDimension: cfg.EmbeddingDimension, + } + + svc, err := ingestion.NewService(ingestionCfg) + if err != nil { + return nil, fmt.Errorf("failed to initialize ingestion service: %w", err) + } + + opt := &EmbeddingOptimizer{ + config: cfg, + ingestionService: svc, + mcpServer: mcpServer, + backendClient: backendClient, + sessionManager: sessionManager, + tracer: otel.Tracer("github.com/stacklok/toolhive/pkg/vmcp/optimizer"), + } + + return opt, nil +} + +// Ensure EmbeddingOptimizer implements Optimizer interface at compile time. +var _ Optimizer = (*EmbeddingOptimizer)(nil) + +// FindTool implements Optimizer.FindTool using hybrid semantic + keyword search. +func (o *EmbeddingOptimizer) FindTool(ctx context.Context, input FindToolInput) (*FindToolOutput, error) { + // Get database for search + if o.ingestionService == nil { + return nil, fmt.Errorf("ingestion service not initialized") + } + database := o.ingestionService.GetDatabase() + if database == nil { + return nil, fmt.Errorf("database not initialized") + } + + // Configure hybrid search + limit := input.Limit + if limit <= 0 { + limit = 10 // Default + } + + // Handle HybridSearchRatio (pointer in config, with default) + hybridRatio := 70 // Default + if o.config.HybridSearchRatio != nil { + hybridRatio = *o.config.HybridSearchRatio + } + + hybridConfig := &db.HybridSearchConfig{ + SemanticRatio: hybridRatio, + Limit: limit, + ServerID: nil, // Search across all servers + } + + // Build query text + queryText := input.ToolDescription + if input.ToolKeywords != "" { + queryText = queryText + " " + input.ToolKeywords + } + + // Execute hybrid search + results, err := database.SearchToolsHybrid(ctx, queryText, hybridConfig) + if err != nil { + logger.Errorw("Hybrid search failed", + "error", err, + "tool_description", input.ToolDescription, + "tool_keywords", input.ToolKeywords) + return nil, fmt.Errorf("search failed: %w", err) + } + + // Get routing table from context to resolve tool names + var routingTable *vmcp.RoutingTable + if capabilities, ok := discovery.DiscoveredCapabilitiesFromContext(ctx); ok && capabilities != nil { + routingTable = capabilities.RoutingTable + } + + // Convert results to output format + tools, totalReturnedTokens := o.convertSearchResults(results, routingTable) + + // Calculate token metrics + baselineTokens := o.ingestionService.GetTotalToolTokens(ctx) + tokensSaved := baselineTokens - totalReturnedTokens + savingsPercentage := 0.0 + if baselineTokens > 0 { + savingsPercentage = (float64(tokensSaved) / float64(baselineTokens)) * 100.0 + } + + // Record OpenTelemetry metrics + o.recordTokenMetrics(ctx, baselineTokens, totalReturnedTokens, tokensSaved, savingsPercentage) + + logger.Infow("optim_find_tool completed", + "query", input.ToolDescription, + "results_count", len(tools), + "tokens_saved", tokensSaved, + "savings_percentage", fmt.Sprintf("%.2f%%", savingsPercentage)) + + return &FindToolOutput{ + Tools: tools, + TokenMetrics: TokenMetrics{ + BaselineTokens: baselineTokens, + ReturnedTokens: totalReturnedTokens, + TokensSaved: tokensSaved, + SavingsPercentage: savingsPercentage, + }, + }, nil +} + +// CallTool implements Optimizer.CallTool by routing to the correct backend. +func (o *EmbeddingOptimizer) CallTool(ctx context.Context, input CallToolInput) (*mcp.CallToolResult, error) { + // Resolve target backend + target, backendToolName, err := o.resolveToolTarget(ctx, input.BackendID, input.ToolName) + if err != nil { + return nil, err + } + + logger.Infow("Calling tool via optimizer", + "backend_id", input.BackendID, + "tool_name", input.ToolName, + "backend_tool_name", backendToolName, + "workload_name", target.WorkloadName) + + // Call the tool on the backend + result, err := o.backendClient.CallTool(ctx, target, backendToolName, input.Parameters, nil) + if err != nil { + logger.Errorw("Tool call failed", + "error", err, + "backend_id", input.BackendID, + "tool_name", input.ToolName, + "backend_tool_name", backendToolName) + return nil, fmt.Errorf("tool call failed: %w", err) + } + + // Convert result to MCP format + mcpResult := convertToolResult(result) + + logger.Infow("optim_call_tool completed successfully", + "backend_id", input.BackendID, + "tool_name", input.ToolName) + + return mcpResult, nil +} + +// Close implements Optimizer.Close by cleaning up resources. +func (o *EmbeddingOptimizer) Close() error { + if o == nil || o.ingestionService == nil { + return nil + } + return o.ingestionService.Close() +} + +// HandleSessionRegistration implements Optimizer.HandleSessionRegistration. +func (o *EmbeddingOptimizer) HandleSessionRegistration( + _ context.Context, + sessionID string, + caps *aggregator.AggregatedCapabilities, + mcpServer *server.MCPServer, + resourceConverter func([]vmcp.Resource) []server.ServerResource, +) (bool, error) { + logger.Debugw("HandleSessionRegistration called for optimizer mode", "session_id", sessionID) + + // Register optimizer tools for this session + optimizerTools, err := adapter.CreateOptimizerToolsFromProvider(o) + if err != nil { + return false, fmt.Errorf("failed to create optimizer tools: %w", err) + } + + // Add optimizer tools to session + if err := mcpServer.AddSessionTools(sessionID, optimizerTools...); err != nil { + return false, fmt.Errorf("failed to add optimizer tools to session: %w", err) + } + + logger.Debugw("Optimizer tools registered for session", "session_id", sessionID) + + // Inject resources (but not backend tools or composite tools) + if len(caps.Resources) > 0 { + sdkResources := resourceConverter(caps.Resources) + if err := mcpServer.AddSessionResources(sessionID, sdkResources...); err != nil { + return false, fmt.Errorf("failed to add session resources: %w", err) + } + logger.Debugw("Added session resources (optimizer mode)", + "session_id", sessionID, + "count", len(sdkResources)) + } + + logger.Infow("Optimizer mode: backend tools not exposed directly", + "session_id", sessionID, + "backend_tool_count", len(caps.Tools), + "resource_count", len(caps.Resources)) + + return true, nil // Optimizer handled the registration +} + +// CreateFindToolHandler implements adapter.OptimizerHandlerProvider. +func (o *EmbeddingOptimizer) CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + logger.Debugw("optim_find_tool called", "request", request) + + // Extract parameters + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return mcp.NewToolResultError("invalid arguments: expected object"), nil + } + + // Extract and validate parameters + toolDescription, toolKeywords, limit, err := extractFindToolParams(args) + if err != nil { + return err, nil + } + + // Call FindTool + output, findErr := o.FindTool(ctx, FindToolInput{ + ToolDescription: toolDescription, + ToolKeywords: toolKeywords, + Limit: limit, + }) + if findErr != nil { + return mcp.NewToolResultError(fmt.Sprintf("search failed: %v", findErr)), nil + } + + // Marshal response to JSON + responseJSON, marshalErr := json.Marshal(output) + if marshalErr != nil { + logger.Errorw("Failed to marshal response", "error", marshalErr) + return mcp.NewToolResultError(fmt.Sprintf("failed to marshal response: %v", marshalErr)), nil + } + + return mcp.NewToolResultText(string(responseJSON)), nil + } +} + +// CreateCallToolHandler implements adapter.OptimizerHandlerProvider. +func (o *EmbeddingOptimizer) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + logger.Debugw("optim_call_tool called", "request", request) + + // Parse request + backendID, toolName, parameters, err := parseCallToolRequest(request) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Call CallTool + result, err := o.CallTool(ctx, CallToolInput{ + BackendID: backendID, + ToolName: toolName, + Parameters: parameters, + }) + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + return result, nil + } +} + +// Initialize performs optimizer initialization (registers tools, ingests backends). +// This should be called once during server startup. +func (o *EmbeddingOptimizer) Initialize( + ctx context.Context, + mcpServer *server.MCPServer, + backendRegistry vmcp.BackendRegistry, +) error { + // Register optimizer tools globally + optimizerTools, err := adapter.CreateOptimizerToolsFromProvider(o) + if err != nil { + return fmt.Errorf("failed to create optimizer tools: %w", err) + } + for _, tool := range optimizerTools { + mcpServer.AddTool(tool.Tool, tool.Handler) + } + logger.Info("Optimizer tools registered globally") + + // Ingest discovered backends + initialBackends := backendRegistry.List(ctx) + if err := o.IngestInitialBackends(ctx, initialBackends); err != nil { + logger.Warnf("Failed to ingest initial backends: %v", err) + // Don't fail initialization - optimizer can still work with incremental ingestion + } + + return nil +} + +// IngestInitialBackends ingests all discovered backends and their tools at startup. +func (o *EmbeddingOptimizer) IngestInitialBackends(ctx context.Context, backends []vmcp.Backend) error { + if o == nil || o.ingestionService == nil { + logger.Infow("Optimizer disabled, embedding time: 0ms") + return nil + } + + // Reset embedding time before starting ingestion + o.ingestionService.ResetEmbeddingTime() + + // Create a span for the entire ingestion process + ctx, span := o.tracer.Start(ctx, "optimizer.ingestion.ingest_initial_backends", + trace.WithAttributes( + attribute.Int("backends.count", len(backends)), + )) + defer span.End() + + start := time.Now() + logger.Infof("Ingesting %d discovered backends into optimizer", len(backends)) + + ingestedCount := 0 + totalToolsIngested := 0 + for _, backend := range backends { + // Create a span for each backend ingestion + backendCtx, backendSpan := o.tracer.Start(ctx, "optimizer.ingestion.ingest_backend", + trace.WithAttributes( + attribute.String("backend.id", backend.ID), + attribute.String("backend.name", backend.Name), + )) + + // Convert Backend to BackendTarget for client API + target := vmcp.BackendToTarget(&backend) + if target == nil { + logger.Warnf("Failed to convert backend %s to target", backend.Name) + backendSpan.RecordError(fmt.Errorf("failed to convert backend to target")) + backendSpan.SetStatus(codes.Error, "conversion failed") + backendSpan.End() + continue + } + + // Query backend capabilities to get its tools + capabilities, err := o.backendClient.ListCapabilities(backendCtx, target) + if err != nil { + logger.Warnf("Failed to query capabilities for backend %s: %v", backend.Name, err) + backendSpan.RecordError(err) + backendSpan.SetStatus(codes.Error, err.Error()) + backendSpan.End() + continue + } + + // Extract tools from capabilities + var tools []mcp.Tool + for _, tool := range capabilities.Tools { + tools = append(tools, mcp.Tool{ + Name: tool.Name, + Description: tool.Description, + }) + } + + // Get description from metadata + var description *string + if backend.Metadata != nil { + if desc := backend.Metadata["description"]; desc != "" { + description = &desc + } + } + + backendSpan.SetAttributes( + attribute.Int("tools.count", len(tools)), + ) + + // Ingest this backend's tools + if err := o.ingestionService.IngestServer( + backendCtx, + backend.ID, + backend.Name, + description, + tools, + ); err != nil { + logger.Warnf("Failed to ingest backend %s: %v", backend.Name, err) + backendSpan.RecordError(err) + backendSpan.SetStatus(codes.Error, err.Error()) + backendSpan.End() + continue + } + ingestedCount++ + totalToolsIngested += len(tools) + backendSpan.SetAttributes( + attribute.Int("tools.ingested", len(tools)), + ) + backendSpan.SetStatus(codes.Ok, "backend ingested successfully") + backendSpan.End() + } + + // Get total embedding time + totalEmbeddingTime := o.ingestionService.GetTotalEmbeddingTime() + totalDuration := time.Since(start) + + span.SetAttributes( + attribute.Int64("ingestion.duration_ms", totalDuration.Milliseconds()), + attribute.Int64("embedding.duration_ms", totalEmbeddingTime.Milliseconds()), + attribute.Int("backends.ingested", ingestedCount), + attribute.Int("tools.ingested", totalToolsIngested), + ) + + logger.Infow("Initial backend ingestion completed", + "servers_ingested", ingestedCount, + "tools_ingested", totalToolsIngested, + "total_duration_ms", totalDuration.Milliseconds(), + "total_embedding_time_ms", totalEmbeddingTime.Milliseconds(), + "embedding_time_percentage", fmt.Sprintf("%.2f%%", float64(totalEmbeddingTime)/float64(totalDuration)*100)) + + return nil +} + +// Helper methods + +// convertSearchResults converts database search results to ToolMatch format. +func (*EmbeddingOptimizer) convertSearchResults( + results []*models.BackendToolWithMetadata, + routingTable *vmcp.RoutingTable, +) ([]ToolMatch, int) { + tools := make([]ToolMatch, 0, len(results)) + totalReturnedTokens := 0 + + for _, result := range results { + // Unmarshal InputSchema + var inputSchema map[string]any + if len(result.InputSchema) > 0 { + if err := json.Unmarshal(result.InputSchema, &inputSchema); err != nil { + logger.Warnw("Failed to unmarshal input schema", + "tool_id", result.ID, + "tool_name", result.ToolName, + "error", err) + inputSchema = map[string]any{} + } + } + + // Handle nil description + description := "" + if result.Description != nil { + description = *result.Description + } + + // Resolve tool name using routing table + resolvedName := resolveToolName(routingTable, result.MCPServerID, result.ToolName) + + tool := ToolMatch{ + Name: resolvedName, + Description: description, + InputSchema: inputSchema, + BackendID: result.MCPServerID, + SimilarityScore: float64(result.Similarity), + TokenCount: result.TokenCount, + } + tools = append(tools, tool) + totalReturnedTokens += result.TokenCount + } + + return tools, totalReturnedTokens +} + +// resolveToolTarget finds and validates the target backend for a tool. +func (*EmbeddingOptimizer) resolveToolTarget( + ctx context.Context, + backendID string, + toolName string, +) (*vmcp.BackendTarget, string, error) { + capabilities, ok := discovery.DiscoveredCapabilitiesFromContext(ctx) + if !ok || capabilities == nil { + return nil, "", fmt.Errorf("routing information not available in context") + } + + if capabilities.RoutingTable == nil || capabilities.RoutingTable.Tools == nil { + return nil, "", fmt.Errorf("routing table not initialized") + } + + target, exists := capabilities.RoutingTable.Tools[toolName] + if !exists { + return nil, "", fmt.Errorf("tool not found in routing table: %s", toolName) + } + + if target.WorkloadID != backendID { + return nil, "", fmt.Errorf("tool %s belongs to backend %s, not %s", + toolName, target.WorkloadID, backendID) + } + + backendToolName := target.GetBackendCapabilityName(toolName) + return target, backendToolName, nil +} + +// recordTokenMetrics records OpenTelemetry metrics for token savings. +func (*EmbeddingOptimizer) recordTokenMetrics( + ctx context.Context, + baselineTokens int, + returnedTokens int, + tokensSaved int, + savingsPercentage float64, +) { + meter := otel.Meter("github.com/stacklok/toolhive/pkg/vmcp/optimizer") + + baselineCounter, err := meter.Int64Counter( + "toolhive_vmcp_optimizer_baseline_tokens", + metric.WithDescription("Total tokens for all tools in the optimizer database (baseline)"), + ) + if err != nil { + logger.Debugw("Failed to create baseline_tokens counter", "error", err) + return + } + + returnedCounter, err := meter.Int64Counter( + "toolhive_vmcp_optimizer_returned_tokens", + metric.WithDescription("Total tokens for tools returned by optim_find_tool"), + ) + if err != nil { + logger.Debugw("Failed to create returned_tokens counter", "error", err) + return + } + + savedCounter, err := meter.Int64Counter( + "toolhive_vmcp_optimizer_tokens_saved", + metric.WithDescription("Number of tokens saved by filtering tools with optim_find_tool"), + ) + if err != nil { + logger.Debugw("Failed to create tokens_saved counter", "error", err) + return + } + + savingsGauge, err := meter.Float64Gauge( + "toolhive_vmcp_optimizer_savings_percentage", + metric.WithDescription("Percentage of tokens saved by filtering tools (0-100)"), + metric.WithUnit("%"), + ) + if err != nil { + logger.Debugw("Failed to create savings_percentage gauge", "error", err) + return + } + + attrs := metric.WithAttributes( + attribute.String("operation", "find_tool"), + ) + + baselineCounter.Add(ctx, int64(baselineTokens), attrs) + returnedCounter.Add(ctx, int64(returnedTokens), attrs) + savedCounter.Add(ctx, int64(tokensSaved), attrs) + savingsGauge.Record(ctx, savingsPercentage, attrs) +} + +// Helper functions + +// extractFindToolParams extracts and validates parameters from the find_tool request. +func extractFindToolParams(args map[string]any) (toolDescription, toolKeywords string, limit int, err *mcp.CallToolResult) { + toolDescription, ok := args["tool_description"].(string) + if !ok || toolDescription == "" { + return "", "", 0, mcp.NewToolResultError("tool_description is required and must be a non-empty string") + } + + toolKeywords, _ = args["tool_keywords"].(string) + + limit = 10 // Default + if limitVal, ok := args["limit"]; ok { + if limitFloat, ok := limitVal.(float64); ok { + limit = int(limitFloat) + } + } + + return toolDescription, toolKeywords, limit, nil +} + +// parseCallToolRequest extracts and validates parameters from the call_tool request. +func parseCallToolRequest(request mcp.CallToolRequest) (backendID, toolName string, parameters map[string]any, err error) { + args, ok := request.Params.Arguments.(map[string]any) + if !ok { + return "", "", nil, fmt.Errorf("invalid arguments: expected object") + } + + backendID, ok = args["backend_id"].(string) + if !ok || backendID == "" { + return "", "", nil, fmt.Errorf("backend_id is required and must be a non-empty string") + } + + toolName, ok = args["tool_name"].(string) + if !ok || toolName == "" { + return "", "", nil, fmt.Errorf("tool_name is required and must be a non-empty string") + } + + parameters, ok = args["parameters"].(map[string]any) + if !ok { + return "", "", nil, fmt.Errorf("parameters is required and must be an object") + } + + return backendID, toolName, parameters, nil +} + +// resolveToolName looks up the resolved name for a tool in the routing table. +func resolveToolName(routingTable *vmcp.RoutingTable, backendID string, originalName string) string { + if routingTable == nil || routingTable.Tools == nil { + return originalName + } + + for resolvedName, target := range routingTable.Tools { + // Case 1: Tool was renamed + if target.WorkloadID == backendID && target.OriginalCapabilityName == originalName { + return resolvedName + } + + // Case 2: Tool was not renamed + if target.WorkloadID == backendID && target.OriginalCapabilityName == "" && resolvedName == originalName { + return resolvedName + } + } + + return originalName // Fallback +} + +// convertToolResult converts vmcp.ToolCallResult to mcp.CallToolResult. +func convertToolResult(result *vmcp.ToolCallResult) *mcp.CallToolResult { + mcpContent := make([]mcp.Content, len(result.Content)) + for i, content := range result.Content { + mcpContent[i] = convertVMCPContent(content) + } + + return &mcp.CallToolResult{ + Content: mcpContent, + IsError: result.IsError, + } +} + +// convertVMCPContent converts a vmcp.Content to mcp.Content. +func convertVMCPContent(content vmcp.Content) mcp.Content { + switch content.Type { + case "text": + return mcp.NewTextContent(content.Text) + case "image": + return mcp.NewImageContent(content.Data, content.MimeType) + case "audio": + return mcp.NewAudioContent(content.Data, content.MimeType) + case "resource": + logger.Warnw("Converting resource content to text - embedded resources not yet supported") + return mcp.NewTextContent("") + default: + logger.Warnw("Converting unknown content type to text", "type", content.Type) + return mcp.NewTextContent("") + } +} + +// OnRegisterSession is a test helper that registers a session without all the infrastructure setup. +// It's a simplified version for testing purposes. +func (o *EmbeddingOptimizer) OnRegisterSession( + _ context.Context, + _ interface{}, // session - not used in simplified test version + _ *aggregator.AggregatedCapabilities, // capabilities - not used in simplified test version +) error { + // Test helper - no-op implementation + if o == nil { + return nil + } + return nil +} + +// RegisterTools is a test helper for registering optimizer tools with a session. +// It's a simplified version for testing purposes. +func (o *EmbeddingOptimizer) RegisterTools( + _ context.Context, + _ interface{}, // session - not used in simplified test version +) error { + // Test helper - no-op implementation (or could panic if o is nil) + if o == nil { + return nil + } + return nil +} + +// IngestToolsForTesting manually ingests tools for testing purposes. +// This is a test helper that bypasses the normal ingestion flow. +func (o *EmbeddingOptimizer) IngestToolsForTesting( + ctx context.Context, + serverID string, + serverName string, + description *string, + tools []mcp.Tool, +) error { + if o.ingestionService == nil { + return fmt.Errorf("optimizer integration not initialized") + } + return o.ingestionService.IngestServer(ctx, serverID, serverName, description, tools) } diff --git a/pkg/vmcp/optimizer/optimizer_handlers_test.go b/pkg/vmcp/optimizer/optimizer_handlers_test.go new file mode 100644 index 0000000000..5837993027 --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer_handlers_test.go @@ -0,0 +1,1020 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/discovery" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/embeddings" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// mockMCPServerWithSession implements AddSessionTools for testing +type mockMCPServerWithSession struct { + *server.MCPServer + toolsAdded map[string][]server.ServerTool +} + +func newMockMCPServerWithSession() *mockMCPServerWithSession { + return &mockMCPServerWithSession{ + MCPServer: server.NewMCPServer("test-server", "1.0"), + toolsAdded: make(map[string][]server.ServerTool), + } +} + +func (m *mockMCPServerWithSession) AddSessionTools(sessionID string, tools ...server.ServerTool) error { + m.toolsAdded[sessionID] = tools + return nil +} + +// mockBackendClientWithCallTool implements CallTool for testing +type mockBackendClientWithCallTool struct { + callToolResult map[string]any + callToolError error +} + +func (*mockBackendClientWithCallTool) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + return &vmcp.CapabilityList{}, nil +} + +func (m *mockBackendClientWithCallTool) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any, _ map[string]any) (*vmcp.ToolCallResult, error) { + if m.callToolError != nil { + return nil, m.callToolError + } + // Convert map[string]any to ToolCallResult with JSON-marshaled content + jsonBytes, err := json.Marshal(m.callToolResult) + if err != nil { + return nil, fmt.Errorf("failed to marshal call tool result: %w", err) + } + result := &vmcp.ToolCallResult{ + Content: []vmcp.Content{ + { + Type: "text", + Text: string(jsonBytes), + }, + }, + StructuredContent: m.callToolResult, + } + return result, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClientWithCallTool) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (*vmcp.PromptGetResult, error) { + return &vmcp.PromptGetResult{}, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClientWithCallTool) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) (*vmcp.ResourceReadResult, error) { + return &vmcp.ResourceReadResult{}, nil +} + +// TestCreateFindToolHandler_InvalidArguments tests error handling for invalid arguments +func TestCreateFindToolHandler_InvalidArguments(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Setup optimizer integration + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateFindToolHandler() + + // Test with invalid arguments type + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: "not a map", + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for invalid arguments") + + // Test with missing tool_description + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "limit": 10, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing tool_description") + + // Test with empty tool_description + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "", + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for empty tool_description") + + // Test with non-string tool_description + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": 123, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for non-string tool_description") +} + +// TestCreateFindToolHandler_WithKeywords tests find_tool with keywords +func TestCreateFindToolHandler_WithKeywords(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Ingest a tool for testing + tools := []mcp.Tool{ + { + Name: "test_tool", + Description: "A test tool for searching", + }, + } + + err = integration.IngestToolsForTesting(ctx, "server-1", "TestServer", nil, tools) + require.NoError(t, err) + + handler := integration.CreateFindToolHandler() + + // Test with keywords + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "search tool", + "tool_keywords": "test search", + "limit": 10, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.False(t, result.IsError, "Should not return error") + + // Verify response structure + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + + _, ok = response["tools"] + require.True(t, ok, "Response should have tools") + + _, ok = response["token_metrics"] + require.True(t, ok, "Response should have token_metrics") +} + +// TestCreateFindToolHandler_Limit tests limit parameter handling +func TestCreateFindToolHandler_Limit(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateFindToolHandler() + + // Test with custom limit + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "test", + "limit": 5, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.False(t, result.IsError) + + // Test with float64 limit (from JSON) + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "test", + "limit": float64(3), + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.False(t, result.IsError) +} + +// TestCreateFindToolHandler_BackendToolOpsNil tests error when backend tool ops is nil +func TestCreateFindToolHandler_BackendToolOpsNil(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create integration with nil ingestion service to trigger error path + integration := &OptimizerIntegration{ + config: &Config{Enabled: true}, + ingestionService: nil, // This will cause GetDatabase to return nil + } + + handler := integration.CreateFindToolHandler() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "test", + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when backend tool ops is nil") +} + +// TestCreateCallToolHandler_InvalidArguments tests error handling for invalid arguments +func TestCreateCallToolHandler_InvalidArguments(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Test with invalid arguments type + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: "not a map", + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for invalid arguments") + + // Test with missing backend_id + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing backend_id") + + // Test with empty backend_id + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "", + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for empty backend_id") + + // Test with missing tool_name + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "parameters": map[string]any{}, + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing tool_name") + + // Test with missing parameters + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for missing parameters") + + // Test with invalid parameters type + request = mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": "not a map", + }, + }, + } + + result, err = handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error for invalid parameters type") +} + +// TestCreateCallToolHandler_NoRoutingTable tests error when routing table is missing +func TestCreateCallToolHandler_NoRoutingTable(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Test without routing table in context + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when routing table is missing") +} + +// TestCreateCallToolHandler_ToolNotFound tests error when tool is not found +func TestCreateCallToolHandler_ToolNotFound(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Create context with routing table but tool not found + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: make(map[string]*vmcp.BackendTarget), + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "nonexistent_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when tool is not found") +} + +// TestCreateCallToolHandler_BackendMismatch tests error when backend doesn't match +func TestCreateCallToolHandler_BackendMismatch(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Create context with routing table where tool belongs to different backend + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": { + WorkloadID: "backend-2", // Different backend + WorkloadName: "Backend 2", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", // Requesting backend-1 + "tool_name": "test_tool", // But tool belongs to backend-2 + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when backend doesn't match") +} + +// TestCreateCallToolHandler_Success tests successful tool call +func TestCreateCallToolHandler_Success(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{ + callToolResult: map[string]any{ + "result": "success", + }, + } + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + // Create context with routing table + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "Backend 1", + BaseURL: "http://localhost:8000", + } + + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": target, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": map[string]any{ + "param1": "value1", + }, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.False(t, result.IsError, "Should not return error") + + // Verify response + textContent, ok := mcp.AsTextContent(result.Content[0]) + require.True(t, ok) + + var response map[string]any + err = json.Unmarshal([]byte(textContent.Text), &response) + require.NoError(t, err) + assert.Equal(t, "success", response["result"]) +} + +// TestCreateCallToolHandler_CallToolError tests error handling when CallTool fails +func TestCreateCallToolHandler_CallToolError(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClientWithCallTool{ + callToolError: assert.AnError, + } + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateCallToolHandler() + + target := &vmcp.BackendTarget{ + WorkloadID: "backend-1", + WorkloadName: "Backend 1", + BaseURL: "http://localhost:8000", + } + + capabilities := &aggregator.AggregatedCapabilities{ + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": target, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + ctxWithCaps := discovery.WithDiscoveredCapabilities(ctx, capabilities) + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_call_tool", + Arguments: map[string]any{ + "backend_id": "backend-1", + "tool_name": "test_tool", + "parameters": map[string]any{}, + }, + }, + } + + result, err := handler(ctxWithCaps, request) + require.NoError(t, err) + require.True(t, result.IsError, "Should return error when CallTool fails") +} + +// TestCreateFindToolHandler_InputSchemaUnmarshalError tests error handling for invalid input schema +func TestCreateFindToolHandler_InputSchemaUnmarshalError(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + handler := integration.CreateFindToolHandler() + + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "test", + }, + }, + } + + // The handler should handle invalid input schema gracefully + result, err := handler(ctx, request) + require.NoError(t, err) + // Should not error even if some tools have invalid schemas + require.False(t, result.IsError) +} + +// TestOnRegisterSession_DuplicateSession tests duplicate session handling +func TestOnRegisterSession_DuplicateSession(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + session := &mockSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{} + + // First call + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Second call with same session ID (should be skipped) + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err, "Should handle duplicate session gracefully") +} + +// TestIngestInitialBackends_ErrorHandling tests error handling during ingestion +func TestIngestInitialBackends_ErrorHandling(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Check Ollama availability first + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v", err) + return + } + _ = embeddingManager.Close() + + mcpServer := newMockMCPServerWithSession() + mockClient := &mockBackendClient{ + err: assert.AnError, // Simulate error when listing capabilities + } + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "all-minilm", + EmbeddingDimension: 384, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer.MCPServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + backends := []vmcp.Backend{ + { + ID: "backend-1", + Name: "Backend 1", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + // Should not fail even if backend query fails + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err, "Should handle backend query errors gracefully") +} + +// TestIngestInitialBackends_NilIntegration tests nil integration handling +func TestIngestInitialBackends_NilIntegration(t *testing.T) { + t.Parallel() + ctx := context.Background() + + var integration *OptimizerIntegration = nil + backends := []vmcp.Backend{} + + err := integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err, "Should handle nil integration gracefully") +} diff --git a/pkg/vmcp/optimizer/optimizer_integration_test.go b/pkg/vmcp/optimizer/optimizer_integration_test.go new file mode 100644 index 0000000000..39a090b5c1 --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer_integration_test.go @@ -0,0 +1,433 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/require" + + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/embeddings" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// mockBackendClient implements vmcp.BackendClient for integration testing +type mockIntegrationBackendClient struct { + backends map[string]*vmcp.CapabilityList +} + +func newMockIntegrationBackendClient() *mockIntegrationBackendClient { + return &mockIntegrationBackendClient{ + backends: make(map[string]*vmcp.CapabilityList), + } +} + +func (m *mockIntegrationBackendClient) addBackend(backendID string, caps *vmcp.CapabilityList) { + m.backends[backendID] = caps +} + +func (m *mockIntegrationBackendClient) ListCapabilities(_ context.Context, target *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + if caps, exists := m.backends[target.WorkloadID]; exists { + return caps, nil + } + return &vmcp.CapabilityList{}, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any, _ map[string]any) (*vmcp.ToolCallResult, error) { + return &vmcp.ToolCallResult{}, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (*vmcp.PromptGetResult, error) { + return &vmcp.PromptGetResult{}, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) (*vmcp.ResourceReadResult, error) { + return &vmcp.ResourceReadResult{}, nil +} + +// mockIntegrationSession implements server.ClientSession for testing +type mockIntegrationSession struct { + sessionID string +} + +func (m *mockIntegrationSession) SessionID() string { + return m.sessionID +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Send(_ interface{}) error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Close() error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Initialize() { + // No-op for testing +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) Initialized() bool { + return true +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockIntegrationSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + // Return a dummy channel for testing + ch := make(chan mcp.JSONRPCNotification, 1) + return ch +} + +// TestOptimizerIntegration_WithVMCP tests the complete integration with vMCP +func TestOptimizerIntegration_WithVMCP(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Create MCP server + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + + // Create mock backend client + mockClient := newMockIntegrationBackendClient() + mockClient.addBackend("github", &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + }, + }, + }) + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Configure optimizer + optimizerConfig := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: embeddings.BackendTypeOllama, + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: embeddings.DefaultModelAllMiniLM, + EmbeddingDimension: 384, + } + + // Create optimizer integration + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Ingest backends + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err) + + // Simulate session registration + session := &mockIntegrationSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + BackendID: "github", + }, + }, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "create_issue": { + WorkloadID: "github", + WorkloadName: "GitHub", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + err = integration.OnRegisterSession(ctx, session, capabilities) + require.NoError(t, err) + + // Note: We don't test RegisterTools here because it requires the session + // to be properly registered with the MCP server, which is beyond the scope + // of this integration test. The RegisterTools method is tested separately + // in unit tests where we can properly mock the MCP server behavior. +} + +// TestOptimizerIntegration_EmbeddingTimeTracking tests that embedding time is tracked and logged +func TestOptimizerIntegration_EmbeddingTimeTracking(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Create MCP server + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + + // Create mock backend client + mockClient := newMockIntegrationBackendClient() + mockClient.addBackend("github", &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + }, + { + Name: "get_repo", + Description: "Get repository information", + }, + }, + }) + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Configure optimizer + optimizerConfig := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: embeddings.BackendTypeOllama, + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: embeddings.DefaultModelAllMiniLM, + EmbeddingDimension: 384, + } + + // Create optimizer integration + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Verify embedding time starts at 0 + embeddingTime := integration.ingestionService.GetTotalEmbeddingTime() + require.Equal(t, time.Duration(0), embeddingTime, "Initial embedding time should be 0") + + // Ingest backends + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err) + + // After ingestion, embedding time should be tracked + // Note: The actual time depends on Ollama performance, but it should be > 0 + finalEmbeddingTime := integration.ingestionService.GetTotalEmbeddingTime() + require.Greater(t, finalEmbeddingTime, time.Duration(0), + "Embedding time should be tracked after ingestion") +} + +// TestOptimizerIntegration_DisabledEmbeddingTime tests that embedding time is 0 when optimizer is disabled +func TestOptimizerIntegration_DisabledEmbeddingTime(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Create optimizer integration with disabled optimizer + optimizerConfig := &Config{ + Enabled: false, + } + + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + mockClient := newMockIntegrationBackendClient() + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.Nil(t, integration, "Integration should be nil when optimizer is disabled") + + // Try to ingest backends - should return nil without error + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + // This should handle nil integration gracefully + var nilIntegration *OptimizerIntegration + err = nilIntegration.IngestInitialBackends(ctx, backends) + require.NoError(t, err, "Should handle nil integration gracefully") +} + +// TestOptimizerIntegration_TokenMetrics tests that token metrics are calculated and returned in optim_find_tool +func TestOptimizerIntegration_TokenMetrics(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Create MCP server + mcpServer := server.NewMCPServer("vmcp-test", "1.0") + + // Create mock backend client with multiple tools + mockClient := newMockIntegrationBackendClient() + mockClient.addBackend("github", &vmcp.CapabilityList{ + Tools: []vmcp.Tool{ + { + Name: "create_issue", + Description: "Create a GitHub issue", + }, + { + Name: "get_pull_request", + Description: "Get a pull request from GitHub", + }, + { + Name: "list_repositories", + Description: "List repositories from GitHub", + }, + }, + }) + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: embeddings.BackendTypeOllama, + BaseURL: "http://localhost:11434", + Model: embeddings.DefaultModelAllMiniLM, + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull %s'", err, embeddings.DefaultModelAllMiniLM) + return + } + t.Cleanup(func() { _ = embeddingManager.Close() }) + + // Configure optimizer + optimizerConfig := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: embeddings.BackendTypeOllama, + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: embeddings.DefaultModelAllMiniLM, + EmbeddingDimension: 384, + } + + // Create optimizer integration + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, optimizerConfig, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + // Ingest backends + backends := []vmcp.Backend{ + { + ID: "github", + Name: "GitHub", + BaseURL: "http://localhost:8000", + TransportType: "sse", + }, + } + + err = integration.IngestInitialBackends(ctx, backends) + require.NoError(t, err) + + // Get the find_tool handler + handler := integration.CreateFindToolHandler() + require.NotNil(t, handler) + + // Call optim_find_tool + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "optim_find_tool", + Arguments: map[string]any{ + "tool_description": "create issue", + "limit": 5, + }, + }, + } + + result, err := handler(ctx, request) + require.NoError(t, err) + require.NotNil(t, result) + + // Verify result contains token_metrics + require.NotNil(t, result.Content) + require.Len(t, result.Content, 1) + textResult, ok := result.Content[0].(mcp.TextContent) + require.True(t, ok, "Result should be TextContent") + + // Parse JSON response + var response map[string]any + err = json.Unmarshal([]byte(textResult.Text), &response) + require.NoError(t, err) + + // Verify token_metrics exist + tokenMetrics, ok := response["token_metrics"].(map[string]any) + require.True(t, ok, "Response should contain token_metrics") + + // Verify token metrics fields + baselineTokens, ok := tokenMetrics["baseline_tokens"].(float64) + require.True(t, ok, "token_metrics should contain baseline_tokens") + require.Greater(t, baselineTokens, float64(0), "baseline_tokens should be greater than 0") + + returnedTokens, ok := tokenMetrics["returned_tokens"].(float64) + require.True(t, ok, "token_metrics should contain returned_tokens") + require.GreaterOrEqual(t, returnedTokens, float64(0), "returned_tokens should be >= 0") + + tokensSaved, ok := tokenMetrics["tokens_saved"].(float64) + require.True(t, ok, "token_metrics should contain tokens_saved") + require.GreaterOrEqual(t, tokensSaved, float64(0), "tokens_saved should be >= 0") + + savingsPercentage, ok := tokenMetrics["savings_percentage"].(float64) + require.True(t, ok, "token_metrics should contain savings_percentage") + require.GreaterOrEqual(t, savingsPercentage, float64(0), "savings_percentage should be >= 0") + require.LessOrEqual(t, savingsPercentage, float64(100), "savings_percentage should be <= 100") + + // Verify tools are returned + tools, ok := response["tools"].([]any) + require.True(t, ok, "Response should contain tools") + require.Greater(t, len(tools), 0, "Should return at least one tool") +} diff --git a/pkg/vmcp/optimizer/optimizer_unit_test.go b/pkg/vmcp/optimizer/optimizer_unit_test.go new file mode 100644 index 0000000000..f1dd90128d --- /dev/null +++ b/pkg/vmcp/optimizer/optimizer_unit_test.go @@ -0,0 +1,330 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package optimizer + +import ( + "context" + "path/filepath" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + transportsession "github.com/stacklok/toolhive/pkg/transport/session" + "github.com/stacklok/toolhive/pkg/vmcp" + "github.com/stacklok/toolhive/pkg/vmcp/aggregator" + "github.com/stacklok/toolhive/pkg/vmcp/optimizer/internal/embeddings" + vmcpsession "github.com/stacklok/toolhive/pkg/vmcp/session" +) + +// mockBackendClient implements vmcp.BackendClient for testing +type mockBackendClient struct { + capabilities *vmcp.CapabilityList + err error +} + +func (m *mockBackendClient) ListCapabilities(_ context.Context, _ *vmcp.BackendTarget) (*vmcp.CapabilityList, error) { + if m.err != nil { + return nil, m.err + } + return m.capabilities, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClient) CallTool(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any, _ map[string]any) (*vmcp.ToolCallResult, error) { + return &vmcp.ToolCallResult{}, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClient) GetPrompt(_ context.Context, _ *vmcp.BackendTarget, _ string, _ map[string]any) (*vmcp.PromptGetResult, error) { + return &vmcp.PromptGetResult{}, nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockBackendClient) ReadResource(_ context.Context, _ *vmcp.BackendTarget, _ string) (*vmcp.ResourceReadResult, error) { + return &vmcp.ResourceReadResult{}, nil +} + +// mockSession implements server.ClientSession for testing +type mockSession struct { + sessionID string +} + +func (m *mockSession) SessionID() string { + return m.sessionID +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Send(_ interface{}) error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Close() error { + return nil +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Initialize() { + // No-op for testing +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) Initialized() bool { + return true +} + +//nolint:revive // Receiver unused in mock implementation +func (m *mockSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + // Return a dummy channel for testing + ch := make(chan mcp.JSONRPCNotification, 1) + return ch +} + +// TestNewIntegration_Disabled tests that nil is returned when optimizer is disabled +func TestNewIntegration_Disabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + + // Test with nil config + integration, err := NewIntegration(ctx, nil, nil, nil, nil) + require.NoError(t, err) + assert.Nil(t, integration, "Should return nil when config is nil") + + // Test with disabled config + config := &Config{Enabled: false} + integration, err = NewIntegration(ctx, config, nil, nil, nil) + require.NoError(t, err) + assert.Nil(t, integration, "Should return nil when optimizer is disabled") +} + +// TestNewIntegration_Enabled tests successful creation +func TestNewIntegration_Enabled(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "nomic-embed-text", + EmbeddingDimension: 768, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + require.NotNil(t, integration) + defer func() { _ = integration.Close() }() +} + +// TestOnRegisterSession tests session registration +func TestOnRegisterSession(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "nomic-embed-text", + EmbeddingDimension: 768, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + session := &mockSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{ + Tools: []vmcp.Tool{ + { + Name: "test_tool", + Description: "A test tool", + BackendID: "backend-1", + }, + }, + RoutingTable: &vmcp.RoutingTable{ + Tools: map[string]*vmcp.BackendTarget{ + "test_tool": { + WorkloadID: "backend-1", + WorkloadName: "Test Backend", + }, + }, + Resources: map[string]*vmcp.BackendTarget{}, + Prompts: map[string]*vmcp.BackendTarget{}, + }, + } + + err = integration.OnRegisterSession(ctx, session, capabilities) + assert.NoError(t, err) +} + +// TestOnRegisterSession_NilIntegration tests nil integration handling +func TestOnRegisterSession_NilIntegration(t *testing.T) { + t.Parallel() + ctx := context.Background() + + var integration *OptimizerIntegration = nil + session := &mockSession{sessionID: "test-session"} + capabilities := &aggregator.AggregatedCapabilities{} + + err := integration.OnRegisterSession(ctx, session, capabilities) + assert.NoError(t, err) +} + +// TestRegisterTools tests tool registration behavior +func TestRegisterTools(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "nomic-embed-text", + EmbeddingDimension: 768, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + defer func() { _ = integration.Close() }() + + session := &mockSession{sessionID: "test-session"} + // RegisterTools will fail with "session not found" because the mock session + // is not actually registered with the MCP server. This is expected behavior. + // We're just testing that the method executes without panicking. + _ = integration.RegisterTools(ctx, session) +} + +// TestRegisterTools_NilIntegration tests nil integration handling +func TestRegisterTools_NilIntegration(t *testing.T) { + t.Parallel() + ctx := context.Background() + + var integration *OptimizerIntegration = nil + session := &mockSession{sessionID: "test-session"} + + err := integration.RegisterTools(ctx, session) + assert.NoError(t, err) +} + +// TestClose tests cleanup +func TestClose(t *testing.T) { + t.Parallel() + ctx := context.Background() + tmpDir := t.TempDir() + + mcpServer := server.NewMCPServer("test-server", "1.0") + mockClient := &mockBackendClient{} + + // Try to use Ollama if available, otherwise skip test + embeddingConfig := &embeddings.Config{ + BackendType: "ollama", + BaseURL: "http://localhost:11434", + Model: "all-minilm", + Dimension: 384, + } + + embeddingManager, err := embeddings.NewManager(embeddingConfig) + if err != nil { + t.Skipf("Skipping test: Ollama not available. Error: %v. Run 'ollama serve && ollama pull all-minilm'", err) + return + } + _ = embeddingManager.Close() + + config := &Config{ + Enabled: true, + PersistPath: filepath.Join(tmpDir, "optimizer-db"), + EmbeddingBackend: "ollama", + EmbeddingURL: "http://localhost:11434", + EmbeddingModel: "nomic-embed-text", + EmbeddingDimension: 768, + } + + sessionMgr := transportsession.NewManager(30*time.Minute, vmcpsession.VMCPSessionFactory()) + integration, err := NewIntegration(ctx, config, mcpServer, mockClient, sessionMgr) + require.NoError(t, err) + + err = integration.Close() + assert.NoError(t, err) + + // Multiple closes should be safe + err = integration.Close() + assert.NoError(t, err) +} + +// TestClose_NilIntegration tests nil integration close +func TestClose_NilIntegration(t *testing.T) { + t.Parallel() + + var integration *OptimizerIntegration = nil + err := integration.Close() + assert.NoError(t, err) +} diff --git a/pkg/vmcp/schema/reflect_test.go b/pkg/vmcp/schema/reflect_test.go index 55d9491019..2e0da8ed28 100644 --- a/pkg/vmcp/schema/reflect_test.go +++ b/pkg/vmcp/schema/reflect_test.go @@ -8,10 +8,85 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" ) +// FindToolInput represents the input schema for optim_find_tool +// This matches the schema defined in pkg/vmcp/optimizer/optimizer.go +type FindToolInput struct { + ToolDescription string `json:"tool_description" description:"Natural language description of the tool you're looking for"` + ToolKeywords string `json:"tool_keywords,omitempty" description:"Optional space-separated keywords for keyword-based search"` + Limit int `json:"limit,omitempty" description:"Maximum number of tools to return (default: 10)"` +} + +// CallToolInput represents the input schema for optim_call_tool +// This matches the schema defined in pkg/vmcp/optimizer/optimizer.go +type CallToolInput struct { + BackendID string `json:"backend_id" description:"Backend ID from find_tool results"` + ToolName string `json:"tool_name" description:"Tool name to invoke"` + Parameters map[string]any `json:"parameters" description:"Parameters to pass to the tool"` +} + +func TestGenerateSchema_AllTypes(t *testing.T) { + t.Parallel() + + type TestStruct struct { + StringField string `json:"string_field,omitempty"` + IntField int `json:"int_field"` + FloatField float64 `json:"float_field,omitempty"` + BoolField bool `json:"bool_field"` + OptionalStr string `json:"optional_str,omitempty"` + SliceField []int `json:"slice_field"` + MapField map[string]string `json:"map_field"` + StructField struct { + RequiredField string `json:"field"` + OptionalField string `json:"optional_field,omitempty"` + } `json:"struct_field"` + PointerField *int `json:"pointer_field"` + } + + expected := map[string]any{ + "type": "object", + "properties": map[string]any{ + "string_field": map[string]any{"type": "string"}, + "int_field": map[string]any{"type": "integer"}, + "float_field": map[string]any{"type": "number"}, + "bool_field": map[string]any{"type": "boolean"}, + "optional_str": map[string]any{"type": "string"}, + "slice_field": map[string]any{ + "type": "array", + "items": map[string]any{"type": "integer"}, + }, + "map_field": map[string]any{"type": "object"}, + "struct_field": map[string]any{ + "type": "object", + "properties": map[string]any{ + "field": map[string]any{"type": "string"}, + "optional_field": map[string]any{"type": "string"}, + }, + "required": []string{"field"}, + }, + "pointer_field": map[string]any{ + "type": "integer", + }, + }, + "required": []string{ + "int_field", + "bool_field", + "map_field", + "struct_field", + "pointer_field", + "slice_field", + }, + } + + actual, err := GenerateSchema[TestStruct]() + require.NoError(t, err) + + require.Equal(t, expected["type"], actual["type"]) + require.Equal(t, expected["properties"], actual["properties"]) + require.ElementsMatch(t, expected["required"], actual["required"]) +} + func TestGenerateSchema_FindToolInput(t *testing.T) { t.Parallel() @@ -20,18 +95,21 @@ func TestGenerateSchema_FindToolInput(t *testing.T) { "properties": map[string]any{ "tool_description": map[string]any{ "type": "string", - "description": "Natural language description of the tool to find", + "description": "Natural language description of the tool you're looking for", }, "tool_keywords": map[string]any{ - "type": "array", - "items": map[string]any{"type": "string"}, - "description": "Optional keywords to narrow search", + "type": "string", + "description": "Optional space-separated keywords for keyword-based search", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of tools to return (default: 10)", }, }, "required": []string{"tool_description"}, } - actual, err := GenerateSchema[optimizer.FindToolInput]() + actual, err := GenerateSchema[FindToolInput]() require.NoError(t, err) require.Equal(t, expected, actual) @@ -43,19 +121,23 @@ func TestGenerateSchema_CallToolInput(t *testing.T) { expected := map[string]any{ "type": "object", "properties": map[string]any{ + "backend_id": map[string]any{ + "type": "string", + "description": "Backend ID from find_tool results", + }, "tool_name": map[string]any{ "type": "string", - "description": "Name of the tool to call", + "description": "Tool name to invoke", }, "parameters": map[string]any{ "type": "object", "description": "Parameters to pass to the tool", }, }, - "required": []string{"tool_name", "parameters"}, + "required": []string{"backend_id", "tool_name", "parameters"}, } - actual, err := GenerateSchema[optimizer.CallToolInput]() + actual, err := GenerateSchema[CallToolInput]() require.NoError(t, err) require.Equal(t, expected, actual) @@ -66,15 +148,17 @@ func TestTranslate_FindToolInput(t *testing.T) { input := map[string]any{ "tool_description": "find a tool to read files", - "tool_keywords": []any{"file", "read"}, + "tool_keywords": "file read", + "limit": 5, } - result, err := Translate[optimizer.FindToolInput](input) + result, err := Translate[FindToolInput](input) require.NoError(t, err) - require.Equal(t, optimizer.FindToolInput{ + require.Equal(t, FindToolInput{ ToolDescription: "find a tool to read files", - ToolKeywords: []string{"file", "read"}, + ToolKeywords: "file read", + Limit: 5, }, result) } @@ -82,16 +166,18 @@ func TestTranslate_CallToolInput(t *testing.T) { t.Parallel() input := map[string]any{ - "tool_name": "read_file", + "backend_id": "backend-123", + "tool_name": "read_file", "parameters": map[string]any{ "path": "/etc/hosts", }, } - result, err := Translate[optimizer.CallToolInput](input) + result, err := Translate[CallToolInput](input) require.NoError(t, err) - require.Equal(t, optimizer.CallToolInput{ + require.Equal(t, CallToolInput{ + BackendID: "backend-123", ToolName: "read_file", Parameters: map[string]any{"path": "/etc/hosts"}, }, result) @@ -104,12 +190,13 @@ func TestTranslate_PartialInput(t *testing.T) { "tool_description": "find a file reader", } - result, err := Translate[optimizer.FindToolInput](input) + result, err := Translate[FindToolInput](input) require.NoError(t, err) - require.Equal(t, optimizer.FindToolInput{ + require.Equal(t, FindToolInput{ ToolDescription: "find a file reader", - ToolKeywords: nil, + ToolKeywords: "", + Limit: 0, }, result) } @@ -118,68 +205,7 @@ func TestTranslate_InvalidInput(t *testing.T) { input := make(chan int) - _, err := Translate[optimizer.FindToolInput](input) + _, err := Translate[FindToolInput](input) require.Error(t, err) assert.Contains(t, err.Error(), "failed to marshal input") } - -func TestGenerateSchema_AllTypes(t *testing.T) { - t.Parallel() - - type TestStruct struct { - StringField string `json:"string_field,omitempty"` - IntField int `json:"int_field"` - FloatField float64 `json:"float_field,omitempty"` - BoolField bool `json:"bool_field"` - OptionalStr string `json:"optional_str,omitempty"` - SliceField []int `json:"slice_field"` - MapField map[string]string `json:"map_field"` - StructField struct { - RequiredField string `json:"field"` - OptionalField string `json:"optional_field,omitempty"` - } `json:"struct_field"` - PointerField *int `json:"pointer_field"` - } - - expected := map[string]any{ - "type": "object", - "properties": map[string]any{ - "string_field": map[string]any{"type": "string"}, - "int_field": map[string]any{"type": "integer"}, - "float_field": map[string]any{"type": "number"}, - "bool_field": map[string]any{"type": "boolean"}, - "optional_str": map[string]any{"type": "string"}, - "slice_field": map[string]any{ - "type": "array", - "items": map[string]any{"type": "integer"}, - }, - "map_field": map[string]any{"type": "object"}, - "struct_field": map[string]any{ - "type": "object", - "properties": map[string]any{ - "field": map[string]any{"type": "string"}, - "optional_field": map[string]any{"type": "string"}, - }, - "required": []string{"field"}, - }, - "pointer_field": map[string]any{ - "type": "integer", - }, - }, - "required": []string{ - "int_field", - "bool_field", - "map_field", - "struct_field", - "pointer_field", - "slice_field", - }, - } - - actual, err := GenerateSchema[TestStruct]() - require.NoError(t, err) - - require.Equal(t, expected["type"], actual["type"]) - require.Equal(t, expected["properties"], actual["properties"]) - require.ElementsMatch(t, expected["required"], actual["required"]) -} diff --git a/pkg/vmcp/server/adapter/handler_factory.go b/pkg/vmcp/server/adapter/handler_factory.go index a836ef61a1..7f3cb51148 100644 --- a/pkg/vmcp/server/adapter/handler_factory.go +++ b/pkg/vmcp/server/adapter/handler_factory.go @@ -58,6 +58,17 @@ type WorkflowResult struct { Error error } +// OptimizerHandlerProvider provides handlers for optimizer tools. +// This interface allows the adapter to create optimizer tools without +// depending on the optimizer package implementation. +type OptimizerHandlerProvider interface { + // CreateFindToolHandler returns the handler for find_tool + CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) + + // CreateCallToolHandler returns the handler for call_tool + CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) +} + // DefaultHandlerFactory creates MCP request handlers that route to backend workloads. type DefaultHandlerFactory struct { router router.Router diff --git a/pkg/vmcp/server/adapter/optimizer_adapter.go b/pkg/vmcp/server/adapter/optimizer_adapter.go index 07a6f4cb72..d909024b1b 100644 --- a/pkg/vmcp/server/adapter/optimizer_adapter.go +++ b/pkg/vmcp/server/adapter/optimizer_adapter.go @@ -4,15 +4,11 @@ package adapter import ( - "context" "encoding/json" "fmt" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" - "github.com/stacklok/toolhive/pkg/vmcp/schema" ) // OptimizerToolNames defines the tool names exposed when optimizer is enabled. @@ -24,80 +20,102 @@ const ( // Pre-generated schemas for optimizer tools. // Generated at package init time so any schema errors panic at startup. var ( - findToolInputSchema = mustGenerateSchema[optimizer.FindToolInput]() - callToolInputSchema = mustGenerateSchema[optimizer.CallToolInput]() + findToolInputSchema = mustMarshalSchema(findToolSchema) + callToolInputSchema = mustMarshalSchema(callToolSchema) +) + +// Tool schemas defined once to eliminate duplication. +var ( + findToolSchema = mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "tool_description": map[string]any{ + "type": "string", + "description": "Natural language description of the tool you're looking for", + }, + "tool_keywords": map[string]any{ + "type": "string", + "description": "Optional space-separated keywords for keyword-based search", + }, + "limit": map[string]any{ + "type": "integer", + "description": "Maximum number of tools to return (default: 10)", + "default": 10, + }, + }, + Required: []string{"tool_description"}, + } + + callToolSchema = mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "backend_id": map[string]any{ + "type": "string", + "description": "Backend ID from find_tool results", + }, + "tool_name": map[string]any{ + "type": "string", + "description": "Tool name to invoke", + }, + "parameters": map[string]any{ + "type": "object", + "description": "Parameters to pass to the tool", + }, + }, + Required: []string{"backend_id", "tool_name", "parameters"}, + } ) // CreateOptimizerTools creates the SDK tools for optimizer mode. // When optimizer is enabled, only these two tools are exposed to clients // instead of all backend tools. -func CreateOptimizerTools(opt optimizer.Optimizer) []server.ServerTool { +// +// This function uses the OptimizerHandlerProvider interface to get handlers, +// allowing it to work with any optimizer implementation without direct dependency. +// +// Deprecated: Use CreateOptimizerToolsFromProvider for new code. +// This signature is kept for backward compatibility with existing server code. +func CreateOptimizerTools(provider OptimizerHandlerProvider) []server.ServerTool { + if provider == nil { + // Return empty slice for nil provider (backward compat with old behavior) + return nil + } + return []server.ServerTool{ { Tool: mcp.Tool{ Name: FindToolName, - Description: "Search for tools by description. Returns matching tools ranked by relevance.", + Description: "Semantic search across all backend tools using natural language description and optional keywords", RawInputSchema: findToolInputSchema, }, - Handler: createFindToolHandler(opt), + Handler: provider.CreateFindToolHandler(), }, { Tool: mcp.Tool{ Name: CallToolName, - Description: "Call a tool by name with the given parameters.", + Description: "Dynamically invoke any tool on any backend using the backend_id from find_tool", RawInputSchema: callToolInputSchema, }, - Handler: createCallToolHandler(opt), + Handler: provider.CreateCallToolHandler(), }, } } -// createFindToolHandler creates a handler for the find_tool optimizer operation. -func createFindToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - input, err := schema.Translate[optimizer.FindToolInput](request.Params.Arguments) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil - } - - output, err := opt.FindTool(ctx, input) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("find_tool failed: %v", err)), nil - } - - return mcp.NewToolResultStructuredOnly(output), nil +// CreateOptimizerToolsFromProvider creates optimizer tools with error handling. +// This is the preferred function for new code. +func CreateOptimizerToolsFromProvider(provider OptimizerHandlerProvider) ([]server.ServerTool, error) { + if provider == nil { + return nil, fmt.Errorf("optimizer handler provider cannot be nil") } -} -// createCallToolHandler creates a handler for the call_tool optimizer operation. -func createCallToolHandler(opt optimizer.Optimizer) func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { - return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - input, err := schema.Translate[optimizer.CallToolInput](request.Params.Arguments) - if err != nil { - return mcp.NewToolResultError(fmt.Sprintf("invalid arguments: %v", err)), nil - } - - result, err := opt.CallTool(ctx, input) - if err != nil { - // Exposing the error to the MCP client is important if you want it to correct its behavior. - // Without information on the failure, the model is pretty much hopeless in figuring out the problem. - return mcp.NewToolResultError(fmt.Sprintf("call_tool failed: %v", err)), nil - } - - return result, nil - } + return CreateOptimizerTools(provider), nil } // mustMarshalSchema marshals a schema to JSON, panicking on error. // This is safe because schemas are generated from known types at startup. // This should NOT be called by runtime code. -func mustGenerateSchema[T any]() json.RawMessage { - s, err := schema.GenerateSchema[T]() - if err != nil { - panic(fmt.Sprintf("failed to generate schema: %v", err)) - } - - data, err := json.Marshal(s) +func mustMarshalSchema(schema mcp.ToolInputSchema) json.RawMessage { + data, err := json.Marshal(schema) if err != nil { panic(fmt.Sprintf("failed to marshal schema: %v", err)) } diff --git a/pkg/vmcp/server/adapter/optimizer_adapter_test.go b/pkg/vmcp/server/adapter/optimizer_adapter_test.go index b5ad7e066a..35da65f425 100644 --- a/pkg/vmcp/server/adapter/optimizer_adapter_test.go +++ b/pkg/vmcp/server/adapter/optimizer_adapter_test.go @@ -9,65 +9,94 @@ import ( "github.com/mark3labs/mcp-go/mcp" "github.com/stretchr/testify/require" - - "github.com/stacklok/toolhive/pkg/vmcp/optimizer" ) -// mockOptimizer implements optimizer.Optimizer for testing. -type mockOptimizer struct { - findToolFunc func(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) - callToolFunc func(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) +// mockOptimizerHandlerProvider implements OptimizerHandlerProvider for testing. +type mockOptimizerHandlerProvider struct { + findToolHandler func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) + callToolHandler func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) } -func (m *mockOptimizer) FindTool(ctx context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { - if m.findToolFunc != nil { - return m.findToolFunc(ctx, input) +func (m *mockOptimizerHandlerProvider) CreateFindToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if m.findToolHandler != nil { + return m.findToolHandler + } + return func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("ok"), nil } - return &optimizer.FindToolOutput{}, nil } -func (m *mockOptimizer) CallTool(ctx context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { - if m.callToolFunc != nil { - return m.callToolFunc(ctx, input) +func (m *mockOptimizerHandlerProvider) CreateCallToolHandler() func(context.Context, mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if m.callToolHandler != nil { + return m.callToolHandler + } + return func(_ context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("ok"), nil } - return mcp.NewToolResultText("ok"), nil } func TestCreateOptimizerTools(t *testing.T) { t.Parallel() - opt := &mockOptimizer{} - tools := CreateOptimizerTools(opt) + provider := &mockOptimizerHandlerProvider{} + tools := CreateOptimizerTools(provider) require.Len(t, tools, 2) require.Equal(t, FindToolName, tools[0].Tool.Name) require.Equal(t, CallToolName, tools[1].Tool.Name) } +func TestCreateOptimizerTools_NilProvider(t *testing.T) { + t.Parallel() + + tools := CreateOptimizerTools(nil) + + require.Nil(t, tools) +} + +func TestCreateOptimizerToolsFromProvider(t *testing.T) { + t.Parallel() + + provider := &mockOptimizerHandlerProvider{} + tools, err := CreateOptimizerToolsFromProvider(provider) + + require.NoError(t, err) + require.Len(t, tools, 2) + require.Equal(t, FindToolName, tools[0].Tool.Name) + require.Equal(t, CallToolName, tools[1].Tool.Name) +} + +func TestCreateOptimizerToolsFromProvider_NilProvider(t *testing.T) { + t.Parallel() + + tools, err := CreateOptimizerToolsFromProvider(nil) + + require.Error(t, err) + require.Nil(t, tools) + require.Contains(t, err.Error(), "cannot be nil") +} + func TestFindToolHandler(t *testing.T) { t.Parallel() - opt := &mockOptimizer{ - findToolFunc: func(_ context.Context, input optimizer.FindToolInput) (*optimizer.FindToolOutput, error) { - require.Equal(t, "read files", input.ToolDescription) - return &optimizer.FindToolOutput{ - Tools: []optimizer.ToolMatch{ - { - Name: "read_file", - Description: "Read a file", - Score: 1.0, - }, - }, - }, nil + provider := &mockOptimizerHandlerProvider{ + findToolHandler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args, ok := req.Params.Arguments.(map[string]any) + require.True(t, ok) + require.Equal(t, "read files", args["tool_description"]) + return mcp.NewToolResultText("found tools"), nil }, } - tools := CreateOptimizerTools(opt) + tools := CreateOptimizerTools(provider) handler := tools[0].Handler - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]any{ - "tool_description": "read files", + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "tool_description": "read files", + }, + }, } result, err := handler(context.Background(), request) @@ -80,22 +109,28 @@ func TestFindToolHandler(t *testing.T) { func TestCallToolHandler(t *testing.T) { t.Parallel() - opt := &mockOptimizer{ - callToolFunc: func(_ context.Context, input optimizer.CallToolInput) (*mcp.CallToolResult, error) { - require.Equal(t, "read_file", input.ToolName) - require.Equal(t, "/etc/hosts", input.Parameters["path"]) + provider := &mockOptimizerHandlerProvider{ + callToolHandler: func(_ context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + args, ok := req.Params.Arguments.(map[string]any) + require.True(t, ok) + require.Equal(t, "read_file", args["tool_name"]) + params := args["parameters"].(map[string]any) + require.Equal(t, "/etc/hosts", params["path"]) return mcp.NewToolResultText("file contents here"), nil }, } - tools := CreateOptimizerTools(opt) + tools := CreateOptimizerTools(provider) handler := tools[1].Handler - request := mcp.CallToolRequest{} - request.Params.Arguments = map[string]any{ - "tool_name": "read_file", - "parameters": map[string]any{ - "path": "/etc/hosts", + request := mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Arguments: map[string]any{ + "tool_name": "read_file", + "parameters": map[string]any{ + "path": "/etc/hosts", + }, + }, }, }