Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 5 additions & 16 deletions pkg/registry/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
v0 "github.com/modelcontextprotocol/registry/pkg/api/v0"
"gopkg.in/yaml.v3"

"github.com/stacklok/toolhive/pkg/networking"
"github.com/stacklok/toolhive/pkg/registry/auth"
"github.com/stacklok/toolhive/pkg/versions"
)
Expand Down Expand Up @@ -63,18 +62,11 @@ type mcpRegistryClient struct {
func NewClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (Client, error) {
// Build HTTP client with security controls
// If private IPs are allowed, also allow HTTP (for localhost testing)
builder := networking.NewHttpClientBuilder().WithPrivateIPs(allowPrivateIp)
if allowPrivateIp {
builder = builder.WithInsecureAllowHTTP(true)
}
httpClient, err := builder.Build()
httpClient, err := buildHTTPClient(allowPrivateIp, tokenSource)
if err != nil {
return nil, fmt.Errorf("failed to build HTTP client: %w", err)
return nil, err
}

// Wrap transport with auth if token source is provided
httpClient.Transport = auth.WrapTransport(httpClient.Transport, tokenSource)

// Ensure base URL doesn't have trailing slash
if baseURL[len(baseURL)-1] == '/' {
baseURL = baseURL[:len(baseURL)-1]
Expand Down Expand Up @@ -112,8 +104,7 @@ func (c *mcpRegistryClient) GetServer(ctx context.Context, name string) (*v0.Ser
}()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
return nil, newRegistryHTTPError(resp)
}

var serverResp v0.ServerResponse
Expand Down Expand Up @@ -207,8 +198,7 @@ func (c *mcpRegistryClient) fetchServersPage(
}()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, "", fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
return nil, "", newRegistryHTTPError(resp)
}

var listResp v0.ServerListResponse
Expand Down Expand Up @@ -252,8 +242,7 @@ func (c *mcpRegistryClient) SearchServers(ctx context.Context, query string) ([]
}()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
return nil, newRegistryHTTPError(resp)
}

var listResp v0.ServerListResponse
Expand Down
64 changes: 64 additions & 0 deletions pkg/registry/api/shared.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package api

import (
"errors"
"fmt"
"io"
"net/http"

"github.com/stacklok/toolhive/pkg/networking"
"github.com/stacklok/toolhive/pkg/registry/auth"
)

const maxErrorBodySize = 4096

// ErrRegistryUnauthorized is a sentinel error for 401/403 responses from registry APIs.
var ErrRegistryUnauthorized = errors.New("registry authentication failed")

// RegistryHTTPError represents an HTTP error from a registry API endpoint.
type RegistryHTTPError struct {
StatusCode int
Body string
URL string
}

func (e *RegistryHTTPError) Error() string {
return fmt.Sprintf("registry API returned status %d for %s: %s", e.StatusCode, e.URL, e.Body)
}

// Unwrap returns ErrRegistryUnauthorized for 401/403 status codes,
// allowing callers to use errors.Is(err, ErrRegistryUnauthorized).
func (e *RegistryHTTPError) Unwrap() error {
if e.StatusCode == http.StatusUnauthorized || e.StatusCode == http.StatusForbidden {
return ErrRegistryUnauthorized
}
return nil
}

// buildHTTPClient creates an HTTP client with security controls and optional auth.
// If allowPrivateIp is true, HTTP (non-HTTPS) is also allowed for localhost testing.
func buildHTTPClient(allowPrivateIp bool, tokenSource auth.TokenSource) (*http.Client, error) {
builder := networking.NewHttpClientBuilder().WithPrivateIPs(allowPrivateIp)
if allowPrivateIp {
builder = builder.WithInsecureAllowHTTP(true)
}
httpClient, err := builder.Build()
if err != nil {
return nil, fmt.Errorf("failed to build HTTP client: %w", err)
}
httpClient.Transport = auth.WrapTransport(httpClient.Transport, tokenSource)
return httpClient, nil
}

