From 49e9d1682403e79d4ef355966801e793b2482132 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Tue, 17 Feb 2026 00:55:13 +0530 Subject: [PATCH 1/3] ref: Implementation for Dynamic webhook phase one Signed-off-by: Sanskarzz --- pkg/webhook/client.go | 217 ++++++++++++++++ pkg/webhook/client_test.go | 494 ++++++++++++++++++++++++++++++++++++ pkg/webhook/errors.go | 86 +++++++ pkg/webhook/errors_test.go | 85 +++++++ pkg/webhook/signing.go | 56 ++++ pkg/webhook/signing_test.go | 159 ++++++++++++ pkg/webhook/types.go | 183 +++++++++++++ pkg/webhook/types_test.go | 164 ++++++++++++ 8 files changed, 1444 insertions(+) create mode 100644 pkg/webhook/client.go create mode 100644 pkg/webhook/client_test.go create mode 100644 pkg/webhook/errors.go create mode 100644 pkg/webhook/errors_test.go create mode 100644 pkg/webhook/signing.go create mode 100644 pkg/webhook/signing_test.go create mode 100644 pkg/webhook/types.go create mode 100644 pkg/webhook/types_test.go diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go new file mode 100644 index 0000000000..13cd6d3c50 --- /dev/null +++ b/pkg/webhook/client.go @@ -0,0 +1,217 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package webhook + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "strconv" + "time" +) + +// Client is an HTTP client for calling webhook endpoints. +type Client struct { + httpClient *http.Client + config Config + hmacSecret []byte + webhookType Type +} + +// NewClient creates a new webhook Client from the given configuration. +// The hmacSecret parameter is the resolved secret bytes for HMAC signing; +// pass nil if signing is not configured. +func NewClient(cfg Config, webhookType Type, hmacSecret []byte) (*Client, error) { + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("invalid webhook config: %w", err) + } + + timeout := cfg.Timeout + if timeout == 0 { + timeout = DefaultTimeout + } + + transport, err := buildTransport(cfg.TLSConfig) + if err != nil { + return nil, fmt.Errorf("failed to build HTTP transport: %w", err) + } + + return &Client{ + httpClient: &http.Client{ + Transport: transport, + Timeout: timeout, + }, + config: cfg, + hmacSecret: hmacSecret, + webhookType: webhookType, + }, nil +} + +// Call sends a request to a validating webhook and returns its response. +func (c *Client) Call(ctx context.Context, req *Request) (*Response, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, NewInvalidResponseError(c.config.Name, fmt.Errorf("failed to marshal request: %w", err)) + } + + respBody, err := c.doHTTPCall(ctx, body) + if err != nil { + return nil, err + } + + var resp Response + if err := json.Unmarshal(respBody, &resp); err != nil { + return nil, NewInvalidResponseError(c.config.Name, fmt.Errorf("failed to unmarshal response: %w", err)) + } + + return &resp, nil +} + +// CallMutating sends a request to a mutating webhook and returns its response. +func (c *Client) CallMutating(ctx context.Context, req *Request) (*MutatingResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, NewInvalidResponseError(c.config.Name, fmt.Errorf("failed to marshal request: %w", err)) + } + + respBody, err := c.doHTTPCall(ctx, body) + if err != nil { + return nil, err + } + + var resp MutatingResponse + if err := json.Unmarshal(respBody, &resp); err != nil { + return nil, NewInvalidResponseError(c.config.Name, fmt.Errorf("failed to unmarshal mutating response: %w", err)) + } + + return &resp, nil +} + +// doHTTPCall performs the HTTP POST to the webhook endpoint, handling signing, +// error classification, and response size limiting. +func (c *Client) doHTTPCall(ctx context.Context, body []byte) ([]byte, error) { + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.config.URL, bytes.NewReader(body)) + if err != nil { + return nil, NewNetworkError(c.config.Name, fmt.Errorf("failed to create HTTP request: %w", err)) + } + httpReq.Header.Set("Content-Type", "application/json") + + // Apply HMAC signing if configured. + if len(c.hmacSecret) > 0 { + timestamp := time.Now().Unix() + signature := SignPayload(c.hmacSecret, timestamp, body) + httpReq.Header.Set(SignatureHeader, signature) + httpReq.Header.Set(TimestampHeader, strconv.FormatInt(timestamp, 10)) + } + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, classifyError(c.config.Name, err) + } + defer func() { + _ = resp.Body.Close() + }() + + // Enforce response size limit. + limitedReader := io.LimitReader(resp.Body, MaxResponseSize+1) + respBody, err := io.ReadAll(limitedReader) + if err != nil { + return nil, NewNetworkError(c.config.Name, fmt.Errorf("failed to read response body: %w", err)) + } + if int64(len(respBody)) > MaxResponseSize { + return nil, NewInvalidResponseError(c.config.Name, + fmt.Errorf("response body exceeds maximum size of %d bytes", MaxResponseSize)) + } + + // 5xx errors indicate webhook operational failures. + if resp.StatusCode >= http.StatusInternalServerError { + return nil, NewNetworkError(c.config.Name, + fmt.Errorf("webhook returned HTTP %d: %s", resp.StatusCode, truncateBody(respBody))) + } + + // Non-200 responses (excluding 5xx handled above) are treated as invalid. + if resp.StatusCode != http.StatusOK { + return nil, NewInvalidResponseError(c.config.Name, + fmt.Errorf("webhook returned HTTP %d: %s", resp.StatusCode, truncateBody(respBody))) + } + + return respBody, nil +} + +// buildTransport creates an http.Transport with the specified TLS configuration. +func buildTransport(tlsCfg *TLSConfig) (*http.Transport, error) { + transport := &http.Transport{ + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 10 * time.Second, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + } + + if tlsCfg == nil { + return transport, nil + } + + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + } + + // Load CA bundle if provided. + if tlsCfg.CABundlePath != "" { + caCert, err := os.ReadFile(tlsCfg.CABundlePath) + if err != nil { + return nil, fmt.Errorf("failed to read CA bundle: %w", err) + } + caCertPool := x509.NewCertPool() + if !caCertPool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("failed to parse CA certificate bundle") + } + tlsConfig.RootCAs = caCertPool + } + + // Load client certificate for mTLS if provided. + if tlsCfg.ClientCertPath != "" && tlsCfg.ClientKeyPath != "" { + cert, err := tls.LoadX509KeyPair(tlsCfg.ClientCertPath, tlsCfg.ClientKeyPath) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + if tlsCfg.InsecureSkipVerify { + //#nosec G402 -- InsecureSkipVerify is intentionally user-configurable for development/testing only. + tlsConfig.InsecureSkipVerify = true + } + + transport.TLSClientConfig = tlsConfig + return transport, nil +} + +// classifyError examines an HTTP client error and returns an appropriately +// typed webhook error (TimeoutError or NetworkError). +func classifyError(webhookName string, err error) error { + // Check for timeout errors (context deadline, net.Error timeout). + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return NewTimeoutError(webhookName, err) + } + return NewNetworkError(webhookName, err) +} + +// truncateBody returns a preview of the response body for error messages. +func truncateBody(body []byte) string { + const maxPreview = 256 + if len(body) <= maxPreview { + return string(body) + } + return string(body[:maxPreview]) + "..." +} diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go new file mode 100644 index 0000000000..50cad2bbe7 --- /dev/null +++ b/pkg/webhook/client_test.go @@ -0,0 +1,494 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package webhook + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewClient(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config Config + expectError bool + }{ + { + name: "valid config", + config: Config{ + Name: "test", + URL: "https://example.com/webhook", + Timeout: 5 * time.Second, + FailurePolicy: FailurePolicyFail, + }, + expectError: false, + }, + { + name: "valid config with zero timeout", + config: Config{ + Name: "test", + URL: "https://example.com/webhook", + Timeout: 0, + FailurePolicy: FailurePolicyIgnore, + }, + expectError: false, + }, + { + name: "invalid config", + config: Config{ + Name: "", + URL: "", + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + client, err := NewClient(tt.config, TypeValidating, nil) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, client) + } else { + assert.NoError(t, err) + assert.NotNil(t, client) + } + }) + } +} + +func TestClientCallValidating(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + serverHandler http.HandlerFunc + expectError bool + expectedResult *Response + errorType interface{} + }{ + { + name: "allowed response", + serverHandler: func(w http.ResponseWriter, _ *http.Request) { + resp := Response{ + Version: APIVersion, + UID: "test-uid", + Allowed: true, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }, + expectedResult: &Response{ + Version: APIVersion, + UID: "test-uid", + Allowed: true, + }, + }, + { + name: "denied response", + serverHandler: func(w http.ResponseWriter, _ *http.Request) { + resp := Response{ + Version: APIVersion, + UID: "test-uid", + Allowed: false, + Code: 403, + Message: "Access denied", + Reason: "PolicyDenied", + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + }, + expectedResult: &Response{ + Version: APIVersion, + UID: "test-uid", + Allowed: false, + Code: 403, + Message: "Access denied", + Reason: "PolicyDenied", + }, + }, + { + name: "server 500 error", + serverHandler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("internal server error")) + }, + expectError: true, + errorType: &NetworkError{}, + }, + { + name: "server 503 error", + serverHandler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + w.Write([]byte("service unavailable")) + }, + expectError: true, + errorType: &NetworkError{}, + }, + { + name: "invalid JSON response", + serverHandler: func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte("not json")) + }, + expectError: true, + errorType: &InvalidResponseError{}, + }, + { + name: "non-200 non-5xx response", + serverHandler: func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("bad request")) + }, + expectError: true, + errorType: &InvalidResponseError{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(tt.serverHandler) + defer server.Close() + + cfg := Config{ + Name: "test-webhook", + URL: server.URL, + Timeout: 5 * time.Second, + FailurePolicy: FailurePolicyFail, + } + + client := newTestClient(cfg, TypeValidating, nil) + + req := &Request{ + Version: APIVersion, + UID: "test-uid", + Timestamp: time.Now(), + Principal: &Principal{Sub: "user1"}, + Context: &RequestContext{ + ServerName: "test-server", + SourceIP: "127.0.0.1", + Transport: "sse", + }, + } + + resp, err := client.Call(context.Background(), req) + + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, resp) + if tt.errorType != nil { + assert.True(t, errors.As(err, &tt.errorType), + "expected error type %T, got %T", tt.errorType, err) + } + } else { + require.NoError(t, err) + require.NotNil(t, resp) + assert.Equal(t, tt.expectedResult.Version, resp.Version) + assert.Equal(t, tt.expectedResult.UID, resp.UID) + assert.Equal(t, tt.expectedResult.Allowed, resp.Allowed) + assert.Equal(t, tt.expectedResult.Code, resp.Code) + assert.Equal(t, tt.expectedResult.Message, resp.Message) + assert.Equal(t, tt.expectedResult.Reason, resp.Reason) + } + }) + } +} + +func TestClientCallMutating(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + resp := MutatingResponse{ + Response: Response{ + Version: APIVersion, + UID: "test-uid", + Allowed: true, + }, + PatchType: "json_patch", + Patch: json.RawMessage(`[{"op":"add","path":"/mcp_request/params/arguments/audit","value":"true"}]`), + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := Config{ + Name: "test-mutating", + URL: server.URL, + Timeout: 5 * time.Second, + FailurePolicy: FailurePolicyIgnore, + } + + client := newTestClient(cfg, TypeMutating, nil) + + req := &Request{ + Version: APIVersion, + UID: "test-uid", + Timestamp: time.Now(), + Principal: &Principal{Sub: "user1"}, + Context: &RequestContext{ + ServerName: "test-server", + SourceIP: "127.0.0.1", + Transport: "sse", + }, + } + + resp, err := client.CallMutating(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.True(t, resp.Allowed) + assert.Equal(t, "json_patch", resp.PatchType) + assert.NotEmpty(t, resp.Patch) +} + +func TestClientHMACSigningHeaders(t *testing.T) { + t.Parallel() + + var capturedHeaders http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header + resp := Response{ + Version: APIVersion, + UID: "test-uid", + Allowed: true, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := Config{ + Name: "test-hmac", + URL: server.URL, + Timeout: 5 * time.Second, + FailurePolicy: FailurePolicyFail, + } + hmacSecret := []byte("test-secret-key") + + client := newTestClient(cfg, TypeValidating, hmacSecret) + + req := &Request{ + Version: APIVersion, + UID: "test-uid", + Timestamp: time.Now(), + Principal: &Principal{Sub: "user1"}, + Context: &RequestContext{ + ServerName: "test-server", + SourceIP: "127.0.0.1", + Transport: "sse", + }, + } + + resp, err := client.Call(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify HMAC headers were sent. + assert.NotEmpty(t, capturedHeaders.Get(SignatureHeader), "expected %s header", SignatureHeader) + assert.Contains(t, capturedHeaders.Get(SignatureHeader), "sha256=") + assert.NotEmpty(t, capturedHeaders.Get(TimestampHeader), "expected %s header", TimestampHeader) +} + +func TestClientNoHMACHeadersWithoutSecret(t *testing.T) { + t.Parallel() + + var capturedHeaders http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedHeaders = r.Header + resp := Response{ + Version: APIVersion, + UID: "test-uid", + Allowed: true, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := Config{ + Name: "test-no-hmac", + URL: server.URL, + Timeout: 5 * time.Second, + FailurePolicy: FailurePolicyFail, + } + + client := newTestClient(cfg, TypeValidating, nil) + + req := &Request{ + Version: APIVersion, + UID: "test-uid", + Timestamp: time.Now(), + Principal: &Principal{Sub: "user1"}, + Context: &RequestContext{ + ServerName: "test-server", + SourceIP: "127.0.0.1", + Transport: "sse", + }, + } + + resp, err := client.Call(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + + // Verify HMAC headers were NOT sent. + assert.Empty(t, capturedHeaders.Get(SignatureHeader)) + assert.Empty(t, capturedHeaders.Get(TimestampHeader)) +} + +func TestClientTimeout(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // Sleep longer than the client timeout. + time.Sleep(3 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + cfg := Config{ + Name: "test-timeout", + URL: server.URL, + Timeout: 100 * time.Millisecond, + FailurePolicy: FailurePolicyFail, + } + + client := newTestClient(cfg, TypeValidating, nil) + + req := &Request{ + Version: APIVersion, + UID: "test-uid", + Timestamp: time.Now(), + } + + _, err := client.Call(context.Background(), req) + require.Error(t, err) + + var timeoutErr *TimeoutError + assert.True(t, errors.As(err, &timeoutErr), "expected TimeoutError, got %T: %v", err, err) +} + +func TestClientResponseSizeLimit(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + // Write more than MaxResponseSize bytes. + largeBody := strings.Repeat("x", MaxResponseSize+100) + w.Write([]byte(largeBody)) + })) + defer server.Close() + + cfg := Config{ + Name: "test-size-limit", + URL: server.URL, + Timeout: 5 * time.Second, + FailurePolicy: FailurePolicyFail, + } + + client := newTestClient(cfg, TypeValidating, nil) + + req := &Request{ + Version: APIVersion, + UID: "test-uid", + Timestamp: time.Now(), + } + + _, err := client.Call(context.Background(), req) + require.Error(t, err) + + var invalidErr *InvalidResponseError + assert.True(t, errors.As(err, &invalidErr), "expected InvalidResponseError, got %T: %v", err, err) + assert.Contains(t, err.Error(), "exceeds maximum size") +} + +func TestClientRequestContentType(t *testing.T) { + t.Parallel() + + var capturedContentType string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedContentType = r.Header.Get("Content-Type") + // Verify request body is valid JSON. + body, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + var req Request + if err := json.Unmarshal(body, &req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + resp := Response{ + Version: APIVersion, + UID: req.UID, + Allowed: true, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + cfg := Config{ + Name: "test-content-type", + URL: server.URL, + Timeout: 5 * time.Second, + FailurePolicy: FailurePolicyFail, + } + + client := newTestClient(cfg, TypeValidating, nil) + + req := &Request{ + Version: APIVersion, + UID: "test-uid", + Timestamp: time.Now(), + Principal: &Principal{Sub: "user1"}, + Context: &RequestContext{ + ServerName: "test-server", + SourceIP: "127.0.0.1", + Transport: "sse", + }, + } + + resp, err := client.Call(context.Background(), req) + require.NoError(t, err) + require.NotNil(t, resp) + + assert.Equal(t, "application/json", capturedContentType) +} + +// newTestClient creates a webhook Client suitable for testing with httptest servers. +// It bypasses URL validation (httptest uses HTTP, not HTTPS). +func newTestClient(cfg Config, webhookType Type, hmacSecret []byte) *Client { + timeout := cfg.Timeout + if timeout == 0 { + timeout = DefaultTimeout + } + + return &Client{ + httpClient: &http.Client{ + Timeout: timeout, + }, + config: cfg, + hmacSecret: hmacSecret, + webhookType: webhookType, + } +} diff --git a/pkg/webhook/errors.go b/pkg/webhook/errors.go new file mode 100644 index 0000000000..6424f49a94 --- /dev/null +++ b/pkg/webhook/errors.go @@ -0,0 +1,86 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package webhook + +import "fmt" + +// WebhookError is the base error type for all webhook-related errors. +// +//nolint:revive // WebhookError is the canonical name; renaming to Error conflicts with Error() method. +type WebhookError struct { + // WebhookName is the name of the webhook that caused the error. + WebhookName string + // Err is the underlying error. + Err error +} + +// Error implements the error interface. +func (e *WebhookError) Error() string { + return fmt.Sprintf("webhook %q: %v", e.WebhookName, e.Err) +} + +// Unwrap returns the underlying error for errors.Is/errors.As support. +func (e *WebhookError) Unwrap() error { + return e.Err +} + +// TimeoutError indicates that a webhook call timed out. +type TimeoutError struct { + WebhookError +} + +// Error implements the error interface. +func (e *TimeoutError) Error() string { + return fmt.Sprintf("webhook %q: timeout: %v", e.WebhookName, e.Err) +} + +// NetworkError indicates a network-level failure when calling a webhook. +type NetworkError struct { + WebhookError +} + +// Error implements the error interface. +func (e *NetworkError) Error() string { + return fmt.Sprintf("webhook %q: network error: %v", e.WebhookName, e.Err) +} + +// InvalidResponseError indicates that a webhook returned an unparseable or invalid response. +type InvalidResponseError struct { + WebhookError +} + +// Error implements the error interface. +func (e *InvalidResponseError) Error() string { + return fmt.Sprintf("webhook %q: invalid response: %v", e.WebhookName, e.Err) +} + +// NewTimeoutError creates a new TimeoutError. +func NewTimeoutError(webhookName string, err error) *TimeoutError { + return &TimeoutError{ + WebhookError: WebhookError{ + WebhookName: webhookName, + Err: err, + }, + } +} + +// NewNetworkError creates a new NetworkError. +func NewNetworkError(webhookName string, err error) *NetworkError { + return &NetworkError{ + WebhookError: WebhookError{ + WebhookName: webhookName, + Err: err, + }, + } +} + +// NewInvalidResponseError creates a new InvalidResponseError. +func NewInvalidResponseError(webhookName string, err error) *InvalidResponseError { + return &InvalidResponseError{ + WebhookError: WebhookError{ + WebhookName: webhookName, + Err: err, + }, + } +} diff --git a/pkg/webhook/errors_test.go b/pkg/webhook/errors_test.go new file mode 100644 index 0000000000..a94feaca3e --- /dev/null +++ b/pkg/webhook/errors_test.go @@ -0,0 +1,85 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package webhook + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestWebhookErrors(t *testing.T) { + t.Parallel() + + underlyingErr := fmt.Errorf("connection refused") + + tests := []struct { + name string + err error + expectedMsg string + isTimeout bool + isNetwork bool + isInvalidResp bool + unwrapsToInner bool + }{ + { + name: "TimeoutError", + err: NewTimeoutError("my-webhook", underlyingErr), + expectedMsg: `webhook "my-webhook": timeout: connection refused`, + isTimeout: true, + unwrapsToInner: true, + }, + { + name: "NetworkError", + err: NewNetworkError("my-webhook", underlyingErr), + expectedMsg: `webhook "my-webhook": network error: connection refused`, + isNetwork: true, + unwrapsToInner: true, + }, + { + name: "InvalidResponseError", + err: NewInvalidResponseError("my-webhook", underlyingErr), + expectedMsg: `webhook "my-webhook": invalid response: connection refused`, + isInvalidResp: true, + unwrapsToInner: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + assert.Equal(t, tt.expectedMsg, tt.err.Error()) + + // Test errors.As for each type. + var timeoutErr *TimeoutError + assert.Equal(t, tt.isTimeout, errors.As(tt.err, &timeoutErr)) + + var networkErr *NetworkError + assert.Equal(t, tt.isNetwork, errors.As(tt.err, &networkErr)) + + var invalidRespErr *InvalidResponseError + assert.Equal(t, tt.isInvalidResp, errors.As(tt.err, &invalidRespErr)) + + // Test Unwrap chain reaches the underlying error. + if tt.unwrapsToInner { + require.True(t, errors.Is(tt.err, underlyingErr), + "expected error to unwrap to underlying error") + } + }) + } +} + +func TestWebhookErrorBaseType(t *testing.T) { + t.Parallel() + + inner := fmt.Errorf("some error") + err := &WebhookError{WebhookName: "base-test", Err: inner} + + assert.Equal(t, `webhook "base-test": some error`, err.Error()) + assert.Equal(t, inner, err.Unwrap()) +} diff --git a/pkg/webhook/signing.go b/pkg/webhook/signing.go new file mode 100644 index 0000000000..fe4aff86be --- /dev/null +++ b/pkg/webhook/signing.go @@ -0,0 +1,56 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package webhook + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" +) + +// Header names for webhook HMAC signing. +const ( + // SignatureHeader is the HTTP header containing the HMAC signature. + SignatureHeader = "X-ToolHive-Signature" + // TimestampHeader is the HTTP header containing the Unix timestamp. + TimestampHeader = "X-ToolHive-Timestamp" +) + +// signaturePrefix is the prefix for the HMAC-SHA256 signature value. +const signaturePrefix = "sha256=" + +// SignPayload computes an HMAC-SHA256 signature over the given timestamp and +// payload. The signature is computed over the string "timestamp.payload" and +// returned in the format "sha256=". +func SignPayload(secret []byte, timestamp int64, payload []byte) string { + mac := hmac.New(sha256.New, secret) + // Write the message: "timestamp.payload" + msg := fmt.Sprintf("%d.", timestamp) + mac.Write([]byte(msg)) + mac.Write(payload) + return signaturePrefix + hex.EncodeToString(mac.Sum(nil)) +} + +// VerifySignature verifies an HMAC-SHA256 signature against the given timestamp +// and payload. The signature should be in the format "sha256=". +// Comparison is done in constant time to prevent timing attacks. +func VerifySignature(secret []byte, timestamp int64, payload []byte, signature string) bool { + if !strings.HasPrefix(signature, signaturePrefix) { + return false + } + + sigBytes, err := hex.DecodeString(strings.TrimPrefix(signature, signaturePrefix)) + if err != nil { + return false + } + + mac := hmac.New(sha256.New, secret) + msg := fmt.Sprintf("%d.", timestamp) + mac.Write([]byte(msg)) + mac.Write(payload) + + return hmac.Equal(mac.Sum(nil), sigBytes) +} diff --git a/pkg/webhook/signing_test.go b/pkg/webhook/signing_test.go new file mode 100644 index 0000000000..576b1b7638 --- /dev/null +++ b/pkg/webhook/signing_test.go @@ -0,0 +1,159 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package webhook + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSignPayload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + secret []byte + timestamp int64 + payload []byte + }{ + { + name: "basic payload", + secret: []byte("my-secret"), + timestamp: 1698057000, + payload: []byte(`{"version":"v0.1.0","uid":"test-uid"}`), + }, + { + name: "empty payload", + secret: []byte("my-secret"), + timestamp: 1698057000, + payload: []byte{}, + }, + { + name: "large payload", + secret: []byte("another-secret"), + timestamp: 9999999999, + payload: []byte(`{"key":"` + string(make([]byte, 1024)) + `"}`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + sig := SignPayload(tt.secret, tt.timestamp, tt.payload) + assert.NotEmpty(t, sig) + assert.Contains(t, sig, "sha256=") + + // Round-trip: signature must verify. + assert.True(t, VerifySignature(tt.secret, tt.timestamp, tt.payload, sig), + "signature round-trip verification failed") + }) + } +} + +func TestVerifySignature(t *testing.T) { + t.Parallel() + + secret := []byte("test-secret") + timestamp := int64(1698057000) + payload := []byte(`{"version":"v0.1.0","uid":"test"}`) + validSig := SignPayload(secret, timestamp, payload) + + tests := []struct { + name string + secret []byte + timestamp int64 + payload []byte + signature string + expected bool + }{ + { + name: "valid signature", + secret: secret, + timestamp: timestamp, + payload: payload, + signature: validSig, + expected: true, + }, + { + name: "wrong secret", + secret: []byte("wrong-secret"), + timestamp: timestamp, + payload: payload, + signature: validSig, + expected: false, + }, + { + name: "wrong timestamp", + secret: secret, + timestamp: timestamp + 1, + payload: payload, + signature: validSig, + expected: false, + }, + { + name: "tampered payload", + secret: secret, + timestamp: timestamp, + payload: []byte(`{"version":"v0.1.0","uid":"TAMPERED"}`), + signature: validSig, + expected: false, + }, + { + name: "missing sha256 prefix", + secret: secret, + timestamp: timestamp, + payload: payload, + signature: "abcdef1234567890", + expected: false, + }, + { + name: "invalid hex after prefix", + secret: secret, + timestamp: timestamp, + payload: payload, + signature: "sha256=not-valid-hex!", + expected: false, + }, + { + name: "empty signature", + secret: secret, + timestamp: timestamp, + payload: payload, + signature: "", + expected: false, + }, + { + name: "sha256= prefix only", + secret: secret, + timestamp: timestamp, + payload: payload, + signature: "sha256=", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + result := VerifySignature(tt.secret, tt.timestamp, tt.payload, tt.signature) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestSignPayloadDeterministic(t *testing.T) { + t.Parallel() + + secret := []byte("deterministic-test") + timestamp := int64(1234567890) + payload := []byte("test-payload") + + sig1 := SignPayload(secret, timestamp, payload) + sig2 := SignPayload(secret, timestamp, payload) + + assert.Equal(t, sig1, sig2, "same inputs must produce the same signature") +} diff --git a/pkg/webhook/types.go b/pkg/webhook/types.go new file mode 100644 index 0000000000..5c5b80f711 --- /dev/null +++ b/pkg/webhook/types.go @@ -0,0 +1,183 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package webhook implements the core types, HTTP client, HMAC signing, +// and error handling for ToolHive's dynamic webhook middleware system. +package webhook + +import ( + "encoding/json" + "fmt" + "net/url" + "time" +) + +// APIVersion is the version of the webhook API protocol. +const APIVersion = "v0.1.0" + +// DefaultTimeout is the default timeout for webhook HTTP calls. +const DefaultTimeout = 10 * time.Second + +// MaxTimeout is the maximum allowed timeout for webhook HTTP calls. +const MaxTimeout = 30 * time.Second + +// MaxResponseSize is the maximum allowed size in bytes for webhook responses (1 MB). +const MaxResponseSize = 1 << 20 + +// Type indicates whether a webhook is validating or mutating. +type Type string + +const ( + // TypeValidating indicates a validating webhook that accepts or denies requests. + TypeValidating Type = "validating" + // TypeMutating indicates a mutating webhook that transforms requests. + TypeMutating Type = "mutating" +) + +// FailurePolicy defines how webhook errors are handled. +type FailurePolicy string + +const ( + // FailurePolicyFail denies the request on webhook error (fail-closed). + FailurePolicyFail FailurePolicy = "fail" + // FailurePolicyIgnore allows the request on webhook error (fail-open). + FailurePolicyIgnore FailurePolicy = "ignore" +) + +// TLSConfig holds TLS-related configuration for webhook HTTP communication. +type TLSConfig struct { + // CABundlePath is the path to a CA certificate bundle for server verification. + CABundlePath string `json:"ca_bundle_path,omitempty"` + // ClientCertPath is the path to a client certificate for mTLS. + ClientCertPath string `json:"client_cert_path,omitempty"` + // ClientKeyPath is the path to a client key for mTLS. + ClientKeyPath string `json:"client_key_path,omitempty"` + // InsecureSkipVerify disables server certificate verification. + // WARNING: This should only be used for development/testing. + InsecureSkipVerify bool `json:"insecure_skip_verify,omitempty"` +} + +// Config holds the configuration for a single webhook. +type Config struct { + // Name is a unique identifier for this webhook. + Name string `json:"name"` + // URL is the HTTPS endpoint to call. + URL string `json:"url"` + // Timeout is the maximum time to wait for a webhook response. + Timeout time.Duration `json:"timeout"` + // FailurePolicy determines behavior when the webhook call fails. + FailurePolicy FailurePolicy `json:"failure_policy"` + // TLSConfig holds optional TLS configuration (CA bundles, client certs). + TLSConfig *TLSConfig `json:"tls_config,omitempty"` + // HMACSecretRef is an optional reference to an HMAC secret for payload signing. + HMACSecretRef string `json:"hmac_secret_ref,omitempty"` +} + +// Validate checks that the WebhookConfig has valid required fields. +func (c *Config) Validate() error { + if c.Name == "" { + return fmt.Errorf("webhook name is required") + } + if c.URL == "" { + return fmt.Errorf("webhook URL is required") + } + if _, err := url.ParseRequestURI(c.URL); err != nil { + return fmt.Errorf("webhook URL is invalid: %w", err) + } + if c.FailurePolicy != FailurePolicyFail && c.FailurePolicy != FailurePolicyIgnore { + return fmt.Errorf("webhook failure_policy must be %q or %q, got %q", + FailurePolicyFail, FailurePolicyIgnore, c.FailurePolicy) + } + if c.Timeout < 0 { + return fmt.Errorf("webhook timeout must be non-negative") + } + if c.Timeout > MaxTimeout { + return fmt.Errorf("webhook timeout %v exceeds maximum %v", c.Timeout, MaxTimeout) + } + if c.TLSConfig != nil { + if err := validateTLSConfig(c.TLSConfig); err != nil { + return fmt.Errorf("webhook TLS config: %w", err) + } + } + return nil +} + +// Request is the payload sent to webhook endpoints. +type Request struct { + // Version is the webhook API protocol version. + Version string `json:"version"` + // UID is a unique identifier for this request, used for idempotency. + UID string `json:"uid"` + // Timestamp is when the request was created. + Timestamp time.Time `json:"timestamp"` + // Principal contains the authenticated user's identity information. + Principal *Principal `json:"principal"` + // MCPRequest is the raw MCP JSON-RPC request. + MCPRequest json.RawMessage `json:"mcp_request"` + // Context provides additional metadata about the request origin. + Context *RequestContext `json:"context"` +} + +// Principal contains the authenticated user's identity information. +type Principal struct { + // Sub is the subject identifier (user ID). + Sub string `json:"sub"` + // Email is the user's email address. + Email string `json:"email,omitempty"` + // Name is the user's display name. + Name string `json:"name,omitempty"` + // Groups is a list of groups the user belongs to. + Groups []string `json:"groups,omitempty"` + // Claims contains additional identity claims. + Claims map[string]string `json:"claims,omitempty"` +} + +// RequestContext provides metadata about the request origin and environment. +type RequestContext struct { + // ServerName is the ToolHive/vMCP instance name handling the request. + ServerName string `json:"server_name"` + // BackendServer is the actual MCP server being proxied (when using vMCP). + BackendServer string `json:"backend_server,omitempty"` + // Namespace is the Kubernetes namespace, if applicable. + Namespace string `json:"namespace,omitempty"` + // SourceIP is the client's IP address. + SourceIP string `json:"source_ip"` + // Transport is the connection transport type (e.g., "sse", "stdio"). + Transport string `json:"transport"` +} + +// Response is the response from a validating webhook. +type Response struct { + // Version is the webhook API protocol version. + Version string `json:"version"` + // UID is the unique request identifier, echoed back for correlation. + UID string `json:"uid"` + // Allowed indicates whether the request is permitted. + Allowed bool `json:"allowed"` + // Code is an optional HTTP status code for denied requests. + Code int `json:"code,omitempty"` + // Message is an optional human-readable explanation. + Message string `json:"message,omitempty"` + // Reason is an optional machine-readable denial reason. + Reason string `json:"reason,omitempty"` + // Details contains optional structured information about the denial. + Details map[string]string `json:"details,omitempty"` +} + +// MutatingResponse is the response from a mutating webhook. +type MutatingResponse struct { + Response + // PatchType indicates the type of patch (e.g., "json_patch"). + PatchType string `json:"patch_type,omitempty"` + // Patch contains the JSON Patch operations to apply. + Patch json.RawMessage `json:"patch,omitempty"` +} + +// validateTLSConfig validates the TLS configuration for consistency. +func validateTLSConfig(cfg *TLSConfig) error { + // If one of client cert/key is provided, both must be present. + if (cfg.ClientCertPath == "") != (cfg.ClientKeyPath == "") { + return fmt.Errorf("both client_cert_path and client_key_path must be provided for mTLS") + } + return nil +} diff --git a/pkg/webhook/types_test.go b/pkg/webhook/types_test.go new file mode 100644 index 0000000000..4cd2fe236f --- /dev/null +++ b/pkg/webhook/types_test.go @@ -0,0 +1,164 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package webhook + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestConfigValidate(t *testing.T) { + t.Parallel() + + validConfig := func() Config { + return Config{ + Name: "test-webhook", + URL: "https://example.com/webhook", + Timeout: 5 * time.Second, + FailurePolicy: FailurePolicyFail, + } + } + + tests := []struct { + name string + modify func(*Config) + expectError bool + errorContains string + }{ + { + name: "valid config with fail policy", + modify: func(_ *Config) {}, + expectError: false, + }, + { + name: "valid config with ignore policy", + modify: func(c *Config) { + c.FailurePolicy = FailurePolicyIgnore + }, + expectError: false, + }, + { + name: "valid config with zero timeout (uses default)", + modify: func(c *Config) { + c.Timeout = 0 + }, + expectError: false, + }, + { + name: "valid config with TLS", + modify: func(c *Config) { + c.TLSConfig = &TLSConfig{ + CABundlePath: "/path/to/ca.crt", + ClientCertPath: "/path/to/cert.pem", + ClientKeyPath: "/path/to/key.pem", + } + }, + expectError: false, + }, + { + name: "missing name", + modify: func(c *Config) { + c.Name = "" + }, + expectError: true, + errorContains: "name is required", + }, + { + name: "missing URL", + modify: func(c *Config) { + c.URL = "" + }, + expectError: true, + errorContains: "URL is required", + }, + { + name: "invalid URL", + modify: func(c *Config) { + c.URL = "not a url" + }, + expectError: true, + errorContains: "URL is invalid", + }, + { + name: "invalid failure policy", + modify: func(c *Config) { + c.FailurePolicy = "invalid" + }, + expectError: true, + errorContains: "failure_policy", + }, + { + name: "negative timeout", + modify: func(c *Config) { + c.Timeout = -1 * time.Second + }, + expectError: true, + errorContains: "non-negative", + }, + { + name: "timeout exceeds max", + modify: func(c *Config) { + c.Timeout = MaxTimeout + time.Second + }, + expectError: true, + errorContains: "exceeds maximum", + }, + { + name: "mTLS with only cert", + modify: func(c *Config) { + c.TLSConfig = &TLSConfig{ + ClientCertPath: "/path/to/cert.pem", + } + }, + expectError: true, + errorContains: "both client_cert_path and client_key_path", + }, + { + name: "mTLS with only key", + modify: func(c *Config) { + c.TLSConfig = &TLSConfig{ + ClientKeyPath: "/path/to/key.pem", + } + }, + expectError: true, + errorContains: "both client_cert_path and client_key_path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + cfg := validConfig() + tt.modify(&cfg) + + err := cfg.Validate() + + if tt.expectError { + assert.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestTypeConstants(t *testing.T) { + t.Parallel() + + assert.Equal(t, Type("validating"), TypeValidating) + assert.Equal(t, Type("mutating"), TypeMutating) +} + +func TestFailurePolicyConstants(t *testing.T) { + t.Parallel() + + assert.Equal(t, FailurePolicy("fail"), FailurePolicyFail) + assert.Equal(t, FailurePolicy("ignore"), FailurePolicyIgnore) +} From ec85a1058b2720cb1fda42f7c56718828e4f8358 Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Tue, 17 Feb 2026 22:44:07 +0530 Subject: [PATCH 2/3] fix: CI error and codecov Signed-off-by: Sanskarzz --- pkg/webhook/client_test.go | 220 +++++++++++++++++++++++++++++++++++++ pkg/webhook/errors.go | 2 +- 2 files changed, 221 insertions(+), 1 deletion(-) diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index 50cad2bbe7..2a28b97a5d 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -10,6 +10,8 @@ import ( "io" "net/http" "net/http/httptest" + "os" + "path/filepath" "strings" "testing" "time" @@ -475,6 +477,224 @@ func TestClientRequestContentType(t *testing.T) { assert.Equal(t, "application/json", capturedContentType) } +func TestBuildTransport(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + caFile := filepath.Join(tmpDir, "ca.crt") + err := os.WriteFile(caFile, []byte("invalid-ca"), 0600) + require.NoError(t, err) + + certFile := filepath.Join(tmpDir, "client.crt") + keyFile := filepath.Join(tmpDir, "client.key") + err = os.WriteFile(certFile, []byte("invalid-cert"), 0600) + require.NoError(t, err) + err = os.WriteFile(keyFile, []byte("invalid-key"), 0600) + require.NoError(t, err) + + tests := []struct { + name string + tlsCfg *TLSConfig + expectError bool + }{ + { + name: "nil config", + tlsCfg: nil, + expectError: false, + }, + { + name: "insecure skip verify", + tlsCfg: &TLSConfig{ + InsecureSkipVerify: true, + }, + expectError: false, + }, + { + name: "non-existent ca bundle", + tlsCfg: &TLSConfig{ + CABundlePath: "/non/existent/path", + }, + expectError: true, + }, + { + name: "invalid ca bundle content", + tlsCfg: &TLSConfig{ + CABundlePath: caFile, + }, + expectError: true, + }, + { + name: "non-existent client cert", + tlsCfg: &TLSConfig{ + ClientCertPath: "/non/existent/cert", + ClientKeyPath: keyFile, + }, + expectError: true, + }, + { + name: "invalid client cert/key", + tlsCfg: &TLSConfig{ + ClientCertPath: certFile, + ClientKeyPath: keyFile, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + transport, err := buildTransport(tt.tlsCfg) + if tt.expectError { + assert.Error(t, err) + assert.Nil(t, transport) + } else { + assert.NoError(t, err) + assert.NotNil(t, transport) + if tt.tlsCfg != nil && tt.tlsCfg.InsecureSkipVerify { + assert.True(t, transport.TLSClientConfig.InsecureSkipVerify) + } + } + }) + } +} + +func TestClassifyError(t *testing.T) { + t.Parallel() + + t.Run("non-timeout network error", func(t *testing.T) { + err := errors.New("connection refused") + classified := classifyError("test", err) + var netErr *NetworkError + assert.True(t, errors.As(classified, &netErr)) + }) +} + +func TestTruncateBody(t *testing.T) { + t.Parallel() + + t.Run("short body", func(t *testing.T) { + body := []byte("short") + assert.Equal(t, "short", truncateBody(body)) + }) + + t.Run("long body", func(t *testing.T) { + body := []byte(strings.Repeat("a", 300)) + truncated := truncateBody(body) + assert.Equal(t, 256+3, len(truncated)) + assert.True(t, strings.HasSuffix(truncated, "...")) + }) +} + +func TestClientCallErrors(t *testing.T) { + t.Parallel() + + cfg := Config{ + Name: "error-test", + URL: "invalid URL \x00", // Will cause http.NewRequest to fail + Timeout: 1 * time.Second, + } + client := newTestClient(cfg, TypeValidating, nil) + + t.Run("request creation failure", func(t *testing.T) { + _, err := client.Call(context.Background(), &Request{}) + assert.Error(t, err) + var networkErr *NetworkError + assert.True(t, errors.As(err, &networkErr)) + }) + + t.Run("unmarshal failure Call", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("not-json")) + })) + defer server.Close() + + client.config.URL = server.URL + _, err := client.Call(context.Background(), &Request{}) + assert.Error(t, err) + var invalidErr *InvalidResponseError + assert.True(t, errors.As(err, &invalidErr)) + }) + + t.Run("unmarshal failure CallMutating", func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("not-json")) + })) + defer server.Close() + + client.config.URL = server.URL + _, err := client.CallMutating(context.Background(), &Request{}) + assert.Error(t, err) + var invalidErr *InvalidResponseError + assert.True(t, errors.As(err, &invalidErr)) + }) + + t.Run("doHTTPCall failure CallMutating", func(t *testing.T) { + client.config.URL = "http://invalid-address.local" + _, err := client.CallMutating(context.Background(), &Request{}) + assert.Error(t, err) + }) +} + +type errorReader struct{} + +func (e *errorReader) Read(_ []byte) (n int, err error) { + return 0, errors.New("forced read error") +} +func (e *errorReader) Close() error { return nil } + +func TestDoHTTPCallReadError(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // This handler won't be reached because we're testing doHTTPCall error path + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // We need a server that returns a body that fails on Read + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Length", "10") + w.WriteHeader(http.StatusOK) + // We can't easily force Read to fail from net/http handler, + // but we can mock the http client or its transport. + })) + defer ts.Close() + + cfg := Config{ + Name: "read-err", + URL: ts.URL, + FailurePolicy: FailurePolicyFail, + } + client, err := NewClient(cfg, TypeValidating, nil) + require.NoError(t, err) + + // Mock the RoundTripper to return a body that fails on Read + rt := &mockRoundTripper{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Body: &errorReader{}, + }, + } + client.httpClient.Transport = rt + + _, err = client.Call(context.Background(), &Request{}) + assert.Error(t, err) + var networkErr *NetworkError + assert.True(t, errors.As(err, &networkErr)) + assert.Contains(t, err.Error(), "forced read error") +} + +type mockRoundTripper struct { + resp *http.Response + err error +} + +func (m *mockRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) { + return m.resp, m.err +} + // newTestClient creates a webhook Client suitable for testing with httptest servers. // It bypasses URL validation (httptest uses HTTP, not HTTPS). func newTestClient(cfg Config, webhookType Type, hmacSecret []byte) *Client { diff --git a/pkg/webhook/errors.go b/pkg/webhook/errors.go index 6424f49a94..d4c06bd090 100644 --- a/pkg/webhook/errors.go +++ b/pkg/webhook/errors.go @@ -45,7 +45,7 @@ func (e *NetworkError) Error() string { return fmt.Sprintf("webhook %q: network error: %v", e.WebhookName, e.Err) } -// InvalidResponseError indicates that a webhook returned an unparseable or invalid response. +// InvalidResponseError indicates that a webhook returned an unparsable or invalid response. type InvalidResponseError struct { WebhookError } From 5ab0f59dae22db16ecd5023cf3cf8b7942249acc Mon Sep 17 00:00:00 2001 From: Sanskarzz Date: Wed, 18 Feb 2026 00:20:23 +0530 Subject: [PATCH 3/3] fix: CI lint error fix Signed-off-by: Sanskarzz --- pkg/webhook/client.go | 13 +++++++--- pkg/webhook/client_test.go | 51 +++++++++++++++++++++++++++++--------- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/pkg/webhook/client.go b/pkg/webhook/client.go index 13cd6d3c50..f10f63ed5c 100644 --- a/pkg/webhook/client.go +++ b/pkg/webhook/client.go @@ -17,6 +17,8 @@ import ( "os" "strconv" "time" + + "github.com/stacklok/toolhive/pkg/networking" ) // Client is an HTTP client for calling webhook endpoints. @@ -113,6 +115,7 @@ func (c *Client) doHTTPCall(ctx context.Context, body []byte) ([]byte, error) { httpReq.Header.Set(TimestampHeader, strconv.FormatInt(timestamp, 10)) } + // #nosec G704 -- URL is validated in Config.Validate and we use ValidatingTransport for SSRF protection. resp, err := c.httpClient.Do(httpReq) if err != nil { return nil, classifyError(c.config.Name, err) @@ -147,8 +150,9 @@ func (c *Client) doHTTPCall(ctx context.Context, body []byte) ([]byte, error) { return respBody, nil } -// buildTransport creates an http.Transport with the specified TLS configuration. -func buildTransport(tlsCfg *TLSConfig) (*http.Transport, error) { +// buildTransport creates an http.RoundTripper with the specified TLS configuration, +// wrapped in a ValidatingTransport for security. +func buildTransport(tlsCfg *TLSConfig) (http.RoundTripper, error) { transport := &http.Transport{ TLSHandshakeTimeout: 10 * time.Second, ResponseHeaderTimeout: 10 * time.Second, @@ -193,7 +197,10 @@ func buildTransport(tlsCfg *TLSConfig) (*http.Transport, error) { } transport.TLSClientConfig = tlsConfig - return transport, nil + return &networking.ValidatingTransport{ + Transport: transport, + InsecureAllowHTTP: false, + }, nil } // classifyError examines an HTTP client error and returns an appropriately diff --git a/pkg/webhook/client_test.go b/pkg/webhook/client_test.go index 2a28b97a5d..22d7c90c2c 100644 --- a/pkg/webhook/client_test.go +++ b/pkg/webhook/client_test.go @@ -18,6 +18,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/stacklok/toolhive/pkg/networking" ) func TestNewClient(t *testing.T) { @@ -543,6 +545,7 @@ func TestBuildTransport(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + t.Parallel() transport, err := buildTransport(tt.tlsCfg) if tt.expectError { assert.Error(t, err) @@ -551,7 +554,11 @@ func TestBuildTransport(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, transport) if tt.tlsCfg != nil && tt.tlsCfg.InsecureSkipVerify { - assert.True(t, transport.TLSClientConfig.InsecureSkipVerify) + vt, ok := transport.(*networking.ValidatingTransport) + require.True(t, ok, "expected *networking.ValidatingTransport") + tr, ok := vt.Transport.(*http.Transport) + require.True(t, ok, "expected *http.Transport") + assert.True(t, tr.TLSClientConfig.InsecureSkipVerify) } } }) @@ -562,6 +569,7 @@ func TestClassifyError(t *testing.T) { t.Parallel() t.Run("non-timeout network error", func(t *testing.T) { + t.Parallel() err := errors.New("connection refused") classified := classifyError("test", err) var netErr *NetworkError @@ -573,11 +581,13 @@ func TestTruncateBody(t *testing.T) { t.Parallel() t.Run("short body", func(t *testing.T) { + t.Parallel() body := []byte("short") assert.Equal(t, "short", truncateBody(body)) }) t.Run("long body", func(t *testing.T) { + t.Parallel() body := []byte(strings.Repeat("a", 300)) truncated := truncateBody(body) assert.Equal(t, 256+3, len(truncated)) @@ -588,14 +598,14 @@ func TestTruncateBody(t *testing.T) { func TestClientCallErrors(t *testing.T) { t.Parallel() - cfg := Config{ + client := newTestClient(Config{ Name: "error-test", URL: "invalid URL \x00", // Will cause http.NewRequest to fail Timeout: 1 * time.Second, - } - client := newTestClient(cfg, TypeValidating, nil) + }, TypeValidating, nil) t.Run("request creation failure", func(t *testing.T) { + t.Parallel() _, err := client.Call(context.Background(), &Request{}) assert.Error(t, err) var networkErr *NetworkError @@ -603,46 +613,63 @@ func TestClientCallErrors(t *testing.T) { }) t.Run("unmarshal failure Call", func(t *testing.T) { + t.Parallel() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("not-json")) })) defer server.Close() - client.config.URL = server.URL - _, err := client.Call(context.Background(), &Request{}) + testClient := newTestClient(Config{ + Name: "unmarshal-fail", + URL: server.URL, + FailurePolicy: FailurePolicyFail, + }, TypeValidating, nil) + + _, err := testClient.Call(context.Background(), &Request{}) assert.Error(t, err) var invalidErr *InvalidResponseError assert.True(t, errors.As(err, &invalidErr)) }) t.Run("unmarshal failure CallMutating", func(t *testing.T) { + t.Parallel() server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) w.Write([]byte("not-json")) })) defer server.Close() - client.config.URL = server.URL - _, err := client.CallMutating(context.Background(), &Request{}) + testClient := newTestClient(Config{ + Name: "unmarshal-fail-mutating", + URL: server.URL, + FailurePolicy: FailurePolicyFail, + }, TypeMutating, nil) + + _, err := testClient.CallMutating(context.Background(), &Request{}) assert.Error(t, err) var invalidErr *InvalidResponseError assert.True(t, errors.As(err, &invalidErr)) }) t.Run("doHTTPCall failure CallMutating", func(t *testing.T) { - client.config.URL = "http://invalid-address.local" - _, err := client.CallMutating(context.Background(), &Request{}) + t.Parallel() + testClient := newTestClient(Config{ + Name: "http-fail", + URL: "http://invalid-address.local", + FailurePolicy: FailurePolicyFail, + }, TypeMutating, nil) + _, err := testClient.CallMutating(context.Background(), &Request{}) assert.Error(t, err) }) } type errorReader struct{} -func (e *errorReader) Read(_ []byte) (n int, err error) { +func (*errorReader) Read(_ []byte) (n int, err error) { return 0, errors.New("forced read error") } -func (e *errorReader) Close() error { return nil } +func (*errorReader) Close() error { return nil } func TestDoHTTPCallReadError(t *testing.T) { t.Parallel()