diff --git a/pkg/registry/api/client.go b/pkg/registry/api/client.go index 4caf9d63d4..910f2a5e11 100644 --- a/pkg/registry/api/client.go +++ b/pkg/registry/api/client.go @@ -16,7 +16,6 @@ import ( v0 "github.com/modelcontextprotocol/registry/pkg/api/v0" "gopkg.in/yaml.v3" - "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/registry/auth" "github.com/stacklok/toolhive/pkg/versions" ) @@ -63,18 +62,11 @@ type mcpRegistryClient struct { func NewClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (Client, error) { // Build HTTP client with security controls // If private IPs are allowed, also allow HTTP (for localhost testing) - builder := networking.NewHttpClientBuilder().WithPrivateIPs(allowPrivateIp) - if allowPrivateIp { - builder = builder.WithInsecureAllowHTTP(true) - } - httpClient, err := builder.Build() + httpClient, err := buildHTTPClient(allowPrivateIp, tokenSource) if err != nil { - return nil, fmt.Errorf("failed to build HTTP client: %w", err) + return nil, err } - // Wrap transport with auth if token source is provided - httpClient.Transport = auth.WrapTransport(httpClient.Transport, tokenSource) - // Ensure base URL doesn't have trailing slash if baseURL[len(baseURL)-1] == '/' { baseURL = baseURL[:len(baseURL)-1] @@ -112,8 +104,7 @@ func (c *mcpRegistryClient) GetServer(ctx context.Context, name string) (*v0.Ser }() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + return nil, newRegistryHTTPError(resp) } var serverResp v0.ServerResponse @@ -207,8 +198,7 @@ func (c *mcpRegistryClient) fetchServersPage( }() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, "", fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + return nil, "", newRegistryHTTPError(resp) } var listResp v0.ServerListResponse @@ -252,8 +242,7 @@ func (c *mcpRegistryClient) SearchServers(ctx context.Context, query string) ([] }() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + return nil, newRegistryHTTPError(resp) } var listResp v0.ServerListResponse diff --git a/pkg/registry/api/shared.go b/pkg/registry/api/shared.go new file mode 100644 index 0000000000..c0ff6f482b --- /dev/null +++ b/pkg/registry/api/shared.go @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/stacklok/toolhive/pkg/networking" + "github.com/stacklok/toolhive/pkg/registry/auth" +) + +const maxErrorBodySize = 4096 + +// ErrRegistryUnauthorized is a sentinel error for 401/403 responses from registry APIs. +var ErrRegistryUnauthorized = errors.New("registry authentication failed") + +// RegistryHTTPError represents an HTTP error from a registry API endpoint. +type RegistryHTTPError struct { + StatusCode int + Body string + URL string +} + +func (e *RegistryHTTPError) Error() string { + return fmt.Sprintf("registry API returned status %d for %s: %s", e.StatusCode, e.URL, e.Body) +} + +// Unwrap returns ErrRegistryUnauthorized for 401/403 status codes, +// allowing callers to use errors.Is(err, ErrRegistryUnauthorized). +func (e *RegistryHTTPError) Unwrap() error { + if e.StatusCode == http.StatusUnauthorized || e.StatusCode == http.StatusForbidden { + return ErrRegistryUnauthorized + } + return nil +} + +// buildHTTPClient creates an HTTP client with security controls and optional auth. +// If allowPrivateIp is true, HTTP (non-HTTPS) is also allowed for localhost testing. +func buildHTTPClient(allowPrivateIp bool, tokenSource auth.TokenSource) (*http.Client, error) { + builder := networking.NewHttpClientBuilder().WithPrivateIPs(allowPrivateIp) + if allowPrivateIp { + builder = builder.WithInsecureAllowHTTP(true) + } + httpClient, err := builder.Build() + if err != nil { + return nil, fmt.Errorf("failed to build HTTP client: %w", err) + } + httpClient.Transport = auth.WrapTransport(httpClient.Transport, tokenSource) + return httpClient, nil +} + +// newRegistryHTTPError reads the response body (limited) and returns a RegistryHTTPError. +func newRegistryHTTPError(resp *http.Response) *RegistryHTTPError { + body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize)) + return &RegistryHTTPError{ + StatusCode: resp.StatusCode, + Body: string(body), + URL: resp.Request.URL.String(), + } +} diff --git a/pkg/registry/api/skills_client.go b/pkg/registry/api/skills_client.go new file mode 100644 index 0000000000..642259dbae --- /dev/null +++ b/pkg/registry/api/skills_client.go @@ -0,0 +1,258 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/url" + "strings" + + thvregistry "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/auth" + "github.com/stacklok/toolhive/pkg/versions" +) + +const skillsBasePath = "/v0.1/x/dev.toolhive/skills" + +// SkillsListOptions contains options for listing skills. +type SkillsListOptions struct { + // Search is an optional search query to filter skills. + Search string + // Limit is the maximum number of skills per page (default: 100). + Limit int + // Cursor is the pagination cursor for fetching the next page. + Cursor string +} + +// SkillsListResult contains a page of skills and pagination info. +type SkillsListResult struct { + Skills []*thvregistry.Skill + NextCursor string +} + +// SkillsClient provides access to the ToolHive Skills extension API. +type SkillsClient interface { + // GetSkill retrieves a skill by namespace and name (latest version). + GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error) + // GetSkillVersion retrieves a specific version of a skill. + GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error) + // ListSkills retrieves skills with optional filtering and pagination. + ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error) + // SearchSkills searches for skills matching the query (single page, no auto-pagination). + SearchSkills(ctx context.Context, query string) (*SkillsListResult, error) + // ListSkillVersions lists all versions of a specific skill. + ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error) +} + +// NewSkillsClient creates a new ToolHive Skills extension API client. +// If tokenSource is non-nil, the HTTP client transport will be wrapped to inject +// Bearer tokens into all requests. +func NewSkillsClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (SkillsClient, error) { + httpClient, err := buildHTTPClient(allowPrivateIp, tokenSource) + if err != nil { + return nil, err + } + + // Ensure base URL doesn't have trailing slash + baseURL = strings.TrimRight(baseURL, "/") + + return &mcpSkillsClient{ + baseURL: baseURL, + httpClient: httpClient, + userAgent: versions.GetUserAgent(), + }, nil +} + +// GetSkill retrieves a skill by namespace and name (latest version). +func (c *mcpSkillsClient) GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error) { + endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name)) + if err != nil { + return nil, fmt.Errorf("failed to build skills URL: %w", err) + } + + var skill thvregistry.Skill + if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil { + return nil, err + } + return &skill, nil +} + +// GetSkillVersion retrieves a specific version of a skill. +func (c *mcpSkillsClient) GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error) { + endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, + url.PathEscape(namespace), url.PathEscape(name), + "versions", url.PathEscape(version)) + if err != nil { + return nil, fmt.Errorf("failed to build skills URL: %w", err) + } + + var skill thvregistry.Skill + if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil { + return nil, err + } + return &skill, nil +} + +// ListSkills retrieves skills with optional filtering and pagination. +// It auto-paginates through all available pages, concatenating results. +func (c *mcpSkillsClient) ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error) { + if opts == nil { + opts = &SkillsListOptions{} + } + if opts.Limit == 0 { + opts.Limit = 100 + } + + var allSkills []*thvregistry.Skill + cursor := opts.Cursor + + // Pagination loop - continue until no more cursors + for { + page, nextCursor, err := c.fetchSkillsPage(ctx, cursor, opts) + if err != nil { + return nil, err + } + + allSkills = append(allSkills, page...) + + // Check if we have more pages + if nextCursor == "" { + break + } + + cursor = nextCursor + + // Safety limit: prevent infinite loops + if len(allSkills) > 10000 { + return nil, fmt.Errorf("exceeded maximum skills limit (10000)") + } + } + + return &SkillsListResult{ + Skills: allSkills, + }, nil +} + +// SearchSkills searches for skills matching the query. +// Returns a single page of results (no auto-pagination). +func (c *mcpSkillsClient) SearchSkills(ctx context.Context, query string) (*SkillsListResult, error) { + basePath, err := url.JoinPath(c.baseURL, skillsBasePath) + if err != nil { + return nil, fmt.Errorf("failed to build skills URL: %w", err) + } + params := url.Values{} + params.Add("search", query) + + endpoint := basePath + "?" + params.Encode() + + var listResp skillsListResponse + if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil { + return nil, err + } + + return &SkillsListResult{ + Skills: listResp.Skills, + NextCursor: listResp.Metadata.NextCursor, + }, nil +} + +// ListSkillVersions lists all versions of a specific skill. +func (c *mcpSkillsClient) ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error) { + endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name), "versions") + if err != nil { + return nil, fmt.Errorf("failed to build skills URL: %w", err) + } + + var listResp skillsListResponse + if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil { + return nil, err + } + + return &SkillsListResult{ + Skills: listResp.Skills, + NextCursor: listResp.Metadata.NextCursor, + }, nil +} + +// mcpSkillsClient implements the SkillsClient interface. +type mcpSkillsClient struct { + baseURL string + httpClient *http.Client + userAgent string +} + +// skillsListResponse is the wire format for list/search responses. +type skillsListResponse struct { + Skills []*thvregistry.Skill `json:"skills"` + Metadata struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + } `json:"metadata"` +} + +// doSkillsGet performs an HTTP GET request and decodes the JSON response into dest. +func (c *mcpSkillsClient) doSkillsGet(ctx context.Context, endpoint string, dest any) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("User-Agent", c.userAgent) + + resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL from configured registry + if err != nil { + return fmt.Errorf("failed to execute request: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + slog.Debug("failed to close response body", "error", err) + } + }() + + if resp.StatusCode != http.StatusOK { + return newRegistryHTTPError(resp) + } + + if err := json.NewDecoder(resp.Body).Decode(dest); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + return nil +} + +// fetchSkillsPage fetches a single page of skills. +func (c *mcpSkillsClient) fetchSkillsPage( + ctx context.Context, cursor string, opts *SkillsListOptions, +) ([]*thvregistry.Skill, string, error) { + params := url.Values{} + if cursor != "" { + params.Add("cursor", cursor) + } + if opts.Limit > 0 { + params.Add("limit", fmt.Sprintf("%d", opts.Limit)) + } + if opts.Search != "" { + params.Add("search", opts.Search) + } + + basePath, err := url.JoinPath(c.baseURL, skillsBasePath) + if err != nil { + return nil, "", fmt.Errorf("failed to build skills URL: %w", err) + } + endpoint := func() string { + if len(params) > 0 { + return basePath + "?" + params.Encode() + } + return basePath + }() + + var listResp skillsListResponse + if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil { + return nil, "", err + } + + return listResp.Skills, listResp.Metadata.NextCursor, nil +} diff --git a/pkg/registry/api/skills_client_test.go b/pkg/registry/api/skills_client_test.go new file mode 100644 index 0000000000..880f28a6d1 --- /dev/null +++ b/pkg/registry/api/skills_client_test.go @@ -0,0 +1,565 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + thvregistry "github.com/stacklok/toolhive-core/registry/types" +) + +func newTestSkillsClient(t *testing.T, server *httptest.Server) SkillsClient { + t.Helper() + client, err := NewSkillsClient(server.URL, true, nil) + require.NoError(t, err) + return client +} + +func TestSkillsClient_GetSkill(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + namespace string + skillName string + handler http.HandlerFunc + wantSkill *thvregistry.Skill + wantErr bool + }{ + { + name: "success", + namespace: "io.github.user", + skillName: "my-skill", + handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user/my-skill", r.URL.Path) + require.Equal(t, http.MethodGet, r.Method) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + Description: "A test skill", + }) + require.NoError(t, err) + }, + wantSkill: &thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + Description: "A test skill", + }, + }, + { + name: "not found", + namespace: "io.github.user", + skillName: "nonexistent", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("skill not found")) + }, + wantErr: true, + }, + { + name: "server error", + namespace: "io.github.user", + skillName: "my-skill", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal error")) + }, + wantErr: true, + }, + { + name: "path escaping", + namespace: "io.github.user/special", + skillName: "my skill", + handler: func(w http.ResponseWriter, r *http.Request) { + // Verify that the path components are properly escaped + require.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user%2Fspecial/my%20skill", r.URL.RawPath) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user/special", + Name: "my skill", + Version: "1.0.0", + }) + require.NoError(t, err) + }, + wantSkill: &thvregistry.Skill{ + Namespace: "io.github.user/special", + Name: "my skill", + Version: "1.0.0", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + client := newTestSkillsClient(t, server) + skill, err := client.GetSkill(t.Context(), tt.namespace, tt.skillName) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantSkill, skill) + }) + } +} + +func TestSkillsClient_GetSkillVersion(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + namespace string + skillName string + version string + handler http.HandlerFunc + wantSkill *thvregistry.Skill + wantErr bool + }{ + { + name: "success", + namespace: "io.github.user", + skillName: "my-skill", + version: "2.0.0", + handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user/my-skill/versions/2.0.0", r.URL.Path) + require.Equal(t, http.MethodGet, r.Method) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "2.0.0", + Description: "Version 2", + }) + require.NoError(t, err) + }, + wantSkill: &thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "2.0.0", + Description: "Version 2", + }, + }, + { + name: "version not found", + namespace: "io.github.user", + skillName: "my-skill", + version: "99.0.0", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("version not found")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + client := newTestSkillsClient(t, server) + skill, err := client.GetSkillVersion(t.Context(), tt.namespace, tt.skillName, tt.version) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantSkill, skill) + }) + } +} + +func TestSkillsClient_ListSkills(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts *SkillsListOptions + handler http.HandlerFunc + wantCount int + wantErr bool + wantSkills []*thvregistry.Skill + }{ + { + name: "single page", + opts: nil, + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "1.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 2, NextCursor: ""}, + }) + require.NoError(t, err) + }, + wantCount: 2, + wantSkills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "1.0.0"}, + }, + }, + { + name: "pagination across multiple pages", + opts: &SkillsListOptions{Limit: 1}, + handler: func() http.HandlerFunc { + callCount := 0 + return func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + + cursor := r.URL.Query().Get("cursor") + var resp skillsListResponse + + switch { + case cursor == "" && callCount == 1: + resp = skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 1, NextCursor: "page2"}, + } + case cursor == "page2": + resp = skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.b", Name: "skill-2", Version: "1.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 1, NextCursor: ""}, + } + default: + w.WriteHeader(http.StatusBadRequest) + return + } + + err := json.NewEncoder(w).Encode(resp) + require.NoError(t, err) + } + }(), + wantCount: 2, + wantSkills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "1.0.0"}, + }, + }, + { + name: "empty result", + opts: nil, + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{}, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 0, NextCursor: ""}, + }) + require.NoError(t, err) + }, + wantCount: 0, + wantSkills: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + client := newTestSkillsClient(t, server) + result, err := client.ListSkills(t.Context(), tt.opts) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Len(t, result.Skills, tt.wantCount) + if tt.wantSkills != nil { + require.Equal(t, tt.wantSkills, result.Skills) + } + }) + } +} + +func TestSkillsClient_SearchSkills(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + query string + handler http.HandlerFunc + wantCount int + wantErr bool + }{ + { + name: "success with results", + query: "kubernetes", + handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "kubernetes", r.URL.Query().Get("search")) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "k8s-skill", Version: "1.0.0", Description: "Kubernetes skill"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 1, NextCursor: ""}, + }) + require.NoError(t, err) + }, + wantCount: 1, + }, + { + name: "empty result", + query: "nonexistent", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{}, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 0, NextCursor: ""}, + }) + require.NoError(t, err) + }, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + client := newTestSkillsClient(t, server) + result, err := client.SearchSkills(t.Context(), tt.query) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Len(t, result.Skills, tt.wantCount) + }) + } +} + +func TestSkillsClient_ListSkillVersions(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user/my-skill/versions", r.URL.Path) + require.Equal(t, http.MethodGet, r.Method) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "my-skill", Version: "1.0.0"}, + {Namespace: "io.github.user", Name: "my-skill", Version: "2.0.0"}, + {Namespace: "io.github.user", Name: "my-skill", Version: "3.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 3, NextCursor: ""}, + }) + require.NoError(t, err) + })) + defer server.Close() + + client := newTestSkillsClient(t, server) + result, err := client.ListSkillVersions(t.Context(), "io.github.user", "my-skill") + require.NoError(t, err) + require.Len(t, result.Skills, 3) + require.Equal(t, "1.0.0", result.Skills[0].Version) + require.Equal(t, "2.0.0", result.Skills[1].Version) + require.Equal(t, "3.0.0", result.Skills[2].Version) +} + +func TestSkillsClient_ErrorHandling(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + body string + wantErrIs error + }{ + { + name: "401 unauthorized", + statusCode: http.StatusUnauthorized, + body: "unauthorized", + wantErrIs: ErrRegistryUnauthorized, + }, + { + name: "403 forbidden", + statusCode: http.StatusForbidden, + body: "forbidden", + wantErrIs: ErrRegistryUnauthorized, + }, + { + name: "500 server error does not unwrap to unauthorized", + statusCode: http.StatusInternalServerError, + body: "internal server error", + wantErrIs: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tt.statusCode) + _, _ = w.Write([]byte(tt.body)) + })) + defer server.Close() + + client := newTestSkillsClient(t, server) + _, err := client.GetSkill(t.Context(), "io.github.user", "my-skill") + require.Error(t, err) + + var httpErr *RegistryHTTPError + require.True(t, errors.As(err, &httpErr), "expected *RegistryHTTPError, got %T", err) + require.Equal(t, tt.statusCode, httpErr.StatusCode) + require.Contains(t, httpErr.Body, tt.body) + + if tt.wantErrIs != nil { + require.True(t, errors.Is(err, tt.wantErrIs), + "expected errors.Is(%v, %v) to be true", err, tt.wantErrIs) + } else { + require.False(t, errors.Is(err, ErrRegistryUnauthorized), + "expected errors.Is(%v, ErrRegistryUnauthorized) to be false", err) + } + }) + } +} + +func TestSkillsClient_MalformedJSON(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{invalid json`)) + })) + defer server.Close() + + client := newTestSkillsClient(t, server) + _, err := client.GetSkill(t.Context(), "io.github.user", "my-skill") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to decode response") +} + +func TestSkillsClient_TrailingSlashInBaseURL(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The path should not have a double slash + require.NotContains(t, r.URL.Path, "//") + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + }) + require.NoError(t, err) + })) + defer server.Close() + + // Create client with trailing slash + client, err := NewSkillsClient(server.URL+"/", true, nil) + require.NoError(t, err) + + skill, err := client.GetSkill(t.Context(), "io.github.user", "my-skill") + require.NoError(t, err) + require.Equal(t, "io.github.user", skill.Namespace) +} + +func TestSkillsClient_ListSkillsWithSearch(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "test-query", r.URL.Query().Get("search")) + require.Equal(t, "50", r.URL.Query().Get("limit")) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "test-skill", Version: "1.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 1, NextCursor: ""}, + }) + require.NoError(t, err) + })) + defer server.Close() + + client := newTestSkillsClient(t, server) + result, err := client.ListSkills(t.Context(), &SkillsListOptions{ + Search: "test-query", + Limit: 50, + }) + require.NoError(t, err) + require.Len(t, result.Skills, 1) + require.Equal(t, "test-skill", result.Skills[0].Name) +} + +func TestRegistryHTTPError_Unwrap(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + wantErrIs error + }{ + {name: "401 wraps unauthorized", statusCode: http.StatusUnauthorized, wantErrIs: ErrRegistryUnauthorized}, + {name: "403 wraps unauthorized", statusCode: http.StatusForbidden, wantErrIs: ErrRegistryUnauthorized}, + {name: "404 unwraps to nil", statusCode: http.StatusNotFound, wantErrIs: nil}, + {name: "500 unwraps to nil", statusCode: http.StatusInternalServerError, wantErrIs: nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := &RegistryHTTPError{ + StatusCode: tt.statusCode, + Body: "test body", + URL: "http://example.com/test", + } + require.Equal(t, tt.wantErrIs, err.Unwrap()) + require.Contains(t, err.Error(), fmt.Sprintf("status %d", tt.statusCode)) + }) + } +}