From 97d4682fbcd234c46fa307c3fce4318e7c2931ae Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Tue, 17 Mar 2026 08:31:27 +0200 Subject: [PATCH] Add Skills API client for registry extension The toolhive-registry-server exposes a Skills API as a ToolHive-specific extension under /v0.1/x/dev.toolhive/skills. This adds an HTTP client to query that API, following the same patterns as the existing server client. - Extract shared HTTP client builder and error types into shared.go so both the server client and new skills client reuse the same security controls (private IP policy, auth token injection, error handling with LimitReader) - Add SkillsClient interface with GetSkill, GetSkillVersion, ListSkills, SearchSkills, and ListSkillVersions methods - Add RegistryHTTPError with Unwrap() for structured 401/403 handling - Migrate existing server client to use the shared error type - Add comprehensive table-driven tests with httptest Co-Authored-By: Claude Opus 4.6 (1M context) --- pkg/registry/api/client.go | 21 +- pkg/registry/api/shared.go | 64 +++ pkg/registry/api/skills_client.go | 258 +++++++++++ pkg/registry/api/skills_client_test.go | 565 +++++++++++++++++++++++++ 4 files changed, 892 insertions(+), 16 deletions(-) create mode 100644 pkg/registry/api/shared.go create mode 100644 pkg/registry/api/skills_client.go create mode 100644 pkg/registry/api/skills_client_test.go diff --git a/pkg/registry/api/client.go b/pkg/registry/api/client.go index 4caf9d63d4..910f2a5e11 100644 --- a/pkg/registry/api/client.go +++ b/pkg/registry/api/client.go @@ -16,7 +16,6 @@ import ( v0 "github.com/modelcontextprotocol/registry/pkg/api/v0" "gopkg.in/yaml.v3" - "github.com/stacklok/toolhive/pkg/networking" "github.com/stacklok/toolhive/pkg/registry/auth" "github.com/stacklok/toolhive/pkg/versions" ) @@ -63,18 +62,11 @@ type mcpRegistryClient struct { func NewClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (Client, error) { // Build HTTP client with security controls // If private IPs are allowed, also allow HTTP (for localhost testing) - builder := networking.NewHttpClientBuilder().WithPrivateIPs(allowPrivateIp) - if allowPrivateIp { - builder = builder.WithInsecureAllowHTTP(true) - } - httpClient, err := builder.Build() + httpClient, err := buildHTTPClient(allowPrivateIp, tokenSource) if err != nil { - return nil, fmt.Errorf("failed to build HTTP client: %w", err) + return nil, err } - // Wrap transport with auth if token source is provided - httpClient.Transport = auth.WrapTransport(httpClient.Transport, tokenSource) - // Ensure base URL doesn't have trailing slash if baseURL[len(baseURL)-1] == '/' { baseURL = baseURL[:len(baseURL)-1] @@ -112,8 +104,7 @@ func (c *mcpRegistryClient) GetServer(ctx context.Context, name string) (*v0.Ser }() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + return nil, newRegistryHTTPError(resp) } var serverResp v0.ServerResponse @@ -207,8 +198,7 @@ func (c *mcpRegistryClient) fetchServersPage( }() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, "", fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + return nil, "", newRegistryHTTPError(resp) } var listResp v0.ServerListResponse @@ -252,8 +242,7 @@ func (c *mcpRegistryClient) SearchServers(ctx context.Context, query string) ([] }() if resp.StatusCode != http.StatusOK { - body, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + return nil, newRegistryHTTPError(resp) } var listResp v0.ServerListResponse diff --git a/pkg/registry/api/shared.go b/pkg/registry/api/shared.go new file mode 100644 index 0000000000..c0ff6f482b --- /dev/null +++ b/pkg/registry/api/shared.go @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "errors" + "fmt" + "io" + "net/http" + + "github.com/stacklok/toolhive/pkg/networking" + "github.com/stacklok/toolhive/pkg/registry/auth" +) + +const maxErrorBodySize = 4096 + +// ErrRegistryUnauthorized is a sentinel error for 401/403 responses from registry APIs. +var ErrRegistryUnauthorized = errors.New("registry authentication failed") + +// RegistryHTTPError represents an HTTP error from a registry API endpoint. +type RegistryHTTPError struct { + StatusCode int + Body string + URL string +} + +func (e *RegistryHTTPError) Error() string { + return fmt.Sprintf("registry API returned status %d for %s: %s", e.StatusCode, e.URL, e.Body) +} + +// Unwrap returns ErrRegistryUnauthorized for 401/403 status codes, +// allowing callers to use errors.Is(err, ErrRegistryUnauthorized). +func (e *RegistryHTTPError) Unwrap() error { + if e.StatusCode == http.StatusUnauthorized || e.StatusCode == http.StatusForbidden { + return ErrRegistryUnauthorized + } + return nil +} + +// buildHTTPClient creates an HTTP client with security controls and optional auth. +// If allowPrivateIp is true, HTTP (non-HTTPS) is also allowed for localhost testing. +func buildHTTPClient(allowPrivateIp bool, tokenSource auth.TokenSource) (*http.Client, error) { + builder := networking.NewHttpClientBuilder().WithPrivateIPs(allowPrivateIp) + if allowPrivateIp { + builder = builder.WithInsecureAllowHTTP(true) + } + httpClient, err := builder.Build() + if err != nil { + return nil, fmt.Errorf("failed to build HTTP client: %w", err) + } + httpClient.Transport = auth.WrapTransport(httpClient.Transport, tokenSource) + return httpClient, nil +} + +// newRegistryHTTPError reads the response body (limited) and returns a RegistryHTTPError. +func newRegistryHTTPError(resp *http.Response) *RegistryHTTPError { + body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize)) + return &RegistryHTTPError{ + StatusCode: resp.StatusCode, + Body: string(body), + URL: resp.Request.URL.String(), + } +} diff --git a/pkg/registry/api/skills_client.go b/pkg/registry/api/skills_client.go new file mode 100644 index 0000000000..642259dbae --- /dev/null +++ b/pkg/registry/api/skills_client.go @@ -0,0 +1,258 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/url" + "strings" + + thvregistry "github.com/stacklok/toolhive-core/registry/types" + "github.com/stacklok/toolhive/pkg/registry/auth" + "github.com/stacklok/toolhive/pkg/versions" +) + +const skillsBasePath = "/v0.1/x/dev.toolhive/skills" + +// SkillsListOptions contains options for listing skills. +type SkillsListOptions struct { + // Search is an optional search query to filter skills. + Search string + // Limit is the maximum number of skills per page (default: 100). + Limit int + // Cursor is the pagination cursor for fetching the next page. + Cursor string +} + +// SkillsListResult contains a page of skills and pagination info. +type SkillsListResult struct { + Skills []*thvregistry.Skill + NextCursor string +} + +// SkillsClient provides access to the ToolHive Skills extension API. +type SkillsClient interface { + // GetSkill retrieves a skill by namespace and name (latest version). + GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error) + // GetSkillVersion retrieves a specific version of a skill. + GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error) + // ListSkills retrieves skills with optional filtering and pagination. + ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error) + // SearchSkills searches for skills matching the query (single page, no auto-pagination). + SearchSkills(ctx context.Context, query string) (*SkillsListResult, error) + // ListSkillVersions lists all versions of a specific skill. + ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error) +} + +// NewSkillsClient creates a new ToolHive Skills extension API client. +// If tokenSource is non-nil, the HTTP client transport will be wrapped to inject +// Bearer tokens into all requests. +func NewSkillsClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (SkillsClient, error) { + httpClient, err := buildHTTPClient(allowPrivateIp, tokenSource) + if err != nil { + return nil, err + } + + // Ensure base URL doesn't have trailing slash + baseURL = strings.TrimRight(baseURL, "/") + + return &mcpSkillsClient{ + baseURL: baseURL, + httpClient: httpClient, + userAgent: versions.GetUserAgent(), + }, nil +} + +// GetSkill retrieves a skill by namespace and name (latest version). +func (c *mcpSkillsClient) GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error) { + endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name)) + if err != nil { + return nil, fmt.Errorf("failed to build skills URL: %w", err) + } + + var skill thvregistry.Skill + if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil { + return nil, err + } + return &skill, nil +} + +// GetSkillVersion retrieves a specific version of a skill. +func (c *mcpSkillsClient) GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error) { + endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, + url.PathEscape(namespace), url.PathEscape(name), + "versions", url.PathEscape(version)) + if err != nil { + return nil, fmt.Errorf("failed to build skills URL: %w", err) + } + + var skill thvregistry.Skill + if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil { + return nil, err + } + return &skill, nil +} + +// ListSkills retrieves skills with optional filtering and pagination. +// It auto-paginates through all available pages, concatenating results. +func (c *mcpSkillsClient) ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error) { + if opts == nil { + opts = &SkillsListOptions{} + } + if opts.Limit == 0 { + opts.Limit = 100 + } + + var allSkills []*thvregistry.Skill + cursor := opts.Cursor + + // Pagination loop - continue until no more cursors + for { + page, nextCursor, err := c.fetchSkillsPage(ctx, cursor, opts) + if err != nil { + return nil, err + } + + allSkills = append(allSkills, page...) + + // Check if we have more pages + if nextCursor == "" { + break + } + + cursor = nextCursor + + // Safety limit: prevent infinite loops + if len(allSkills) > 10000 { + return nil, fmt.Errorf("exceeded maximum skills limit (10000)") + } + } + + return &SkillsListResult{ + Skills: allSkills, + }, nil +} + +// SearchSkills searches for skills matching the query. +// Returns a single page of results (no auto-pagination). +func (c *mcpSkillsClient) SearchSkills(ctx context.Context, query string) (*SkillsListResult, error) { + basePath, err := url.JoinPath(c.baseURL, skillsBasePath) + if err != nil { + return nil, fmt.Errorf("failed to build skills URL: %w", err) + } + params := url.Values{} + params.Add("search", query) + + endpoint := basePath + "?" + params.Encode() + + var listResp skillsListResponse + if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil { + return nil, err + } + + return &SkillsListResult{ + Skills: listResp.Skills, + NextCursor: listResp.Metadata.NextCursor, + }, nil +} + +// ListSkillVersions lists all versions of a specific skill. +func (c *mcpSkillsClient) ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error) { + endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name), "versions") + if err != nil { + return nil, fmt.Errorf("failed to build skills URL: %w", err) + } + + var listResp skillsListResponse + if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil { + return nil, err + } + + return &SkillsListResult{ + Skills: listResp.Skills, + NextCursor: listResp.Metadata.NextCursor, + }, nil +} + +// mcpSkillsClient implements the SkillsClient interface. +type mcpSkillsClient struct { + baseURL string + httpClient *http.Client + userAgent string +} + +// skillsListResponse is the wire format for list/search responses. +type skillsListResponse struct { + Skills []*thvregistry.Skill `json:"skills"` + Metadata struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + } `json:"metadata"` +} + +// doSkillsGet performs an HTTP GET request and decodes the JSON response into dest. +func (c *mcpSkillsClient) doSkillsGet(ctx context.Context, endpoint string, dest any) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("User-Agent", c.userAgent) + + resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL from configured registry + if err != nil { + return fmt.Errorf("failed to execute request: %w", err) + } + defer func() { + if err := resp.Body.Close(); err != nil { + slog.Debug("failed to close response body", "error", err) + } + }() + + if resp.StatusCode != http.StatusOK { + return newRegistryHTTPError(resp) + } + + if err := json.NewDecoder(resp.Body).Decode(dest); err != nil { + return fmt.Errorf("failed to decode response: %w", err) + } + return nil +} + +// fetchSkillsPage fetches a single page of skills. +func (c *mcpSkillsClient) fetchSkillsPage( + ctx context.Context, cursor string, opts *SkillsListOptions, +) ([]*thvregistry.Skill, string, error) { + params := url.Values{} + if cursor != "" { + params.Add("cursor", cursor) + } + if opts.Limit > 0 { + params.Add("limit", fmt.Sprintf("%d", opts.Limit)) + } + if opts.Search != "" { + params.Add("search", opts.Search) + } + + basePath, err := url.JoinPath(c.baseURL, skillsBasePath) + if err != nil { + return nil, "", fmt.Errorf("failed to build skills URL: %w", err) + } + endpoint := func() string { + if len(params) > 0 { + return basePath + "?" + params.Encode() + } + return basePath + }() + + var listResp skillsListResponse + if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil { + return nil, "", err + } + + return listResp.Skills, listResp.Metadata.NextCursor, nil +} diff --git a/pkg/registry/api/skills_client_test.go b/pkg/registry/api/skills_client_test.go new file mode 100644 index 0000000000..880f28a6d1 --- /dev/null +++ b/pkg/registry/api/skills_client_test.go @@ -0,0 +1,565 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package api + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + + thvregistry "github.com/stacklok/toolhive-core/registry/types" +) + +func newTestSkillsClient(t *testing.T, server *httptest.Server) SkillsClient { + t.Helper() + client, err := NewSkillsClient(server.URL, true, nil) + require.NoError(t, err) + return client +} + +func TestSkillsClient_GetSkill(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + namespace string + skillName string + handler http.HandlerFunc + wantSkill *thvregistry.Skill + wantErr bool + }{ + { + name: "success", + namespace: "io.github.user", + skillName: "my-skill", + handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user/my-skill", r.URL.Path) + require.Equal(t, http.MethodGet, r.Method) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + Description: "A test skill", + }) + require.NoError(t, err) + }, + wantSkill: &thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + Description: "A test skill", + }, + }, + { + name: "not found", + namespace: "io.github.user", + skillName: "nonexistent", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("skill not found")) + }, + wantErr: true, + }, + { + name: "server error", + namespace: "io.github.user", + skillName: "my-skill", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("internal error")) + }, + wantErr: true, + }, + { + name: "path escaping", + namespace: "io.github.user/special", + skillName: "my skill", + handler: func(w http.ResponseWriter, r *http.Request) { + // Verify that the path components are properly escaped + require.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user%2Fspecial/my%20skill", r.URL.RawPath) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user/special", + Name: "my skill", + Version: "1.0.0", + }) + require.NoError(t, err) + }, + wantSkill: &thvregistry.Skill{ + Namespace: "io.github.user/special", + Name: "my skill", + Version: "1.0.0", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + client := newTestSkillsClient(t, server) + skill, err := client.GetSkill(t.Context(), tt.namespace, tt.skillName) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantSkill, skill) + }) + } +} + +func TestSkillsClient_GetSkillVersion(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + namespace string + skillName string + version string + handler http.HandlerFunc + wantSkill *thvregistry.Skill + wantErr bool + }{ + { + name: "success", + namespace: "io.github.user", + skillName: "my-skill", + version: "2.0.0", + handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user/my-skill/versions/2.0.0", r.URL.Path) + require.Equal(t, http.MethodGet, r.Method) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "2.0.0", + Description: "Version 2", + }) + require.NoError(t, err) + }, + wantSkill: &thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "2.0.0", + Description: "Version 2", + }, + }, + { + name: "version not found", + namespace: "io.github.user", + skillName: "my-skill", + version: "99.0.0", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("version not found")) + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + client := newTestSkillsClient(t, server) + skill, err := client.GetSkillVersion(t.Context(), tt.namespace, tt.skillName, tt.version) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantSkill, skill) + }) + } +} + +func TestSkillsClient_ListSkills(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts *SkillsListOptions + handler http.HandlerFunc + wantCount int + wantErr bool + wantSkills []*thvregistry.Skill + }{ + { + name: "single page", + opts: nil, + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "1.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 2, NextCursor: ""}, + }) + require.NoError(t, err) + }, + wantCount: 2, + wantSkills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "1.0.0"}, + }, + }, + { + name: "pagination across multiple pages", + opts: &SkillsListOptions{Limit: 1}, + handler: func() http.HandlerFunc { + callCount := 0 + return func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + + cursor := r.URL.Query().Get("cursor") + var resp skillsListResponse + + switch { + case cursor == "" && callCount == 1: + resp = skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 1, NextCursor: "page2"}, + } + case cursor == "page2": + resp = skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.b", Name: "skill-2", Version: "1.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 1, NextCursor: ""}, + } + default: + w.WriteHeader(http.StatusBadRequest) + return + } + + err := json.NewEncoder(w).Encode(resp) + require.NoError(t, err) + } + }(), + wantCount: 2, + wantSkills: []*thvregistry.Skill{ + {Namespace: "io.github.a", Name: "skill-1", Version: "1.0.0"}, + {Namespace: "io.github.b", Name: "skill-2", Version: "1.0.0"}, + }, + }, + { + name: "empty result", + opts: nil, + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{}, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 0, NextCursor: ""}, + }) + require.NoError(t, err) + }, + wantCount: 0, + wantSkills: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + client := newTestSkillsClient(t, server) + result, err := client.ListSkills(t.Context(), tt.opts) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Len(t, result.Skills, tt.wantCount) + if tt.wantSkills != nil { + require.Equal(t, tt.wantSkills, result.Skills) + } + }) + } +} + +func TestSkillsClient_SearchSkills(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + query string + handler http.HandlerFunc + wantCount int + wantErr bool + }{ + { + name: "success with results", + query: "kubernetes", + handler: func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "kubernetes", r.URL.Query().Get("search")) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "k8s-skill", Version: "1.0.0", Description: "Kubernetes skill"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 1, NextCursor: ""}, + }) + require.NoError(t, err) + }, + wantCount: 1, + }, + { + name: "empty result", + query: "nonexistent", + handler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{}, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 0, NextCursor: ""}, + }) + require.NoError(t, err) + }, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(tt.handler) + defer server.Close() + + client := newTestSkillsClient(t, server) + result, err := client.SearchSkills(t.Context(), tt.query) + + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Len(t, result.Skills, tt.wantCount) + }) + } +} + +func TestSkillsClient_ListSkillVersions(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v0.1/x/dev.toolhive/skills/io.github.user/my-skill/versions", r.URL.Path) + require.Equal(t, http.MethodGet, r.Method) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "my-skill", Version: "1.0.0"}, + {Namespace: "io.github.user", Name: "my-skill", Version: "2.0.0"}, + {Namespace: "io.github.user", Name: "my-skill", Version: "3.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 3, NextCursor: ""}, + }) + require.NoError(t, err) + })) + defer server.Close() + + client := newTestSkillsClient(t, server) + result, err := client.ListSkillVersions(t.Context(), "io.github.user", "my-skill") + require.NoError(t, err) + require.Len(t, result.Skills, 3) + require.Equal(t, "1.0.0", result.Skills[0].Version) + require.Equal(t, "2.0.0", result.Skills[1].Version) + require.Equal(t, "3.0.0", result.Skills[2].Version) +} + +func TestSkillsClient_ErrorHandling(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + body string + wantErrIs error + }{ + { + name: "401 unauthorized", + statusCode: http.StatusUnauthorized, + body: "unauthorized", + wantErrIs: ErrRegistryUnauthorized, + }, + { + name: "403 forbidden", + statusCode: http.StatusForbidden, + body: "forbidden", + wantErrIs: ErrRegistryUnauthorized, + }, + { + name: "500 server error does not unwrap to unauthorized", + statusCode: http.StatusInternalServerError, + body: "internal server error", + wantErrIs: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tt.statusCode) + _, _ = w.Write([]byte(tt.body)) + })) + defer server.Close() + + client := newTestSkillsClient(t, server) + _, err := client.GetSkill(t.Context(), "io.github.user", "my-skill") + require.Error(t, err) + + var httpErr *RegistryHTTPError + require.True(t, errors.As(err, &httpErr), "expected *RegistryHTTPError, got %T", err) + require.Equal(t, tt.statusCode, httpErr.StatusCode) + require.Contains(t, httpErr.Body, tt.body) + + if tt.wantErrIs != nil { + require.True(t, errors.Is(err, tt.wantErrIs), + "expected errors.Is(%v, %v) to be true", err, tt.wantErrIs) + } else { + require.False(t, errors.Is(err, ErrRegistryUnauthorized), + "expected errors.Is(%v, ErrRegistryUnauthorized) to be false", err) + } + }) + } +} + +func TestSkillsClient_MalformedJSON(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{invalid json`)) + })) + defer server.Close() + + client := newTestSkillsClient(t, server) + _, err := client.GetSkill(t.Context(), "io.github.user", "my-skill") + require.Error(t, err) + require.Contains(t, err.Error(), "failed to decode response") +} + +func TestSkillsClient_TrailingSlashInBaseURL(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The path should not have a double slash + require.NotContains(t, r.URL.Path, "//") + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(thvregistry.Skill{ + Namespace: "io.github.user", + Name: "my-skill", + Version: "1.0.0", + }) + require.NoError(t, err) + })) + defer server.Close() + + // Create client with trailing slash + client, err := NewSkillsClient(server.URL+"/", true, nil) + require.NoError(t, err) + + skill, err := client.GetSkill(t.Context(), "io.github.user", "my-skill") + require.NoError(t, err) + require.Equal(t, "io.github.user", skill.Namespace) +} + +func TestSkillsClient_ListSkillsWithSearch(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "test-query", r.URL.Query().Get("search")) + require.Equal(t, "50", r.URL.Query().Get("limit")) + w.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(w).Encode(skillsListResponse{ + Skills: []*thvregistry.Skill{ + {Namespace: "io.github.user", Name: "test-skill", Version: "1.0.0"}, + }, + Metadata: struct { + Count int `json:"count"` + NextCursor string `json:"nextCursor"` + }{Count: 1, NextCursor: ""}, + }) + require.NoError(t, err) + })) + defer server.Close() + + client := newTestSkillsClient(t, server) + result, err := client.ListSkills(t.Context(), &SkillsListOptions{ + Search: "test-query", + Limit: 50, + }) + require.NoError(t, err) + require.Len(t, result.Skills, 1) + require.Equal(t, "test-skill", result.Skills[0].Name) +} + +func TestRegistryHTTPError_Unwrap(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + wantErrIs error + }{ + {name: "401 wraps unauthorized", statusCode: http.StatusUnauthorized, wantErrIs: ErrRegistryUnauthorized}, + {name: "403 wraps unauthorized", statusCode: http.StatusForbidden, wantErrIs: ErrRegistryUnauthorized}, + {name: "404 unwraps to nil", statusCode: http.StatusNotFound, wantErrIs: nil}, + {name: "500 unwraps to nil", statusCode: http.StatusInternalServerError, wantErrIs: nil}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := &RegistryHTTPError{ + StatusCode: tt.statusCode, + Body: "test body", + URL: "http://example.com/test", + } + require.Equal(t, tt.wantErrIs, err.Unwrap()) + require.Contains(t, err.Error(), fmt.Sprintf("status %d", tt.statusCode)) + }) + } +}