diff --git a/pkg/config/registry.go b/pkg/config/registry.go index 04d9847ca7..6f9e76e89c 100644 --- a/pkg/config/registry.go +++ b/pkg/config/registry.go @@ -8,6 +8,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "log/slog" "net" "net/http" @@ -151,7 +152,7 @@ func isValidAPIResponse(resp *http.Response) bool { } // isValidRegistryJSON checks if a URL returns valid ToolHive registry JSON -// by attempting to parse it into the actual Registry type +// by attempting to parse it. Accepts both upstream and legacy formats. func isValidRegistryJSON(client *http.Client, url string) error { resp, err := client.Get(url) if err != nil { @@ -163,9 +164,26 @@ func isValidRegistryJSON(client *http.Client, url string) error { } }() - // Parse into the actual Registry type for strong validation + data, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("%w: failed to read response body: %v", ErrRegistryValidationFailed, err) + } + + // Try upstream format first + if isUpstreamRegistryFormat(data) { + var upstream registrytypes.UpstreamRegistry + if err := json.Unmarshal(data, &upstream); err != nil { + return fmt.Errorf("%w: invalid upstream JSON format: %v", ErrRegistryValidationFailed, err) + } + if len(upstream.Data.Servers) > 0 || len(upstream.Data.Groups) > 0 { + return nil + } + return fmt.Errorf("%w: upstream registry contains no servers", ErrRegistryValidationFailed) + } + + // Fall back to legacy format registry := ®istrytypes.Registry{} - if err := json.NewDecoder(resp.Body).Decode(registry); err != nil { + if err := json.Unmarshal(data, registry); err != nil { return fmt.Errorf("%w: invalid JSON format: %v", ErrRegistryValidationFailed, err) } @@ -314,6 +332,22 @@ func setRegistryFile(provider Provider, registryPath string) error { return nil } +// isUpstreamRegistryFormat returns true if the JSON data appears to be in the +// upstream MCP registry format. The key discriminator is the "data" wrapper +// object — only the upstream format wraps servers inside it. +// NOTE: keep in sync with isUpstreamFormat in pkg/registry/upstream_parser.go +// (duplicated to avoid a circular import). +func isUpstreamRegistryFormat(data []byte) bool { + var probe struct { + Data json.RawMessage `json:"data"` + } + if err := json.Unmarshal(data, &probe); err != nil { + return false + } + // The "data" wrapper object is unique to the upstream format. + return len(probe.Data) > 0 && probe.Data[0] == '{' +} + // registryHasServers checks if a registry contains at least one server // (either in top-level servers/remote_servers or within groups) func registryHasServers(registry *registrytypes.Registry) bool { @@ -333,7 +367,7 @@ func registryHasServers(registry *registrytypes.Registry) bool { } // validateRegistryFileStructure checks if a file contains valid ToolHive registry structure -// by parsing it into the actual Registry type +// by parsing it into the actual Registry type. Accepts both upstream and legacy formats. func validateRegistryFileStructure(path string) error { // Read file content // #nosec G304: File path is user-provided but validated by caller @@ -342,7 +376,19 @@ func validateRegistryFileStructure(path string) error { return fmt.Errorf("failed to read file: %w", err) } - // Parse into the actual Registry type for strong validation + // Try upstream format first + if isUpstreamRegistryFormat(data) { + var upstream registrytypes.UpstreamRegistry + if err := json.Unmarshal(data, &upstream); err != nil { + return fmt.Errorf("invalid upstream registry format: %w", err) + } + if len(upstream.Data.Servers) > 0 || len(upstream.Data.Groups) > 0 { + return nil + } + return fmt.Errorf("upstream registry contains no servers or groups") + } + + // Fall back to legacy format registry := ®istrytypes.Registry{} if err := json.Unmarshal(data, registry); err != nil { return fmt.Errorf("invalid registry format: %w", err) diff --git a/pkg/config/registry_test.go b/pkg/config/registry_test.go index 9f0bb6e352..d12924c04a 100644 --- a/pkg/config/registry_test.go +++ b/pkg/config/registry_test.go @@ -7,9 +7,11 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "os" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const testAPIEndpoint = "/v0.1/servers" @@ -345,6 +347,127 @@ func TestIsValidRegistryJSON(t *testing.T) { } } +func TestValidateRegistryFileStructure_UpstreamFormat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + content string + expectError bool + }{ + { + name: "valid upstream format with servers", + content: `{ + "$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json", + "version": "1.0.0", + "meta": {"last_updated": "2025-01-01T00:00:00Z"}, + "data": { + "servers": [ + { + "name": "io.example.test", + "description": "Test", + "packages": [{"registryType": "oci", "identifier": "test:latest", "transport": {"type": "stdio"}}] + } + ] + } + }`, + expectError: false, + }, + { + name: "upstream format with empty servers", + content: `{ + "$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json", + "version": "1.0.0", + "meta": {"last_updated": "2025-01-01T00:00:00Z"}, + "data": {"servers": []} + }`, + expectError: true, + }, + { + name: "legacy format still works", + content: `{ + "version": "1.0.0", + "servers": {"test": {"image": "test:latest"}} + }`, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + path := tmpDir + "/registry.json" + require.NoError(t, os.WriteFile(path, []byte(tt.content), 0644)) + + err := validateRegistryFileStructure(path) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestIsValidRegistryJSON_UpstreamFormat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + body string + expectedError bool + }{ + { + name: "valid upstream format", + body: `{ + "$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json", + "version": "1.0.0", + "meta": {"last_updated": "2025-01-01T00:00:00Z"}, + "data": { + "servers": [ + { + "name": "io.example.test", + "description": "Test", + "packages": [{"registryType": "oci", "identifier": "test:latest", "transport": {"type": "stdio"}}] + } + ] + } + }`, + expectedError: false, + }, + { + name: "upstream format with no servers", + body: `{ + "$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json", + "version": "1.0.0", + "meta": {"last_updated": "2025-01-01T00:00:00Z"}, + "data": {"servers": []} + }`, + expectedError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(tt.body)) + })) + defer server.Close() + + client := &http.Client{} + err := isValidRegistryJSON(client, server.URL) + if tt.expectedError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + func TestProbeRegistryURL(t *testing.T) { //nolint:tparallel,paralleltest // Cannot use t.Parallel() on subtests using t.Setenv() tests := []struct { name string diff --git a/pkg/registry/mocks/mock_provider.go b/pkg/registry/mocks/mock_provider.go index 883c9f16bd..5fc97deadf 100644 --- a/pkg/registry/mocks/mock_provider.go +++ b/pkg/registry/mocks/mock_provider.go @@ -40,21 +40,6 @@ func (m *MockProvider) EXPECT() *MockProviderMockRecorder { return m.recorder } -// GetImageServer mocks base method. -func (m *MockProvider) GetImageServer(name string) (*registry.ImageMetadata, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetImageServer", name) - ret0, _ := ret[0].(*registry.ImageMetadata) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetImageServer indicates an expected call of GetImageServer. -func (mr *MockProviderMockRecorder) GetImageServer(name any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetImageServer", reflect.TypeOf((*MockProvider)(nil).GetImageServer), name) -} - // GetRegistry mocks base method. func (m *MockProvider) GetRegistry() (*registry.Registry, error) { m.ctrl.T.Helper() @@ -85,49 +70,49 @@ func (mr *MockProviderMockRecorder) GetServer(name any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServer", reflect.TypeOf((*MockProvider)(nil).GetServer), name) } -// ListImageServers mocks base method. -func (m *MockProvider) ListImageServers() ([]*registry.ImageMetadata, error) { +// GetSkill mocks base method. +func (m *MockProvider) GetSkill(namespace, name string) (*registry.Skill, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListImageServers") - ret0, _ := ret[0].([]*registry.ImageMetadata) + ret := m.ctrl.Call(m, "GetSkill", namespace, name) + ret0, _ := ret[0].(*registry.Skill) ret1, _ := ret[1].(error) return ret0, ret1 } -// ListImageServers indicates an expected call of ListImageServers. -func (mr *MockProviderMockRecorder) ListImageServers() *gomock.Call { +// GetSkill indicates an expected call of GetSkill. +func (mr *MockProviderMockRecorder) GetSkill(namespace, name any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListImageServers", reflect.TypeOf((*MockProvider)(nil).ListImageServers)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSkill", reflect.TypeOf((*MockProvider)(nil).GetSkill), namespace, name) } -// ListServers mocks base method. -func (m *MockProvider) ListServers() ([]registry.ServerMetadata, error) { +// ListAvailableSkills mocks base method. +func (m *MockProvider) ListAvailableSkills() ([]registry.Skill, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListServers") - ret0, _ := ret[0].([]registry.ServerMetadata) + ret := m.ctrl.Call(m, "ListAvailableSkills") + ret0, _ := ret[0].([]registry.Skill) ret1, _ := ret[1].(error) return ret0, ret1 } -// ListServers indicates an expected call of ListServers. -func (mr *MockProviderMockRecorder) ListServers() *gomock.Call { +// ListAvailableSkills indicates an expected call of ListAvailableSkills. +func (mr *MockProviderMockRecorder) ListAvailableSkills() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListServers", reflect.TypeOf((*MockProvider)(nil).ListServers)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAvailableSkills", reflect.TypeOf((*MockProvider)(nil).ListAvailableSkills)) } -// SearchImageServers mocks base method. -func (m *MockProvider) SearchImageServers(query string) ([]*registry.ImageMetadata, error) { +// ListServers mocks base method. +func (m *MockProvider) ListServers() ([]registry.ServerMetadata, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SearchImageServers", query) - ret0, _ := ret[0].([]*registry.ImageMetadata) + ret := m.ctrl.Call(m, "ListServers") + ret0, _ := ret[0].([]registry.ServerMetadata) ret1, _ := ret[1].(error) return ret0, ret1 } -// SearchImageServers indicates an expected call of SearchImageServers. -func (mr *MockProviderMockRecorder) SearchImageServers(query any) *gomock.Call { +// ListServers indicates an expected call of ListServers. +func (mr *MockProviderMockRecorder) ListServers() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchImageServers", reflect.TypeOf((*MockProvider)(nil).SearchImageServers), query) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListServers", reflect.TypeOf((*MockProvider)(nil).ListServers)) } // SearchServers mocks base method. @@ -144,3 +129,18 @@ func (mr *MockProviderMockRecorder) SearchServers(query any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchServers", reflect.TypeOf((*MockProvider)(nil).SearchServers), query) } + +// SearchSkills mocks base method. +func (m *MockProvider) SearchSkills(query string) ([]registry.Skill, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SearchSkills", query) + ret0, _ := ret[0].([]registry.Skill) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SearchSkills indicates an expected call of SearchSkills. +func (mr *MockProviderMockRecorder) SearchSkills(query any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchSkills", reflect.TypeOf((*MockProvider)(nil).SearchSkills), query) +} diff --git a/pkg/registry/provider.go b/pkg/registry/provider.go index b4abfd4260..1cffb4ca61 100644 --- a/pkg/registry/provider.go +++ b/pkg/registry/provider.go @@ -21,13 +21,12 @@ type Provider interface { // ListServers returns all available servers (both container and remote) ListServers() ([]types.ServerMetadata, error) - // Legacy methods for backward compatibility - // GetImageServer returns a specific container server by name - GetImageServer(name string) (*types.ImageMetadata, error) + // ListAvailableSkills returns skills discovered from the registry data + ListAvailableSkills() ([]types.Skill, error) - // SearchImageServers searches for container servers matching the query - SearchImageServers(query string) ([]*types.ImageMetadata, error) + // GetSkill returns a specific skill by namespace and name + GetSkill(namespace, name string) (*types.Skill, error) - // ListImageServers returns all available container servers - ListImageServers() ([]*types.ImageMetadata, error) + // SearchSkills searches for skills matching the query + SearchSkills(query string) ([]types.Skill, error) } diff --git a/pkg/registry/provider_api.go b/pkg/registry/provider_api.go index f195119b52..8f8780ea3c 100644 --- a/pkg/registry/provider_api.go +++ b/pkg/registry/provider_api.go @@ -23,6 +23,8 @@ type APIRegistryProvider struct { apiURL string allowPrivateIp bool client api.Client + tokenSource auth.TokenSource + skillsClient api.SkillsClient } // NewAPIRegistryProvider creates a new API registry provider. @@ -34,10 +36,15 @@ func NewAPIRegistryProvider(apiURL string, allowPrivateIp bool, tokenSource auth return nil, fmt.Errorf("failed to create API client: %w", err) } + // Create skills client (best-effort — skills API may not be available) + skillsClient, _ := api.NewSkillsClient(apiURL, allowPrivateIp, tokenSource) + p := &APIRegistryProvider{ apiURL: apiURL, allowPrivateIp: allowPrivateIp, client: client, + tokenSource: tokenSource, + skillsClient: skillsClient, } // Initialize the base provider with the GetRegistry function @@ -171,21 +178,34 @@ func (p *APIRegistryProvider) ListServers() ([]types.ServerMetadata, error) { return ConvertServersToMetadata(servers) } -// GetImageServer returns a specific container server by name (overrides BaseProvider) -// This override is necessary because BaseProvider.GetImageServer calls p.GetServer, -// which would call BaseProvider.GetServer instead of APIRegistryProvider.GetServer -func (p *APIRegistryProvider) GetImageServer(name string) (*types.ImageMetadata, error) { - server, err := p.GetServer(name) +// GetSkill returns a specific skill by namespace and name from the API. +func (p *APIRegistryProvider) GetSkill(namespace, name string) (*types.Skill, error) { + if p.skillsClient == nil { + return nil, nil + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return p.skillsClient.GetSkill(ctx, namespace, name) +} + +// SearchSkills searches for skills matching the query via the API. +func (p *APIRegistryProvider) SearchSkills(query string) ([]types.Skill, error) { + if p.skillsClient == nil { + return nil, nil + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + result, err := p.skillsClient.SearchSkills(ctx, query) if err != nil { return nil, err } - - // Type assert to ImageMetadata - if img, ok := server.(*types.ImageMetadata); ok { - return img, nil + skills := make([]types.Skill, 0, len(result.Skills)) + for _, s := range result.Skills { + if s != nil { + skills = append(skills, *s) + } } - - return nil, fmt.Errorf("server %s is not a container server", name) + return skills, nil } // ConvertServerJSON converts an MCP Registry API ServerJSON to ToolHive ServerMetadata diff --git a/pkg/registry/provider_base.go b/pkg/registry/provider_base.go index e777a8bc8c..49e76b7f75 100644 --- a/pkg/registry/provider_base.go +++ b/pkg/registry/provider_base.go @@ -24,20 +24,62 @@ func NewBaseProvider(getRegistry func() (*types.Registry, error)) *BaseProvider } } -// GetServer returns a specific server by name (container or remote) +// GetServer returns a specific server by name (container or remote). +// Supports both full reverse-DNS names (io.github.stacklok/osv) and +// short names (osv) for backward compatibility. func (p *BaseProvider) GetServer(name string) (types.ServerMetadata, error) { reg, err := p.GetRegistryFunc() if err != nil { return nil, err } - // Use the registry's helper method + // Try exact match first server, found := reg.GetServerByName(name) - if !found { - return nil, fmt.Errorf("server not found: %s", name) + if found { + return server, nil + } + + // Fall back to short-name matching: check if name matches the last + // path component of any server's full reverse-DNS name. + // e.g. "osv" matches "io.github.stacklok/osv" + if !strings.Contains(name, "/") { + matches := findServersByShortName(reg, name) + if len(matches) == 1 { + return matches[0].server, nil + } + if len(matches) > 1 { + names := make([]string, len(matches)) + for i, m := range matches { + names[i] = m.fullName + } + return nil, fmt.Errorf("multiple servers match '%s': %s — use the full name", + name, strings.Join(names, ", ")) + } } - return server, nil + return nil, fmt.Errorf("server not found: %s", name) +} + +type shortNameMatch struct { + fullName string + server types.ServerMetadata +} + +// findServersByShortName returns all servers whose name ends with "/". +func findServersByShortName(reg *types.Registry, shortName string) []shortNameMatch { + suffix := "/" + shortName + var matches []shortNameMatch + for fullName, server := range reg.Servers { + if strings.HasSuffix(fullName, suffix) { + matches = append(matches, shortNameMatch{fullName, server}) + } + } + for fullName, server := range reg.RemoteServers { + if strings.HasSuffix(fullName, suffix) { + matches = append(matches, shortNameMatch{fullName, server}) + } + } + return matches } // SearchServers searches for servers matching the query (both container and remote) @@ -78,57 +120,20 @@ func (p *BaseProvider) ListServers() ([]types.ServerMetadata, error) { return reg.GetAllServers(), nil } -// Legacy methods for backward compatibility - -// GetImageServer returns a specific container server by name (legacy method) -func (p *BaseProvider) GetImageServer(name string) (*types.ImageMetadata, error) { - server, err := p.GetServer(name) - if err != nil { - return nil, err - } - - // Type assert to ImageMetadata - if img, ok := server.(*types.ImageMetadata); ok { - return img, nil - } - - return nil, fmt.Errorf("server %s is not a container server", name) +// ListAvailableSkills returns an empty slice by default. +// Providers that support skills (local, remote) override this. +func (*BaseProvider) ListAvailableSkills() ([]types.Skill, error) { + return nil, nil } -// SearchImageServers searches for container servers matching the query (legacy method) -func (p *BaseProvider) SearchImageServers(query string) ([]*types.ImageMetadata, error) { - servers, err := p.SearchServers(query) - if err != nil { - return nil, err - } - - // Filter to only container servers - var results []*types.ImageMetadata - for _, server := range servers { - if img, ok := server.(*types.ImageMetadata); ok { - results = append(results, img) - } - } - - return results, nil +// GetSkill returns nil for providers that don't support skills. +func (*BaseProvider) GetSkill(_, _ string) (*types.Skill, error) { + return nil, nil } -// ListImageServers returns all container servers (legacy method) -func (p *BaseProvider) ListImageServers() ([]*types.ImageMetadata, error) { - servers, err := p.ListServers() - if err != nil { - return nil, err - } - - // Filter to only container servers - var results []*types.ImageMetadata - for _, server := range servers { - if img, ok := server.(*types.ImageMetadata); ok { - results = append(results, img) - } - } - - return results, nil +// SearchSkills returns nil for providers that don't support skills. +func (*BaseProvider) SearchSkills(_ string) ([]types.Skill, error) { + return nil, nil } // matchesQuery checks if a server matches the search query diff --git a/pkg/registry/provider_cached.go b/pkg/registry/provider_cached.go index 4a6bef6b67..85bdced249 100644 --- a/pkg/registry/provider_cached.go +++ b/pkg/registry/provider_cached.go @@ -15,6 +15,7 @@ import ( v0 "github.com/modelcontextprotocol/registry/pkg/api/v0" types "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/api" "github.com/stacklok/toolhive/pkg/registry/auth" ) @@ -38,6 +39,12 @@ type CachedAPIRegistryProvider struct { cachedData *types.Registry cacheTime time.Time + // Skills cache + skillsMu sync.RWMutex + cachedSkills []types.Skill + skillsCacheSet bool + skillsTime time.Time + // Cache configuration cacheTTL time.Duration usePersistent bool @@ -142,19 +149,17 @@ func (p *CachedAPIRegistryProvider) ForceRefresh() error { } // GetServer returns a specific server by name (overrides base to use cache). +// Ensures the cache is loaded, then delegates to BaseProvider.GetServer which +// handles both exact and short-name resolution. func (p *CachedAPIRegistryProvider) GetServer(name string) (types.ServerMetadata, error) { - // For individual server lookups, we could query the API directly for freshness, - // or use the cached registry. Let's use cached registry for consistency. - registry, err := p.GetRegistry() - if err != nil { + // Ensure cache is loaded + if _, err := p.GetRegistry(); err != nil { return nil, err } - // Try to find in cached registry first - if server, ok := registry.Servers[name]; ok { - return server, nil - } - if server, ok := registry.RemoteServers[name]; ok { + // Use BaseProvider.GetServer which includes short-name resolution + server, err := p.BaseProvider.GetServer(name) + if err == nil { return server, nil } @@ -365,34 +370,74 @@ func (p *CachedAPIRegistryProvider) cleanupOldCaches() { // Ensure CachedAPIRegistryProvider implements Provider interface var _ Provider = (*CachedAPIRegistryProvider)(nil) -// Override methods that query individual servers to ensure they use cache - -// GetImageServer returns a specific container server by name (uses cache). -func (p *CachedAPIRegistryProvider) GetImageServer(name string) (*types.ImageMetadata, error) { +// GetRemoteServer returns a specific remote server by name (uses cache). +func (p *CachedAPIRegistryProvider) GetRemoteServer(name string) (*types.RemoteServerMetadata, error) { server, err := p.GetServer(name) if err != nil { return nil, err } - if img, ok := server.(*types.ImageMetadata); ok { - return img, nil + if remote, ok := server.(*types.RemoteServerMetadata); ok { + return remote, nil } - return nil, fmt.Errorf("server %s is not a container server", name) + return nil, fmt.Errorf("server %s is not a remote server", name) } -// GetRemoteServer returns a specific remote server by name (uses cache). -func (p *CachedAPIRegistryProvider) GetRemoteServer(name string) (*types.RemoteServerMetadata, error) { - server, err := p.GetServer(name) +// ListAvailableSkills returns skills from the registry API, with caching. +// Creates a SkillsClient on demand and fetches all skills with auto-pagination. +func (p *CachedAPIRegistryProvider) ListAvailableSkills() ([]types.Skill, error) { + // Check cache + p.skillsMu.RLock() + if p.skillsCacheSet && time.Since(p.skillsTime) < p.cacheTTL { + skills := p.cachedSkills + p.skillsMu.RUnlock() + return skills, nil + } + p.skillsMu.RUnlock() + + // Fetch from API + skillsClient, err := api.NewSkillsClient(p.apiURL, p.allowPrivateIp, p.tokenSource) if err != nil { - return nil, err + // Return cached data if available + p.skillsMu.RLock() + defer p.skillsMu.RUnlock() + if p.skillsCacheSet { + return p.cachedSkills, nil + } + return nil, fmt.Errorf("failed to create skills client: %w", err) } - if remote, ok := server.(*types.RemoteServerMetadata); ok { - return remote, nil + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + // ListSkills auto-paginates internally, returning all skills in one call + result, err := skillsClient.ListSkills(ctx, nil) + if err != nil { + // Return cached data if available, otherwise nil (skills are optional) + p.skillsMu.RLock() + defer p.skillsMu.RUnlock() + if p.skillsCacheSet { + return p.cachedSkills, nil + } + return nil, nil } - return nil, fmt.Errorf("server %s is not a remote server", name) + allSkills := make([]types.Skill, 0, len(result.Skills)) + for _, s := range result.Skills { + if s != nil { + allSkills = append(allSkills, *s) + } + } + + // Update cache + p.skillsMu.Lock() + p.cachedSkills = allSkills + p.skillsCacheSet = true + p.skillsTime = time.Now() + p.skillsMu.Unlock() + + return allSkills, nil } // ConvertServerJSON wraps ConvertServerJSON for cached provider diff --git a/pkg/registry/provider_local.go b/pkg/registry/provider_local.go index 25e871102f..9693506f77 100644 --- a/pkg/registry/provider_local.go +++ b/pkg/registry/provider_local.go @@ -6,7 +6,10 @@ package registry import ( "encoding/json" "fmt" + "log/slog" "os" + "strings" + "sync" catalog "github.com/stacklok/toolhive-catalog/pkg/catalog/toolhive" types "github.com/stacklok/toolhive-core/registry/types" @@ -16,6 +19,8 @@ import ( type LocalRegistryProvider struct { *BaseProvider filePath string + skillsMu sync.RWMutex + skills []types.Skill } // NewLocalRegistryProvider creates a new local registry provider @@ -38,22 +43,36 @@ func NewLocalRegistryProvider(filePath ...string) *LocalRegistryProvider { // GetRegistry returns the registry data from file path or embedded data func (p *LocalRegistryProvider) GetRegistry() (*types.Registry, error) { - var data []byte - var err error + var registry *types.Registry if p.filePath != "" { - // Read from local file - data, err = os.ReadFile(p.filePath) + // Read from local file — auto-detect format + data, err := os.ReadFile(p.filePath) if err != nil { return nil, fmt.Errorf("failed to read local registry file %s: %w", p.filePath, err) } - } else { - data = catalog.Legacy() - } - registry, err := parseRegistryData(data) - if err != nil { - return nil, err + var skills []types.Skill + var isLegacy bool + registry, skills, isLegacy, err = parseRegistryAutoDetect(data) + if err != nil { + return nil, err + } + p.setSkills(skills) + if isLegacy { + slog.Warn("Registry file uses legacy format; please migrate to the upstream MCP format. "+ + "Legacy format support will be removed in a future release.", + "file", p.filePath) + } + } else { + // Embedded catalog — always upstream format + var err error + var skills []types.Skill + registry, skills, err = parseUpstreamRegistry(catalog.Upstream()) + if err != nil { + return nil, fmt.Errorf("failed to parse embedded upstream registry: %w", err) + } + p.setSkills(skills) } // Set name field on each server based on map key @@ -80,6 +99,64 @@ func (p *LocalRegistryProvider) GetRegistry() (*types.Registry, error) { return registry, nil } +func (p *LocalRegistryProvider) setSkills(skills []types.Skill) { + p.skillsMu.Lock() + defer p.skillsMu.Unlock() + p.skills = skills +} + +// ListAvailableSkills returns skills discovered from the upstream registry data. +// Triggers a registry load if skills haven't been populated yet. +func (p *LocalRegistryProvider) ListAvailableSkills() ([]types.Skill, error) { + p.skillsMu.RLock() + skills := p.skills + p.skillsMu.RUnlock() + + if skills == nil { + // Skills are populated as a side effect of GetRegistry + if _, err := p.GetRegistry(); err != nil { + return nil, err + } + p.skillsMu.RLock() + skills = p.skills + p.skillsMu.RUnlock() + } + + return skills, nil +} + +// GetSkill returns a specific skill by namespace and name. +func (p *LocalRegistryProvider) GetSkill(namespace, name string) (*types.Skill, error) { + skills, err := p.ListAvailableSkills() + if err != nil { + return nil, err + } + for i := range skills { + if skills[i].Namespace == namespace && skills[i].Name == name { + return &skills[i], nil + } + } + return nil, nil +} + +// SearchSkills searches for skills matching the query in name or description. +func (p *LocalRegistryProvider) SearchSkills(query string) ([]types.Skill, error) { + skills, err := p.ListAvailableSkills() + if err != nil { + return nil, err + } + query = strings.ToLower(query) + var results []types.Skill + for _, s := range skills { + if strings.Contains(strings.ToLower(s.Name), query) || + strings.Contains(strings.ToLower(s.Description), query) || + strings.Contains(strings.ToLower(s.Namespace), query) { + results = append(results, s) + } + } + return results, nil +} + // parseRegistryData parses JSON data into a Registry struct func parseRegistryData(data []byte) (*types.Registry, error) { registry := &types.Registry{} diff --git a/pkg/registry/provider_remote.go b/pkg/registry/provider_remote.go index 990cc8b62d..30688bdbf8 100644 --- a/pkg/registry/provider_remote.go +++ b/pkg/registry/provider_remote.go @@ -9,6 +9,8 @@ import ( "io" "log/slog" "net/http" + "strings" + "sync" "time" types "github.com/stacklok/toolhive-core/registry/types" @@ -20,6 +22,8 @@ type RemoteRegistryProvider struct { *BaseProvider registryURL string allowPrivateIp bool + skillsMu sync.RWMutex + skills []types.Skill } // NewRemoteRegistryProvider creates a new remote registry provider. @@ -76,13 +80,24 @@ func (p *RemoteRegistryProvider) validateConnectivity() error { return fmt.Errorf("failed to read registry response: %w", err) } + // Try upstream format first, fall back to legacy + if isUpstreamFormat(data) { + var upstream types.UpstreamRegistry + if err := json.Unmarshal(data, &upstream); err != nil { + return fmt.Errorf("registry returned invalid upstream JSON from %s: %w", p.registryURL, err) + } + if len(upstream.Data.Servers) == 0 && len(upstream.Data.Groups) == 0 { + return fmt.Errorf("registry at %s returned upstream format with no servers or groups", p.registryURL) + } + return nil + } + registry := &types.Registry{} if err := json.Unmarshal(data, registry); err != nil { return fmt.Errorf("registry returned invalid JSON from %s: %w", p.registryURL, err) } // Validate the registry has at least the required structure - // (we don't require servers/groups to exist, but the structure must be valid) if registry.Servers == nil && registry.RemoteServers == nil && registry.Groups == nil { return fmt.Errorf("registry at %s returned invalid structure: "+ "missing servers, remote_servers, and groups fields", p.registryURL) @@ -125,9 +140,15 @@ func (p *RemoteRegistryProvider) GetRegistry() (*types.Registry, error) { return nil, fmt.Errorf("failed to read registry data from response body: %w", err) } - registry := &types.Registry{} - if err := json.Unmarshal(data, registry); err != nil { - return nil, fmt.Errorf("failed to parse registry data: %w", err) + registry, skills, isLegacy, err := parseRegistryAutoDetect(data) + if err != nil { + return nil, fmt.Errorf("failed to parse registry data from %s: %w", p.registryURL, err) + } + p.setSkills(skills) + if isLegacy { + slog.Warn("Remote registry uses legacy format; please migrate to the upstream MCP format. "+ + "Legacy format support will be removed in a future release.", + "url", p.registryURL) } // Set name field on each server based on map key @@ -153,3 +174,61 @@ func (p *RemoteRegistryProvider) GetRegistry() (*types.Registry, error) { return registry, nil } + +// ListAvailableSkills returns skills discovered from the remote registry data. +// Triggers a registry load if skills haven't been populated yet. +func (p *RemoteRegistryProvider) ListAvailableSkills() ([]types.Skill, error) { + p.skillsMu.RLock() + skills := p.skills + p.skillsMu.RUnlock() + + if skills == nil { + // Skills are populated as a side effect of GetRegistry + if _, err := p.GetRegistry(); err != nil { + return nil, err + } + p.skillsMu.RLock() + skills = p.skills + p.skillsMu.RUnlock() + } + + return skills, nil +} + +// GetSkill returns a specific skill by namespace and name. +func (p *RemoteRegistryProvider) GetSkill(namespace, name string) (*types.Skill, error) { + skills, err := p.ListAvailableSkills() + if err != nil { + return nil, err + } + for i := range skills { + if skills[i].Namespace == namespace && skills[i].Name == name { + return &skills[i], nil + } + } + return nil, nil +} + +// SearchSkills searches for skills matching the query in name or description. +func (p *RemoteRegistryProvider) SearchSkills(query string) ([]types.Skill, error) { + skills, err := p.ListAvailableSkills() + if err != nil { + return nil, err + } + query = strings.ToLower(query) + var results []types.Skill + for _, s := range skills { + if strings.Contains(strings.ToLower(s.Name), query) || + strings.Contains(strings.ToLower(s.Description), query) || + strings.Contains(strings.ToLower(s.Namespace), query) { + results = append(results, s) + } + } + return results, nil +} + +func (p *RemoteRegistryProvider) setSkills(skills []types.Skill) { + p.skillsMu.Lock() + defer p.skillsMu.Unlock() + p.skills = skills +} diff --git a/pkg/registry/provider_test.go b/pkg/registry/provider_test.go index 77c74d2616..3911b879b7 100644 --- a/pkg/registry/provider_test.go +++ b/pkg/registry/provider_test.go @@ -343,6 +343,208 @@ func TestLocalRegistryProviderWithLocalFile(t *testing.T) { } } +func TestLocalRegistryProviderWithUpstreamFormatFile(t *testing.T) { + t.Parallel() + + // Create a temporary upstream-format registry file + tempDir := t.TempDir() + registryFile := filepath.Join(tempDir, "upstream_registry.json") + + testRegistry := `{ + "$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json", + "version": "1.0.0", + "meta": { + "last_updated": "2025-01-01T00:00:00Z" + }, + "data": { + "servers": [ + { + "name": "io.example.test-server", + "description": "Test server", + "packages": [ + { + "registryType": "oci", + "identifier": "example/test-server:latest", + "transport": { + "type": "stdio" + } + } + ] + } + ] + } + }` + + err := os.WriteFile(registryFile, []byte(testRegistry), 0644) + require.NoError(t, err) + + provider := NewLocalRegistryProvider(registryFile) + + registry, err := provider.GetRegistry() + require.NoError(t, err) + require.NotNil(t, registry) + + assert.NotEmpty(t, registry.Servers, "Should have at least one container server") +} + +func TestLocalRegistryProviderLegacyFormatFallback(t *testing.T) { + t.Parallel() + + // Create a legacy-format registry file + tempDir := t.TempDir() + registryFile := filepath.Join(tempDir, "legacy_registry.json") + + testRegistry := `{ + "version": "1.0.0", + "last_updated": "2023-01-01T00:00:00Z", + "servers": { + "test-server": { + "image": "test/image:latest", + "description": "Test server" + } + } + }` + + err := os.WriteFile(registryFile, []byte(testRegistry), 0644) + require.NoError(t, err) + + provider := NewLocalRegistryProvider(registryFile) + + registry, err := provider.GetRegistry() + require.NoError(t, err) + require.NotNil(t, registry) + + assert.Len(t, registry.Servers, 1) + server, exists := registry.Servers["test-server"] + assert.True(t, exists) + assert.Equal(t, "test/image:latest", server.Image) +} + +func TestRemoteRegistryProvider_UpstreamFormat(t *testing.T) { + t.Parallel() + + responseBody := `{ + "$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json", + "version": "1.0.0", + "meta": { + "last_updated": "2025-01-01T00:00:00Z" + }, + "data": { + "servers": [ + { + "name": "io.example.test-server", + "description": "Test server", + "packages": [ + { + "registryType": "oci", + "identifier": "example/test-server:latest", + "transport": { + "type": "stdio" + } + } + ] + } + ] + } + }` + + server := createTestServer(responseBody, 200) + defer server.Close() + + provider, err := NewRemoteRegistryProvider(server.URL, true) + require.NoError(t, err) + require.NotNil(t, provider) + + registry, err := provider.GetRegistry() + require.NoError(t, err) + assert.NotEmpty(t, registry.Servers, "Should have at least one container server") +} + +func TestGetServer_ShortNameResolution(t *testing.T) { + t.Parallel() + + // Build a controlled registry with known names + reg := &types.Registry{ + Version: "1.0.0", + LastUpdated: "2025-01-01T00:00:00Z", + Servers: map[string]*types.ImageMetadata{ + "io.github.stacklok/osv": {BaseServerMetadata: types.BaseServerMetadata{Name: "io.github.stacklok/osv"}, Image: "ghcr.io/osv:latest"}, + "io.github.stacklok/github": {BaseServerMetadata: types.BaseServerMetadata{Name: "io.github.stacklok/github"}, Image: "ghcr.io/github:latest"}, + "io.github.acme/github": {BaseServerMetadata: types.BaseServerMetadata{Name: "io.github.acme/github"}, Image: "ghcr.io/acme-github:latest"}, + }, + RemoteServers: map[string]*types.RemoteServerMetadata{ + "io.github.stacklok/slack-remote": {BaseServerMetadata: types.BaseServerMetadata{Name: "io.github.stacklok/slack-remote"}, URL: "https://slack.example.com"}, + }, + } + + provider := &LocalRegistryProvider{} + provider.BaseProvider = NewBaseProvider(func() (*types.Registry, error) { + return reg, nil + }) + + tests := []struct { + name string + query string + expectName string + expectError string + }{ + { + name: "exact full name match", + query: "io.github.stacklok/osv", + expectName: "io.github.stacklok/osv", + }, + { + name: "unique short name match", + query: "osv", + expectName: "io.github.stacklok/osv", + }, + { + name: "ambiguous short name errors with full names", + query: "github", + expectError: "multiple servers match 'github'", + }, + { + name: "ambiguous error lists both full names", + query: "github", + expectError: "io.github.stacklok/github", + }, + { + name: "ambiguous error lists both full names (second)", + query: "github", + expectError: "io.github.acme/github", + }, + { + name: "short name for remote server", + query: "slack-remote", + expectName: "io.github.stacklok/slack-remote", + }, + { + name: "no match returns not found", + query: "nonexistent", + expectError: "server not found: nonexistent", + }, + { + name: "partial name does not match (github-remote suffix check)", + query: "remote", + expectError: "server not found: remote", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server, err := provider.GetServer(tt.query) + if tt.expectError != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.expectError) + return + } + require.NoError(t, err) + assert.Equal(t, tt.expectName, server.GetName()) + }) + } +} + // getTypeName returns the type name of an interface value func getTypeName(v interface{}) string { switch v.(type) { @@ -421,7 +623,7 @@ func TestGetServer(t *testing.T) { provider, err := NewRegistryProvider(cfg) require.NoError(t, err) - // Test getting an existing server + // Test getting an existing server (short name resolves via suffix match) server, err := provider.GetServer("osv") if err != nil { t.Fatalf("Failed to get server: %v", err) diff --git a/pkg/registry/schema_validation_test.go b/pkg/registry/schema_validation_test.go index 97bb8a725f..119b1d4aef 100644 --- a/pkg/registry/schema_validation_test.go +++ b/pkg/registry/schema_validation_test.go @@ -4,7 +4,6 @@ package registry import ( - "encoding/json" "testing" "github.com/stretchr/testify/assert" @@ -14,27 +13,150 @@ import ( types "github.com/stacklok/toolhive-core/registry/types" ) -// TestEmbeddedRegistrySchemaValidation validates that the embedded registry.json -// conforms to the registry schema. This is the main test that ensures our -// registry data is always valid. +// TestEmbeddedRegistrySchemaValidation validates that the embedded upstream registry +// conforms to the upstream registry schema. func TestEmbeddedRegistrySchemaValidation(t *testing.T) { t.Parallel() - err := types.ValidateRegistrySchema(catalog.Legacy()) - require.NoError(t, err, "Embedded registry.json must conform to the registry schema") + err := types.ValidateUpstreamRegistryBytes(catalog.Upstream()) + require.NoError(t, err, "Embedded upstream registry must conform to the upstream registry schema") } -// TestValidateEmbeddedRegistryCanLoadData tests that we can actually load the embedded registry +// TestValidateEmbeddedRegistryCanLoadData tests that we can load the embedded upstream +// registry and convert it to the internal format. func TestValidateEmbeddedRegistryCanLoadData(t *testing.T) { t.Parallel() - // Verify it's valid JSON - var registry types.Registry - err := json.Unmarshal(catalog.Legacy(), ®istry) - require.NoError(t, err, "Embedded registry should be valid JSON") + registry, skills, err := parseUpstreamRegistry(catalog.Upstream()) + require.NoError(t, err, "Embedded upstream registry should parse successfully") // Verify basic structure assert.NotEmpty(t, registry.Version, "Registry should have a version") assert.NotEmpty(t, registry.LastUpdated, "Registry should have a last_updated timestamp") - assert.NotNil(t, registry.Servers, "Registry should have a servers map") + assert.True(t, len(registry.Servers) > 0 || len(registry.RemoteServers) > 0, + "Registry should have at least one server") + + // Skills may or may not be present in the catalog, just verify no error + _ = skills +} + +func TestIsUpstreamFormat(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected bool + }{ + { + name: "upstream format with data object", + input: `{"$schema": "https://example.com/schema.json", "data": {"servers": []}}`, + expected: true, + }, + { + name: "upstream format data only, no schema", + input: `{"data": {"servers": []}}`, + expected: true, + }, + { + name: "legacy format with schema but no data object", + input: `{"$schema": "https://example.com/legacy.json", "version": "1.0", "servers": {}}`, + expected: false, + }, + { + name: "legacy format no schema", + input: `{"version": "1.0", "servers": {"osv": {}}}`, + expected: false, + }, + { + name: "data is a string not object", + input: `{"data": "not an object"}`, + expected: false, + }, + { + name: "data is an array not object", + input: `{"data": [1, 2, 3]}`, + expected: false, + }, + { + name: "data is null", + input: `{"data": null}`, + expected: false, + }, + { + name: "empty JSON object", + input: `{}`, + expected: false, + }, + { + name: "invalid JSON", + input: `not json`, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.expected, isUpstreamFormat([]byte(tt.input))) + }) + } +} + +func TestParseRegistryAutoDetect(t *testing.T) { + t.Parallel() + + t.Run("upstream format returns isLegacy=false", func(t *testing.T) { + t.Parallel() + input := `{ + "$schema": "https://example.com/schema.json", + "version": "1.0.0", + "meta": {"last_updated": "2025-01-01T00:00:00Z"}, + "data": {"servers": []} + }` + _, _, isLegacy, err := parseRegistryAutoDetect([]byte(input)) + require.NoError(t, err) + assert.False(t, isLegacy) + }) + + t.Run("legacy format returns isLegacy=true", func(t *testing.T) { + t.Parallel() + input := `{ + "version": "1.0.0", + "servers": {"test": {"image": "test:latest"}} + }` + _, _, isLegacy, err := parseRegistryAutoDetect([]byte(input)) + require.NoError(t, err) + assert.True(t, isLegacy) + }) + + t.Run("legacy format with schema returns isLegacy=true", func(t *testing.T) { + t.Parallel() + input := `{ + "$schema": "https://example.com/legacy.json", + "version": "1.0.0", + "servers": {"test": {"image": "test:latest"}} + }` + _, _, isLegacy, err := parseRegistryAutoDetect([]byte(input)) + require.NoError(t, err) + assert.True(t, isLegacy) + }) +} + +// TestUpstreamRegistryParsing verifies that parseUpstreamRegistry correctly converts +// the embedded upstream catalog data. +func TestUpstreamRegistryParsing(t *testing.T) { + t.Parallel() + + registry, _, err := parseUpstreamRegistry(catalog.Upstream()) + require.NoError(t, err) + + // Verify servers have names set (from conversion) + for _, server := range registry.Servers { + assert.NotEmpty(t, server.Name, "Server should have a name") + assert.NotEmpty(t, server.Image, "Container server should have an image") + } + for _, server := range registry.RemoteServers { + assert.NotEmpty(t, server.Name, "Remote server should have a name") + } } diff --git a/pkg/registry/upstream_parser.go b/pkg/registry/upstream_parser.go new file mode 100644 index 0000000000..d8d41ede64 --- /dev/null +++ b/pkg/registry/upstream_parser.go @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package registry + +import ( + "encoding/json" + "fmt" + + v0 "github.com/modelcontextprotocol/registry/pkg/api/v0" + + types "github.com/stacklok/toolhive-core/registry/types" +) + +// parseUpstreamRegistry parses raw JSON in the upstream registry format and +// converts it into a legacy types.Registry plus any embedded skills. +func parseUpstreamRegistry(data []byte) (*types.Registry, []types.Skill, error) { + var upstream types.UpstreamRegistry + if err := json.Unmarshal(data, &upstream); err != nil { + return nil, nil, fmt.Errorf("failed to parse upstream registry data: %w", err) + } + + // ConvertServersToMetadata expects []*v0.ServerJSON, but UpstreamData.Servers + // is []v0.ServerJSON, so build a pointer slice. + serverPtrs := make([]*v0.ServerJSON, len(upstream.Data.Servers)) + for i := range upstream.Data.Servers { + serverPtrs[i] = &upstream.Data.Servers[i] + } + + serverMetadata, err := ConvertServersToMetadata(serverPtrs) + if err != nil { + return nil, nil, fmt.Errorf("failed to convert upstream servers to metadata: %w", err) + } + + // Build the legacy Registry, separating container and remote servers. + registry := &types.Registry{ + Version: upstream.Version, + LastUpdated: upstream.Meta.LastUpdated, + Servers: make(map[string]*types.ImageMetadata), + RemoteServers: make(map[string]*types.RemoteServerMetadata), + Groups: []*types.Group{}, + } + + for _, server := range serverMetadata { + if server.IsRemote() { + if remoteServer, ok := server.(*types.RemoteServerMetadata); ok { + registry.RemoteServers[remoteServer.Name] = remoteServer + } + } else { + if imageServer, ok := server.(*types.ImageMetadata); ok { + registry.Servers[imageServer.Name] = imageServer + } + } + } + + return registry, upstream.Data.Skills, nil +} + +// upstreamFormatProbe is a minimal struct used to detect whether JSON data is +// in the upstream registry format without fully unmarshalling it. +type upstreamFormatProbe struct { + Schema string `json:"$schema"` + Data json.RawMessage `json:"data"` +} + +// isUpstreamFormat returns true when the raw JSON appears to be in the upstream +// registry format. The key discriminator is the "data" wrapper object — only +// the upstream format wraps servers inside a "data" object. The "$schema" key +// alone is not sufficient because the legacy format also includes one. +// NOTE: keep in sync with isUpstreamRegistryFormat in pkg/config/registry.go +// (duplicated to avoid a circular import). +func isUpstreamFormat(data []byte) bool { + var probe upstreamFormatProbe + if err := json.Unmarshal(data, &probe); err != nil { + return false + } + // The "data" wrapper object is unique to the upstream format. + return len(probe.Data) > 0 && probe.Data[0] == '{' +} + +// parseRegistryAutoDetect attempts to parse the given JSON data by first +// checking whether it uses the upstream format. If so it delegates to +// parseUpstreamRegistry; otherwise it falls back to the legacy parser +// (parseRegistryData). The returned isLegacy flag indicates which path was +// taken. Skills are only returned for the upstream format. +func parseRegistryAutoDetect(data []byte) (*types.Registry, []types.Skill, bool, error) { + if isUpstreamFormat(data) { + reg, skills, err := parseUpstreamRegistry(data) + if err != nil { + return nil, nil, false, fmt.Errorf("upstream format detected but parsing failed: %w", err) + } + return reg, skills, false, nil + } + + // Legacy format — no skills. + reg, err := parseRegistryData(data) + if err != nil { + return nil, nil, true, fmt.Errorf("failed to parse legacy registry data: %w", err) + } + return reg, nil, true, nil +} diff --git a/pkg/runner/retriever/retriever.go b/pkg/runner/retriever/retriever.go index b47e6d9cdf..3622397d54 100644 --- a/pkg/runner/retriever/retriever.go +++ b/pkg/runner/retriever/retriever.go @@ -234,14 +234,13 @@ func handleRegistryLookup( return serverOrImage, nil, server, nil } // It's a container server, get the ImageMetadata - imageMetadata, err = provider.GetImageServer(serverOrImage) - if err != nil { - // This shouldn't happen since we just found it, but handle it anyway - slog.Debug("ImageMetadata not found in registry", "server", serverOrImage, "error", err) - imageToUse = serverOrImage + if imgMeta, ok := server.(*types.ImageMetadata); ok { + imageMetadata = imgMeta + imageToUse = imgMeta.Image + slog.Debug("Found imageMetadata in registry", "server", serverOrImage) } else { - slog.Debug("Found imageMetadata in registry", "server", serverOrImage, "metadata", imageMetadata) - imageToUse = imageMetadata.Image + slog.Debug("ImageMetadata not found in registry: could not cast", "server", serverOrImage) + imageToUse = serverOrImage } } else { // Server not found in registry, treat as a direct image reference diff --git a/test/e2e/fetch_mcp_server_test.go b/test/e2e/fetch_mcp_server_test.go index 54a163ad05..00a841313f 100644 --- a/test/e2e/fetch_mcp_server_test.go +++ b/test/e2e/fetch_mcp_server_test.go @@ -366,7 +366,7 @@ var _ = Describe("FetchMcpServer", Label("mcp", "mcp-run", "e2e"), func() { Expect(err).ToNot(HaveOccurred(), "Output should be valid JSON") // Verify required fields - Expect(serverInfo["name"]).To(Equal("fetch")) + Expect(serverInfo["name"]).To(ContainSubstring("fetch")) Expect(serverInfo["tools"]).ToNot(BeNil(), "Should have tools field") })