From 7d191e664b2ad46747fc97961e33b81d47f17764 Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Tue, 17 Mar 2026 13:22:42 +0200 Subject: [PATCH] Add skills methods to registry Provider interface PR #4173 added SkillsClient for querying the registry's skills extension API. This wires that client into the registry provider layer so callers can discover skills through the same Provider interface used for servers. - Extend Provider interface with GetSkill, ListSkills, SearchSkills - Add default empty implementations on BaseProvider (inherited by Local/Remote providers that don't serve skills) - Add real implementations on APIRegistryProvider that delegate to the SkillsClient with appropriate timeouts - CachedAPIRegistryProvider inherits skills methods via embedding (uncached pass-through for now) - Regenerate mock provider Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/registry/mocks/mock_provider.go | 46 ++++ pkg/registry/provider.go | 15 +- pkg/registry/provider_api.go | 29 +++ pkg/registry/provider_base.go | 16 ++ pkg/registry/provider_cached.go | 2 + pkg/registry/skills_provider_test.go | 339 +++++++++++++++++++++++++++ 6 files changed, 446 insertions(+), 1 deletion(-) create mode 100644 pkg/registry/skills_provider_test.go diff --git a/pkg/registry/mocks/mock_provider.go b/pkg/registry/mocks/mock_provider.go index 883c9f16bd..b206f6ae4d 100644 --- a/pkg/registry/mocks/mock_provider.go +++ b/pkg/registry/mocks/mock_provider.go @@ -13,6 +13,7 @@ import ( reflect "reflect" registry "github.com/stacklok/toolhive-core/registry/types" + api "github.com/stacklok/toolhive/pkg/registry/api" gomock "go.uber.org/mock/gomock" ) @@ -85,6 +86,21 @@ func (mr *MockProviderMockRecorder) GetServer(name any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServer", reflect.TypeOf((*MockProvider)(nil).GetServer), name) } +// GetSkill mocks base method. +func (m *MockProvider) GetSkill(namespace, name string) (*registry.Skill, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSkill", namespace, name) + ret0, _ := ret[0].(*registry.Skill) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSkill indicates an expected call of GetSkill. +func (mr *MockProviderMockRecorder) GetSkill(namespace, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSkill", reflect.TypeOf((*MockProvider)(nil).GetSkill), namespace, name) +} + // ListImageServers mocks base method. func (m *MockProvider) ListImageServers() ([]*registry.ImageMetadata, error) { m.ctrl.T.Helper() @@ -115,6 +131,21 @@ func (mr *MockProviderMockRecorder) ListServers() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListServers", reflect.TypeOf((*MockProvider)(nil).ListServers)) } +// ListSkills mocks base method. +func (m *MockProvider) ListSkills(opts *api.SkillsListOptions) (*api.SkillsListResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListSkills", opts) + ret0, _ := ret[0].(*api.SkillsListResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListSkills indicates an expected call of ListSkills. +func (mr *MockProviderMockRecorder) ListSkills(opts any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSkills", reflect.TypeOf((*MockProvider)(nil).ListSkills), opts) +} + // SearchImageServers mocks base method. func (m *MockProvider) SearchImageServers(query string) ([]*registry.ImageMetadata, error) { m.ctrl.T.Helper() @@ -144,3 +175,18 @@ func (mr *MockProviderMockRecorder) SearchServers(query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchServers", reflect.TypeOf((*MockProvider)(nil).SearchServers), query) } + +// SearchSkills mocks base method. +func (m *MockProvider) SearchSkills(query string) (*api.SkillsListResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SearchSkills", query) + ret0, _ := ret[0].(*api.SkillsListResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SearchSkills indicates an expected call of SearchSkills. +func (mr *MockProviderMockRecorder) SearchSkills(query any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchSkills", reflect.TypeOf((*MockProvider)(nil).SearchSkills), query) +} diff --git a/pkg/registry/provider.go b/pkg/registry/provider.go index b4abfd4260..3038507d63 100644 --- a/pkg/registry/provider.go +++ b/pkg/registry/provider.go @@ -3,7 +3,10 @@ package registry -import types "github.com/stacklok/toolhive-core/registry/types" +import ( + types "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/api" +) //go:generate mockgen -destination=mocks/mock_provider.go -package=mocks -source=provider.go Provider @@ -30,4 +33,14 @@ type Provider interface { // ListImageServers returns all available container servers ListImageServers() ([]*types.ImageMetadata, error) + + // Skills methods + // Providers that don't support skills (Local, Remote) return nil/empty results via BaseProvider. + + // GetSkill returns a specific skill by namespace and name + GetSkill(namespace, name string) (*types.Skill, error) + // ListSkills returns all available skills + ListSkills(opts *api.SkillsListOptions) (*api.SkillsListResult, error) + // SearchSkills searches for skills matching the query + SearchSkills(query string) (*api.SkillsListResult, error) } diff --git a/pkg/registry/provider_api.go b/pkg/registry/provider_api.go index f195119b52..b65fd95413 100644 --- a/pkg/registry/provider_api.go +++ b/pkg/registry/provider_api.go @@ -23,6 +23,7 @@ type APIRegistryProvider struct { apiURL string allowPrivateIp bool client api.Client + skillsClient api.SkillsClient } // NewAPIRegistryProvider creates a new API registry provider. @@ -34,10 +35,17 @@ func NewAPIRegistryProvider(apiURL string, allowPrivateIp bool, tokenSource auth return nil, fmt.Errorf("failed to create API client: %w", err) } + // Create skills client + skillsClient, err := api.NewSkillsClient(apiURL, allowPrivateIp, tokenSource) + if err != nil { + return nil, fmt.Errorf("failed to create skills client: %w", err) + } + p := &APIRegistryProvider{ apiURL: apiURL, allowPrivateIp: allowPrivateIp, client: client, + skillsClient: skillsClient, } // Initialize the base provider with the GetRegistry function @@ -188,6 +196,27 @@ func (p *APIRegistryProvider) GetImageServer(name string) (*types.ImageMetadata, return nil, fmt.Errorf("server %s is not a container server", name) } +// GetSkill returns a specific skill by namespace and name. +func (p *APIRegistryProvider) GetSkill(namespace, name string) (*types.Skill, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return p.skillsClient.GetSkill(ctx, namespace, name) +} + +// ListSkills returns all available skills. +func (p *APIRegistryProvider) ListSkills(opts *api.SkillsListOptions) (*api.SkillsListResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + return p.skillsClient.ListSkills(ctx, opts) +} + +// SearchSkills searches for skills matching the query. +func (p *APIRegistryProvider) SearchSkills(query string) (*api.SkillsListResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return p.skillsClient.SearchSkills(ctx, query) +} + // ConvertServerJSON converts an MCP Registry API ServerJSON to ToolHive ServerMetadata // Uses converters from converters.go (same package) // Note: Only handles OCI packages and remote servers, skips npm/pypi by design diff --git a/pkg/registry/provider_base.go b/pkg/registry/provider_base.go index e777a8bc8c..da59f631ef 100644 --- a/pkg/registry/provider_base.go +++ b/pkg/registry/provider_base.go @@ -8,6 +8,7 @@ import ( "strings" types "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/api" ) // BaseProvider provides common implementation for registry providers @@ -131,6 +132,21 @@ func (p *BaseProvider) ListImageServers() ([]*types.ImageMetadata, error) { return results, nil } +// GetSkill returns nil for providers that don't support skills. +func (*BaseProvider) GetSkill(_, _ string) (*types.Skill, error) { + return nil, nil +} + +// ListSkills returns empty results for providers that don't support skills. +func (*BaseProvider) ListSkills(_ *api.SkillsListOptions) (*api.SkillsListResult, error) { + return &api.SkillsListResult{}, nil +} + +// SearchSkills returns empty results for providers that don't support skills. +func (*BaseProvider) SearchSkills(_ string) (*api.SkillsListResult, error) { + return &api.SkillsListResult{}, nil +} + // matchesQuery checks if a server matches the search query func matchesQuery(name, description string, tags []string, query string) bool { // Search in name diff --git a/pkg/registry/provider_cached.go b/pkg/registry/provider_cached.go index 4a6bef6b67..7752170ebd 100644 --- a/pkg/registry/provider_cached.go +++ b/pkg/registry/provider_cached.go @@ -30,6 +30,8 @@ const ( // CachedAPIRegistryProvider wraps APIRegistryProvider with caching support. // Provides both in-memory and optional persistent file caching. // Works for both CLI (with persistent cache) and API server (memory only). +// Skills methods (GetSkill, ListSkills, SearchSkills) are uncached pass-through +// to the underlying APIRegistryProvider; skills caching may be added later. type CachedAPIRegistryProvider struct { *APIRegistryProvider diff --git a/pkg/registry/skills_provider_test.go b/pkg/registry/skills_provider_test.go new file mode 100644 index 0000000000..39b31f575d --- /dev/null +++ b/pkg/registry/skills_provider_test.go @@ -0,0 +1,339 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + thvregistry "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/api" +) + +func TestBaseProvider_SkillMethods(t *testing.T) { + t.Parallel() + + bp := NewBaseProvider(func() (*thvregistry.Registry, error) { + return &thvregistry.Registry{}, nil + }) + + t.Run("GetSkill returns nil", func(t *testing.T) { + t.Parallel() + skill, err := bp.GetSkill("any-namespace", "any-name") + require.NoError(t, err) + require.Nil(t, skill) + }) + + t.Run("ListSkills returns empty result", func(t *testing.T) { + t.Parallel() + result, err := bp.ListSkills(nil) + require.NoError(t, err) + require.NotNil(t, result) + require.Empty(t, result.Skills) + }) + + t.Run("SearchSkills returns empty result", func(t *testing.T) { + t.Parallel() + result, err := bp.SearchSkills("any-query") + require.NoError(t, err) + require.NotNil(t, result) + require.Empty(t, result.Skills) + }) +} + +func TestLocalRegistryProvider_SkillMethods(t *testing.T) { + t.Parallel() + + provider := NewLocalRegistryProvider() + + t.Run("GetSkill returns nil", func(t *testing.T) { + t.Parallel() + skill, err := provider.GetSkill("any-namespace", "any-name") + require.NoError(t, err) + require.Nil(t, skill) + }) + + t.Run("ListSkills returns empty result", func(t *testing.T) { + t.Parallel() + result, err := provider.ListSkills(&api.SkillsListOptions{Search: "test"}) + require.NoError(t, err) + require.NotNil(t, result) + require.Empty(t, result.Skills) + }) + + t.Run("SearchSkills returns empty result", func(t *testing.T) { + t.Parallel() + result, err := provider.SearchSkills("any-query") + require.NoError(t, err) + require.NotNil(t, result) + require.Empty(t, result.Skills) + }) +} + +func TestAPIRegistryProvider_GetSkill(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + namespace string + skillName string + handler http.HandlerFunc + wantSkill *thvregistry.Skill + wantErr bool + }{ + { + name: "returns skill from API", + namespace: "io.github.user", + skillName: "my-skill", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills/") { + // Handle the validation probe for ListServers + writeEmptyServerList(w) + return + } + assert.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user/my-skill", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + Description: "A test skill", + }) + assert.NoError(t, err) + }, + wantSkill: &thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + Description: "A test skill", + }, + }, + { + name: "returns error on not found", + namespace: "io.github.user", + skillName: "nonexistent", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills/") { + writeEmptyServerList(w) + return + } + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("skill not found")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + provider, err := NewAPIRegistryProvider(server.URL, true, nil) + require.NoError(t, err) + + skill, err := provider.GetSkill(tt.namespace, tt.skillName) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantSkill, skill) + }) + } +} + +func TestAPIRegistryProvider_ListSkills(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts *api.SkillsListOptions + handler http.HandlerFunc + wantSkills []*thvregistry.Skill + wantEmpty bool + wantErr bool + }{ + { + name: "returns skills list", + opts: nil, + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListWireResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "2.0.0"}, + }, + Metadata: skillsListMetadata{Count: 2, NextCursor: ""}, + }) + assert.NoError(t, err) + }, + wantSkills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "2.0.0"}, + }, + }, + { + name: "returns empty list", + opts: nil, + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListWireResponse{ + Skills: []*thvregistry.Skill{}, + Metadata: skillsListMetadata{Count: 0, NextCursor: ""}, + }) + assert.NoError(t, err) + }, + wantEmpty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + provider, err := NewAPIRegistryProvider(server.URL, true, nil) + require.NoError(t, err) + + result, err := provider.ListSkills(tt.opts) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, result) + if tt.wantEmpty { + require.Empty(t, result.Skills) + } else { + require.Equal(t, tt.wantSkills, result.Skills) + } + }) + } +} + +func TestAPIRegistryProvider_SearchSkills(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + query string + handler http.HandlerFunc + wantSkills []*thvregistry.Skill + wantEmpty bool + wantErr bool + }{ + { + name: "returns matching skills", + query: "kubernetes", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + assert.Equal(t, "kubernetes", r.URL.Query().Get("search")) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListWireResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "k8s-deploy", Version: "1.0.0", Description: "Kubernetes deploy skill"}, + }, + Metadata: skillsListMetadata{Count: 1, NextCursor: ""}, + }) + assert.NoError(t, err) + }, + wantSkills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "k8s-deploy", Version: "1.0.0", Description: "Kubernetes deploy skill"}, + }, + }, + { + name: "returns empty for no matches", + query: "nonexistent-query", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListWireResponse{ + Skills: []*thvregistry.Skill{}, + Metadata: skillsListMetadata{Count: 0, NextCursor: ""}, + }) + assert.NoError(t, err) + }, + wantEmpty: true, + }, + { + name: "returns error on server failure", + query: "test", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal server error")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + provider, err := NewAPIRegistryProvider(server.URL, true, nil) + require.NoError(t, err) + + result, err := provider.SearchSkills(tt.query) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, result) + if tt.wantEmpty { + require.Empty(t, result.Skills) + } else { + require.Equal(t, tt.wantSkills, result.Skills) + } + }) + } +} + +// skillsListWireResponse mirrors the JSON wire format for skills list/search responses. +// This is used in test handlers to produce realistic API responses. +type skillsListWireResponse struct { + Skills []*thvregistry.Skill `json:"skills"` + Metadata skillsListMetadata `json:"metadata"` +} + +type skillsListMetadata struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` +} + +// writeEmptyServerList writes a minimal valid ServerListResponse for the +// validation probe that NewAPIRegistryProvider performs (GET /v0.1/servers?limit=1&version=latest). +func writeEmptyServerList(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"servers":[],"metadata":{"count":0,"nextCursor":""}}`)) +}