Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 51 additions & 5 deletions pkg/config/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
Expand Down Expand Up @@ -151,7 +152,7 @@ func isValidAPIResponse(resp *http.Response) bool {
}

// isValidRegistryJSON checks if a URL returns valid ToolHive registry JSON
// by attempting to parse it into the actual Registry type
// by attempting to parse it. Accepts both upstream and legacy formats.
func isValidRegistryJSON(client *http.Client, url string) error {
resp, err := client.Get(url)
if err != nil {
Expand All @@ -163,9 +164,26 @@ func isValidRegistryJSON(client *http.Client, url string) error {
}
}()

// Parse into the actual Registry type for strong validation
data, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("%w: failed to read response body: %v", ErrRegistryValidationFailed, err)
}

// Try upstream format first
if isUpstreamRegistryFormat(data) {
var upstream registrytypes.UpstreamRegistry
if err := json.Unmarshal(data, &upstream); err != nil {
return fmt.Errorf("%w: invalid upstream JSON format: %v", ErrRegistryValidationFailed, err)
}
if len(upstream.Data.Servers) > 0 || len(upstream.Data.Groups) > 0 {
return nil
}
return fmt.Errorf("%w: upstream registry contains no servers", ErrRegistryValidationFailed)
}

// Fall back to legacy format
registry := &registrytypes.Registry{}
if err := json.NewDecoder(resp.Body).Decode(registry); err != nil {
if err := json.Unmarshal(data, registry); err != nil {
return fmt.Errorf("%w: invalid JSON format: %v", ErrRegistryValidationFailed, err)
}

Expand Down Expand Up @@ -314,6 +332,22 @@ func setRegistryFile(provider Provider, registryPath string) error {
return nil
}

// isUpstreamRegistryFormat returns true if the JSON data appears to be in the
// upstream MCP registry format. The key discriminator is the "data" wrapper
// object — only the upstream format wraps servers inside it.
// NOTE: keep in sync with isUpstreamFormat in pkg/registry/upstream_parser.go
// (duplicated to avoid a circular import).
func isUpstreamRegistryFormat(data []byte) bool {
var probe struct {
Data json.RawMessage `json:"data"`
}
if err := json.Unmarshal(data, &probe); err != nil {
return false
}
// The "data" wrapper object is unique to the upstream format.
return len(probe.Data) > 0 && probe.Data[0] == '{'
}

// registryHasServers checks if a registry contains at least one server
// (either in top-level servers/remote_servers or within groups)
func registryHasServers(registry *registrytypes.Registry) bool {
Expand All @@ -333,7 +367,7 @@ func registryHasServers(registry *registrytypes.Registry) bool {
}

// validateRegistryFileStructure checks if a file contains valid ToolHive registry structure
// by parsing it into the actual Registry type
// by parsing it into the actual Registry type. Accepts both upstream and legacy formats.
func validateRegistryFileStructure(path string) error {
// Read file content
// #nosec G304: File path is user-provided but validated by caller
Expand All @@ -342,7 +376,19 @@ func validateRegistryFileStructure(path string) error {
return fmt.Errorf("failed to read file: %w", err)
}

// Parse into the actual Registry type for strong validation
// Try upstream format first
if isUpstreamRegistryFormat(data) {
var upstream registrytypes.UpstreamRegistry
if err := json.Unmarshal(data, &upstream); err != nil {
return fmt.Errorf("invalid upstream registry format: %w", err)
}
if len(upstream.Data.Servers) > 0 || len(upstream.Data.Groups) > 0 {
return nil
}
return fmt.Errorf("upstream registry contains no servers or groups")
}

// Fall back to legacy format
registry := &registrytypes.Registry{}
if err := json.Unmarshal(data, registry); err != nil {
return fmt.Errorf("invalid registry format: %w", err)
Expand Down
123 changes: 123 additions & 0 deletions pkg/config/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

const testAPIEndpoint = "/v0.1/servers"
Expand Down Expand Up @@ -345,6 +347,127 @@ func TestIsValidRegistryJSON(t *testing.T) {
}
}

func TestValidateRegistryFileStructure_UpstreamFormat(t *testing.T) {
t.Parallel()

tests := []struct {
name string
content string
expectError bool
}{
{
name: "valid upstream format with servers",
content: `{
"$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json",
"version": "1.0.0",
"meta": {"last_updated": "2025-01-01T00:00:00Z"},
"data": {
"servers": [
{
"name": "io.example.test",
"description": "Test",
"packages": [{"registryType": "oci", "identifier": "test:latest", "transport": {"type": "stdio"}}]
}
]
}
}`,
expectError: false,
},
{
name: "upstream format with empty servers",
content: `{
"$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json",
"version": "1.0.0",
"meta": {"last_updated": "2025-01-01T00:00:00Z"},
"data": {"servers": []}
}`,
expectError: true,
},
{
name: "legacy format still works",
content: `{
"version": "1.0.0",
"servers": {"test": {"image": "test:latest"}}
}`,
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
path := tmpDir + "/registry.json"
require.NoError(t, os.WriteFile(path, []byte(tt.content), 0644))

err := validateRegistryFileStructure(path)
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

func TestIsValidRegistryJSON_UpstreamFormat(t *testing.T) {
t.Parallel()

tests := []struct {
name string
body string
expectedError bool
}{
{
name: "valid upstream format",
body: `{
"$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json",
"version": "1.0.0",
"meta": {"last_updated": "2025-01-01T00:00:00Z"},
"data": {
"servers": [
{
"name": "io.example.test",
"description": "Test",
"packages": [{"registryType": "oci", "identifier": "test:latest", "transport": {"type": "stdio"}}]
}
]
}
}`,
expectedError: false,
},
{
name: "upstream format with no servers",
body: `{
"$schema": "https://cdn.mcpregistry.io/schema/v0/registry.json",
"version": "1.0.0",
"meta": {"last_updated": "2025-01-01T00:00:00Z"},
"data": {"servers": []}
}`,
expectedError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(tt.body))
}))
defer server.Close()

client := &http.Client{}
err := isValidRegistryJSON(client, server.URL)
if tt.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

func TestProbeRegistryURL(t *testing.T) { //nolint:tparallel,paralleltest // Cannot use t.Parallel() on subtests using t.Setenv()
tests := []struct {
name string
Expand Down
72 changes: 36 additions & 36 deletions pkg/registry/mocks/mock_provider.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 6 additions & 7 deletions pkg/registry/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@ type Provider interface {
// ListServers returns all available servers (both container and remote)
ListServers() ([]types.ServerMetadata, error)

// Legacy methods for backward compatibility
// GetImageServer returns a specific container server by name
GetImageServer(name string) (*types.ImageMetadata, error)
// ListAvailableSkills returns skills discovered from the registry data
ListAvailableSkills() ([]types.Skill, error)

// SearchImageServers searches for container servers matching the query
SearchImageServers(query string) ([]*types.ImageMetadata, error)
// GetSkill returns a specific skill by namespace and name
GetSkill(namespace, name string) (*types.Skill, error)

// ListImageServers returns all available container servers
ListImageServers() ([]*types.ImageMetadata, error)
// SearchSkills searches for skills matching the query
SearchSkills(query string) ([]types.Skill, error)
}
Loading
Loading