// newRegistryHTTPError reads the response body (limited) and returns a RegistryHTTPError.
func newRegistryHTTPError(resp *http.Response) *RegistryHTTPError {
body, _ := io.ReadAll(io.LimitReader(resp.Body, maxErrorBodySize))
return &RegistryHTTPError{
StatusCode: resp.StatusCode,
Body: string(body),
URL: resp.Request.URL.String(),
}
}
258 changes: 258 additions & 0 deletions pkg/registry/api/skills_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package api

import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"net/url"
"strings"

thvregistry "github.com/stacklok/toolhive-core/registry/types"
"github.com/stacklok/toolhive/pkg/registry/auth"
"github.com/stacklok/toolhive/pkg/versions"
)

const skillsBasePath = "/v0.1/x/dev.toolhive/skills"

// SkillsListOptions contains options for listing skills.
type SkillsListOptions struct {
// Search is an optional search query to filter skills.
Search string
// Limit is the maximum number of skills per page (default: 100).
Limit int
// Cursor is the pagination cursor for fetching the next page.
Cursor string
}

// SkillsListResult contains a page of skills and pagination info.
type SkillsListResult struct {
Skills []*thvregistry.Skill
NextCursor string
}

// SkillsClient provides access to the ToolHive Skills extension API.
type SkillsClient interface {
// GetSkill retrieves a skill by namespace and name (latest version).
GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error)
// GetSkillVersion retrieves a specific version of a skill.
GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error)
// ListSkills retrieves skills with optional filtering and pagination.
ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error)
// SearchSkills searches for skills matching the query (single page, no auto-pagination).
SearchSkills(ctx context.Context, query string) (*SkillsListResult, error)
// ListSkillVersions lists all versions of a specific skill.
ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error)
}

// NewSkillsClient creates a new ToolHive Skills extension API client.
// If tokenSource is non-nil, the HTTP client transport will be wrapped to inject
// Bearer tokens into all requests.
func NewSkillsClient(baseURL string, allowPrivateIp bool, tokenSource auth.TokenSource) (SkillsClient, error) {
httpClient, err := buildHTTPClient(allowPrivateIp, tokenSource)
if err != nil {
return nil, err
}

// Ensure base URL doesn't have trailing slash
baseURL = strings.TrimRight(baseURL, "/")

return &mcpSkillsClient{
baseURL: baseURL,
httpClient: httpClient,
userAgent: versions.GetUserAgent(),
}, nil
}

// GetSkill retrieves a skill by namespace and name (latest version).
func (c *mcpSkillsClient) GetSkill(ctx context.Context, namespace, name string) (*thvregistry.Skill, error) {
endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name))
if err != nil {
return nil, fmt.Errorf("failed to build skills URL: %w", err)
}

var skill thvregistry.Skill
if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil {
return nil, err
}
return &skill, nil
}

// GetSkillVersion retrieves a specific version of a skill.
func (c *mcpSkillsClient) GetSkillVersion(ctx context.Context, namespace, name, version string) (*thvregistry.Skill, error) {
endpoint, err := url.JoinPath(c.baseURL, skillsBasePath,
url.PathEscape(namespace), url.PathEscape(name),
"versions", url.PathEscape(version))
if err != nil {
return nil, fmt.Errorf("failed to build skills URL: %w", err)
}

var skill thvregistry.Skill
if err := c.doSkillsGet(ctx, endpoint, &skill); err != nil {
return nil, err
}
return &skill, nil
}

