diff --git a/pkg/registry/mocks/mock_provider.go b/pkg/registry/mocks/mock_provider.go index 883c9f16bd..b206f6ae4d 100644 --- a/pkg/registry/mocks/mock_provider.go +++ b/pkg/registry/mocks/mock_provider.go @@ -13,6 +13,7 @@ import ( reflect "reflect" registry "github.com/stacklok/toolhive-core/registry/types" + api "github.com/stacklok/toolhive/pkg/registry/api" gomock "go.uber.org/mock/gomock" ) @@ -85,6 +86,21 @@ func (mr *MockProviderMockRecorder) GetServer(name any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServer", reflect.TypeOf((*MockProvider)(nil).GetServer), name) } +// GetSkill mocks base method. +func (m *MockProvider) GetSkill(namespace, name string) (*registry.Skill, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetSkill", namespace, name) + ret0, _ := ret[0].(*registry.Skill) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetSkill indicates an expected call of GetSkill. +func (mr *MockProviderMockRecorder) GetSkill(namespace, name any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSkill", reflect.TypeOf((*MockProvider)(nil).GetSkill), namespace, name) +} + // ListImageServers mocks base method. func (m *MockProvider) ListImageServers() ([]*registry.ImageMetadata, error) { m.ctrl.T.Helper() @@ -115,6 +131,21 @@ func (mr *MockProviderMockRecorder) ListServers() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListServers", reflect.TypeOf((*MockProvider)(nil).ListServers)) } +// ListSkills mocks base method. +func (m *MockProvider) ListSkills(opts *api.SkillsListOptions) (*api.SkillsListResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListSkills", opts) + ret0, _ := ret[0].(*api.SkillsListResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListSkills indicates an expected call of ListSkills. +func (mr *MockProviderMockRecorder) ListSkills(opts any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListSkills", reflect.TypeOf((*MockProvider)(nil).ListSkills), opts) +} + // SearchImageServers mocks base method. func (m *MockProvider) SearchImageServers(query string) ([]*registry.ImageMetadata, error) { m.ctrl.T.Helper() @@ -144,3 +175,18 @@ func (mr *MockProviderMockRecorder) SearchServers(query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchServers", reflect.TypeOf((*MockProvider)(nil).SearchServers), query) } + +// SearchSkills mocks base method. +func (m *MockProvider) SearchSkills(query string) (*api.SkillsListResult, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SearchSkills", query) + ret0, _ := ret[0].(*api.SkillsListResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SearchSkills indicates an expected call of SearchSkills. +func (mr *MockProviderMockRecorder) SearchSkills(query any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchSkills", reflect.TypeOf((*MockProvider)(nil).SearchSkills), query) +} diff --git a/pkg/registry/provider.go b/pkg/registry/provider.go index b4abfd4260..3038507d63 100644 --- a/pkg/registry/provider.go +++ b/pkg/registry/provider.go @@ -3,7 +3,10 @@ package registry -import types "github.com/stacklok/toolhive-core/registry/types" +import ( + types "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/api" +) //go:generate mockgen -destination=mocks/mock_provider.go -package=mocks -source=provider.go Provider @@ -30,4 +33,14 @@ type Provider interface { // ListImageServers returns all available container servers ListImageServers() ([]*types.ImageMetadata, error) + + // Skills methods + // Providers that don't support skills (Local, Remote) return nil/empty results via BaseProvider. + + // GetSkill returns a specific skill by namespace and name + GetSkill(namespace, name string) (*types.Skill, error) + // ListSkills returns all available skills + ListSkills(opts *api.SkillsListOptions) (*api.SkillsListResult, error) + // SearchSkills searches for skills matching the query + SearchSkills(query string) (*api.SkillsListResult, error) } diff --git a/pkg/registry/provider_api.go b/pkg/registry/provider_api.go index f195119b52..b65fd95413 100644 --- a/pkg/registry/provider_api.go +++ b/pkg/registry/provider_api.go @@ -23,6 +23,7 @@ type APIRegistryProvider struct { apiURL string allowPrivateIp bool client api.Client + skillsClient api.SkillsClient } // NewAPIRegistryProvider creates a new API registry provider. @@ -34,10 +35,17 @@ func NewAPIRegistryProvider(apiURL string, allowPrivateIp bool, tokenSource auth return nil, fmt.Errorf("failed to create API client: %w", err) } + // Create skills client + skillsClient, err := api.NewSkillsClient(apiURL, allowPrivateIp, tokenSource) + if err != nil { + return nil, fmt.Errorf("failed to create skills client: %w", err) + } + p := &APIRegistryProvider{ apiURL: apiURL, allowPrivateIp: allowPrivateIp, client: client, + skillsClient: skillsClient, } // Initialize the base provider with the GetRegistry function @@ -188,6 +196,27 @@ func (p *APIRegistryProvider) GetImageServer(name string) (*types.ImageMetadata, return nil, fmt.Errorf("server %s is not a container server", name) } +// GetSkill returns a specific skill by namespace and name. +func (p *APIRegistryProvider) GetSkill(namespace, name string) (*types.Skill, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return p.skillsClient.GetSkill(ctx, namespace, name) +} + +// ListSkills returns all available skills. +func (p *APIRegistryProvider) ListSkills(opts *api.SkillsListOptions) (*api.SkillsListResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + return p.skillsClient.ListSkills(ctx, opts) +} + +// SearchSkills searches for skills matching the query. +func (p *APIRegistryProvider) SearchSkills(query string) (*api.SkillsListResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return p.skillsClient.SearchSkills(ctx, query) +} + // ConvertServerJSON converts an MCP Registry API ServerJSON to ToolHive ServerMetadata // Uses converters from converters.go (same package) // Note: Only handles OCI packages and remote servers, skips npm/pypi by design diff --git a/pkg/registry/provider_base.go b/pkg/registry/provider_base.go index e777a8bc8c..da59f631ef 100644 --- a/pkg/registry/provider_base.go +++ b/pkg/registry/provider_base.go @@ -8,6 +8,7 @@ import ( "strings" types "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/api" ) // BaseProvider provides common implementation for registry providers @@ -131,6 +132,21 @@ func (p *BaseProvider) ListImageServers() ([]*types.ImageMetadata, error) { return results, nil } +// GetSkill returns nil for providers that don't support skills. +func (*BaseProvider) GetSkill(_, _ string) (*types.Skill, error) { + return nil, nil +} + +// ListSkills returns empty results for providers that don't support skills. +func (*BaseProvider) ListSkills(_ *api.SkillsListOptions) (*api.SkillsListResult, error) { + return &api.SkillsListResult{}, nil +} + +// SearchSkills returns empty results for providers that don't support skills. +func (*BaseProvider) SearchSkills(_ string) (*api.SkillsListResult, error) { + return &api.SkillsListResult{}, nil +} + // matchesQuery checks if a server matches the search query func matchesQuery(name, description string, tags []string, query string) bool { // Search in name diff --git a/pkg/registry/provider_cached.go b/pkg/registry/provider_cached.go index 4a6bef6b67..7752170ebd 100644 --- a/pkg/registry/provider_cached.go +++ b/pkg/registry/provider_cached.go @@ -30,6 +30,8 @@ const ( // CachedAPIRegistryProvider wraps APIRegistryProvider with caching support. // Provides both in-memory and optional persistent file caching. // Works for both CLI (with persistent cache) and API server (memory only). +// Skills methods (GetSkill, ListSkills, SearchSkills) are uncached pass-through +// to the underlying APIRegistryProvider; skills caching may be added later. type CachedAPIRegistryProvider struct { *APIRegistryProvider diff --git a/pkg/registry/skills_provider_test.go b/pkg/registry/skills_provider_test.go new file mode 100644 index 0000000000..39b31f575d --- /dev/null +++ b/pkg/registry/skills_provider_test.go @@ -0,0 +1,339 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + thvregistry "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/api" +) + +func TestBaseProvider_SkillMethods(t *testing.T) { + t.Parallel() + + bp := NewBaseProvider(func() (*thvregistry.Registry, error) { + return &thvregistry.Registry{}, nil + }) + + t.Run("GetSkill returns nil", func(t *testing.T) { + t.Parallel() + skill, err := bp.GetSkill("any-namespace", "any-name") + require.NoError(t, err) + require.Nil(t, skill) + }) + + t.Run("ListSkills returns empty result", func(t *testing.T) { + t.Parallel() + result, err := bp.ListSkills(nil) + require.NoError(t, err) + require.NotNil(t, result) + require.Empty(t, result.Skills) + }) + + t.Run("SearchSkills returns empty result", func(t *testing.T) { + t.Parallel() + result, err := bp.SearchSkills("any-query") + require.NoError(t, err) + require.NotNil(t, result) + require.Empty(t, result.Skills) + }) +} + +func TestLocalRegistryProvider_SkillMethods(t *testing.T) { + t.Parallel() + + provider := NewLocalRegistryProvider() + + t.Run("GetSkill returns nil", func(t *testing.T) { + t.Parallel() + skill, err := provider.GetSkill("any-namespace", "any-name") + require.NoError(t, err) + require.Nil(t, skill) + }) + + t.Run("ListSkills returns empty result", func(t *testing.T) { + t.Parallel() + result, err := provider.ListSkills(&api.SkillsListOptions{Search: "test"}) + require.NoError(t, err) + require.NotNil(t, result) + require.Empty(t, result.Skills) + }) + + t.Run("SearchSkills returns empty result", func(t *testing.T) { + t.Parallel() + result, err := provider.SearchSkills("any-query") + require.NoError(t, err) + require.NotNil(t, result) + require.Empty(t, result.Skills) + }) +} + +func TestAPIRegistryProvider_GetSkill(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + namespace string + skillName string + handler http.HandlerFunc + wantSkill *thvregistry.Skill + wantErr bool + }{ + { + name: "returns skill from API", + namespace: "io.github.user", + skillName: "my-skill", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills/") { + // Handle the validation probe for ListServers + writeEmptyServerList(w) + return + } + assert.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user/my-skill", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + Description: "A test skill", + }) + assert.NoError(t, err) + }, + wantSkill: &thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + Description: "A test skill", + }, + }, + { + name: "returns error on not found", + namespace: "io.github.user", + skillName: "nonexistent", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills/") { + writeEmptyServerList(w) + return + } + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("skill not found")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + provider, err := NewAPIRegistryProvider(server.URL, true, nil) + require.NoError(t, err) + + skill, err := provider.GetSkill(tt.namespace, tt.skillName) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantSkill, skill) + }) + } +} + +func TestAPIRegistryProvider_ListSkills(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts *api.SkillsListOptions + handler http.HandlerFunc + wantSkills []*thvregistry.Skill + wantEmpty bool + wantErr bool + }{ + { + name: "returns skills list", + opts: nil, + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListWireResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "2.0.0"}, + }, + Metadata: skillsListMetadata{Count: 2, NextCursor: ""}, + }) + assert.NoError(t, err) + }, + wantSkills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "2.0.0"}, + }, + }, + { + name: "returns empty list", + opts: nil, + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListWireResponse{ + Skills: []*thvregistry.Skill{}, + Metadata: skillsListMetadata{Count: 0, NextCursor: ""}, + }) + assert.NoError(t, err) + }, + wantEmpty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + provider, err := NewAPIRegistryProvider(server.URL, true, nil) + require.NoError(t, err) + + result, err := provider.ListSkills(tt.opts) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, result) + if tt.wantEmpty { + require.Empty(t, result.Skills) + } else { + require.Equal(t, tt.wantSkills, result.Skills) + } + }) + } +} + +func TestAPIRegistryProvider_SearchSkills(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + query string + handler http.HandlerFunc + wantSkills []*thvregistry.Skill + wantEmpty bool + wantErr bool + }{ + { + name: "returns matching skills", + query: "kubernetes", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + assert.Equal(t, "kubernetes", r.URL.Query().Get("search")) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListWireResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "k8s-deploy", Version: "1.0.0", Description: "Kubernetes deploy skill"}, + }, + Metadata: skillsListMetadata{Count: 1, NextCursor: ""}, + }) + assert.NoError(t, err) + }, + wantSkills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "k8s-deploy", Version: "1.0.0", Description: "Kubernetes deploy skill"}, + }, + }, + { + name: "returns empty for no matches", + query: "nonexistent-query", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListWireResponse{ + Skills: []*thvregistry.Skill{}, + Metadata: skillsListMetadata{Count: 0, NextCursor: ""}, + }) + assert.NoError(t, err) + }, + wantEmpty: true, + }, + { + name: "returns error on server failure", + query: "test", + handler: func(w http.ResponseWriter, r *http.Request) { + if !strings.HasPrefix(r.URL.Path, "/v0.1/x/dev.toolhive/skills") { + writeEmptyServerList(w) + return + } + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal server error")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + provider, err := NewAPIRegistryProvider(server.URL, true, nil) + require.NoError(t, err) + + result, err := provider.SearchSkills(tt.query) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, result) + if tt.wantEmpty { + require.Empty(t, result.Skills) + } else { + require.Equal(t, tt.wantSkills, result.Skills) + } + }) + } +} + +// skillsListWireResponse mirrors the JSON wire format for skills list/search responses. +// This is used in test handlers to produce realistic API responses. +type skillsListWireResponse struct { + Skills []*thvregistry.Skill `json:"skills"` + Metadata skillsListMetadata `json:"metadata"` +} + +type skillsListMetadata struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` +} + +// writeEmptyServerList writes a minimal valid ServerListResponse for the +// validation probe that NewAPIRegistryProvider performs (GET /v0.1/servers?limit=1&version=latest). +func writeEmptyServerList(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"servers":[],"metadata":{"count":0,"nextCursor":""}}`)) +}