// ListSkills retrieves skills with optional filtering and pagination.
// It auto-paginates through all available pages, concatenating results.
func (c *mcpSkillsClient) ListSkills(ctx context.Context, opts *SkillsListOptions) (*SkillsListResult, error) {
if opts == nil {
opts = &SkillsListOptions{}
}
if opts.Limit == 0 {
opts.Limit = 100
}

var allSkills []*thvregistry.Skill
cursor := opts.Cursor

// Pagination loop - continue until no more cursors
for {
page, nextCursor, err := c.fetchSkillsPage(ctx, cursor, opts)
if err != nil {
return nil, err
}

allSkills = append(allSkills, page...)

// Check if we have more pages
if nextCursor == "" {
break
}

cursor = nextCursor

// Safety limit: prevent infinite loops
if len(allSkills) > 10000 {
return nil, fmt.Errorf("exceeded maximum skills limit (10000)")
}
}

return &SkillsListResult{
Skills: allSkills,
}, nil
}

// SearchSkills searches for skills matching the query.
// Returns a single page of results (no auto-pagination).
func (c *mcpSkillsClient) SearchSkills(ctx context.Context, query string) (*SkillsListResult, error) {
basePath, err := url.JoinPath(c.baseURL, skillsBasePath)
if err != nil {
return nil, fmt.Errorf("failed to build skills URL: %w", err)
}
params := url.Values{}
params.Add("search", query)

endpoint := basePath + "?" + params.Encode()

var listResp skillsListResponse
if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil {
return nil, err
}

return &SkillsListResult{
Skills: listResp.Skills,
NextCursor: listResp.Metadata.NextCursor,
}, nil
}

// ListSkillVersions lists all versions of a specific skill.
func (c *mcpSkillsClient) ListSkillVersions(ctx context.Context, namespace, name string) (*SkillsListResult, error) {
endpoint, err := url.JoinPath(c.baseURL, skillsBasePath, url.PathEscape(namespace), url.PathEscape(name), "versions")
if err != nil {
return nil, fmt.Errorf("failed to build skills URL: %w", err)
}

var listResp skillsListResponse
if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil {
return nil, err
}

return &SkillsListResult{
Skills: listResp.Skills,
NextCursor: listResp.Metadata.NextCursor,
}, nil
}

// mcpSkillsClient implements the SkillsClient interface.
type mcpSkillsClient struct {
baseURL string
httpClient *http.Client
userAgent string
}

// skillsListResponse is the wire format for list/search responses.
type skillsListResponse struct {
Skills []*thvregistry.Skill `json:"skills"`
Metadata struct {
Count int `json:"count"`
NextCursor string `json:"nextCursor"`
} `json:"metadata"`
}

// doSkillsGet performs an HTTP GET request and decodes the JSON response into dest.
func (c *mcpSkillsClient) doSkillsGet(ctx context.Context, endpoint string, dest any) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("User-Agent", c.userAgent)

resp, err := c.httpClient.Do(req) //nolint:gosec // G704: URL from configured registry
if err != nil {
return fmt.Errorf("failed to execute request: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
slog.Debug("failed to close response body", "error", err)
}
}()

if resp.StatusCode != http.StatusOK {
return newRegistryHTTPError(resp)
}

if err := json.NewDecoder(resp.Body).Decode(dest); err != nil {
return fmt.Errorf("failed to decode response: %w", err)
}
return nil
}

// fetchSkillsPage fetches a single page of skills.
func (c *mcpSkillsClient) fetchSkillsPage(
ctx context.Context, cursor string, opts *SkillsListOptions,
) ([]*thvregistry.Skill, string, error) {
params := url.Values{}
if cursor != "" {
params.Add("cursor", cursor)
}
if opts.Limit > 0 {
params.Add("limit", fmt.Sprintf("%d", opts.Limit))
}
if opts.Search != "" {
params.Add("search", opts.Search)
}

basePath, err := url.JoinPath(c.baseURL, skillsBasePath)
if err != nil {
return nil, "", fmt.Errorf("failed to build skills URL: %w", err)
}
endpoint := func() string {
if len(params) > 0 {
return basePath + "?" + params.Encode()
}
return basePath
}()

var listResp skillsListResponse
if err := c.doSkillsGet(ctx, endpoint, &listResp); err != nil {
return nil, "", err
}

return listResp.Skills, listResp.Metadata.NextCursor, nil
}
Loading
Loading