diff --git a/internal/attestation/attestation.go b/internal/attestation/attestation.go index e9b04e98..01a74dd4 100644 --- a/internal/attestation/attestation.go +++ b/internal/attestation/attestation.go @@ -322,8 +322,14 @@ func (m *Manager) getNonce(jwtToken string, machineId string) (string, time.Time req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwtToken)) - // Make the HTTP request - client := &http.Client{Timeout: 30 * time.Second} + // Make the HTTP request. Disable redirects so a compromised backend + // cannot bounce us to an internal service (SSRF). + client := &http.Client{ + Timeout: 30 * time.Second, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, + } resp, err := client.Do(req) if err != nil { log.Logger.Debugw("failed to make POST request in nonce endpoint request", "error", err) @@ -343,6 +349,10 @@ func (m *Manager) getNonce(jwtToken string, machineId string) (string, time.Time "status", resp.Status, "content_type", resp.Header.Get("Content-Type")) + if resp.StatusCode != http.StatusOK { + return "", time.Time{}, fmt.Errorf("nonce endpoint returned HTTP %d", resp.StatusCode) + } + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { log.Logger.Debugw("failed to decode response in nonce endpoint request", "error", err) return "", time.Time{}, err @@ -350,10 +360,11 @@ func (m *Manager) getNonce(jwtToken string, machineId string) (string, time.Time if response.Error != "" { log.Logger.Debugw("error from server in nonce endpoint request", "error", response.Error) - } else { - log.Logger.Debugw("Nonce received from server", "nonce_refresh_timestamp", response.NonceRefreshTimestamp) + return "", time.Time{}, fmt.Errorf("nonce endpoint returned error: %s", response.Error) } + log.Logger.Debugw("Nonce received from server", "nonce_refresh_timestamp", response.NonceRefreshTimestamp) + return response.Nonce, response.NonceRefreshTimestamp, nil } diff --git a/internal/attestation/attestation_test.go b/internal/attestation/attestation_test.go index 457d5383..27165747 100644 --- a/internal/attestation/attestation_test.go +++ b/internal/attestation/attestation_test.go @@ -24,6 +24,7 @@ import ( "net/http/httptest" "path/filepath" "strings" + "sync/atomic" "testing" "time" @@ -193,6 +194,16 @@ type testNonceResponse struct { Error string `json:"error,omitempty"` } +func useDefaultTransport(t *testing.T, transport http.RoundTripper) { + t.Helper() + + orig := http.DefaultTransport + http.DefaultTransport = transport + t.Cleanup(func() { + http.DefaultTransport = orig + }) +} + func TestManager_GetNonce_MockHTTP(t *testing.T) { // Test the nonce parsing logic without actually calling the private method // This tests the HTTP interaction pattern @@ -292,6 +303,59 @@ func TestManager_GetNonce_ServerError(t *testing.T) { assert.True(t, response.NonceRefreshTimestamp.IsZero()) } +func TestManager_GetNonce_RejectsNonOKStatus(t *testing.T) { + manager := newTestManager(t) + var redirectTargetCalled atomic.Bool + var server *httptest.Server + server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/redirected" { + redirectTargetCalled.Store(true) + w.WriteHeader(http.StatusOK) + return + } + http.Redirect(w, r, server.URL+"/redirected", http.StatusFound) + })) + defer server.Close() + + useDefaultTransport(t, server.Client().Transport) + stateFile := setupAttestationMetadataDB(t, map[string]string{ + "nonce_endpoint": server.URL, + }) + useTestStateFile(t, stateFile) + + nonce, refresh, err := manager.getNonce("test-jwt-token", "test-machine-id") + + require.Error(t, err) + assert.Contains(t, err.Error(), "nonce endpoint returned HTTP 302") + assert.Empty(t, nonce) + assert.True(t, refresh.IsZero()) + assert.False(t, redirectTargetCalled.Load()) +} + +func TestManager_GetNonce_RejectsServerErrorPayload(t *testing.T) { + manager := newTestManager(t) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + require.NoError(t, json.NewEncoder(w).Encode(map[string]string{ + "error": "Invalid token", + })) + })) + defer server.Close() + + useDefaultTransport(t, server.Client().Transport) + stateFile := setupAttestationMetadataDB(t, map[string]string{ + "nonce_endpoint": server.URL, + }) + useTestStateFile(t, stateFile) + + nonce, refresh, err := manager.getNonce("test-jwt-token", "test-machine-id") + + require.Error(t, err) + assert.Contains(t, err.Error(), "nonce endpoint returned error: Invalid token") + assert.Empty(t, nonce) + assert.True(t, refresh.IsZero()) +} + func TestManager_GetValidatedNonceEndpoint_UsesStoredNonceEndpoint(t *testing.T) { manager := newTestManager(t) stateFile := setupAttestationMetadataDB(t, map[string]string{ diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go index f68302bc..8d10e713 100644 --- a/internal/endpoint/endpoint.go +++ b/internal/endpoint/endpoint.go @@ -81,10 +81,14 @@ func ValidateLocalServerURL(raw string) (*url.URL, error) { // described by serverURL. For unix socket URLs it installs a custom dialer; for // TCP URLs it returns a plain client. func NewAgentHTTPClient(serverURL *url.URL) *http.Client { + noRedirect := func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + } if serverURL.Scheme == "unix" { socketPath := serverURL.Path return &http.Client{ - Timeout: 5 * time.Second, + Timeout: 5 * time.Second, + CheckRedirect: noRedirect, Transport: &http.Transport{ DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return net.Dial("unix", socketPath) @@ -92,7 +96,7 @@ func NewAgentHTTPClient(serverURL *url.URL) *http.Client { }, } } - return &http.Client{Timeout: 5 * time.Second} + return &http.Client{Timeout: 5 * time.Second, CheckRedirect: noRedirect} } // AgentBaseURL returns the HTTP base URL to use when constructing request URLs. diff --git a/internal/endpoint/endpoint_test.go b/internal/endpoint/endpoint_test.go index 289b8ccc..fe9fde54 100644 --- a/internal/endpoint/endpoint_test.go +++ b/internal/endpoint/endpoint_test.go @@ -16,6 +16,9 @@ package endpoint import ( + "net/http" + "net/http/httptest" + "sync/atomic" "testing" "time" @@ -128,12 +131,39 @@ func TestNewAgentHTTPClient(t *testing.T) { assert.Equal(t, 5*time.Second, client.Timeout) }) + t.Run("tcp_client_does_not_follow_redirects", func(t *testing.T) { + var redirectTargetCalled atomic.Bool + redirectTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + redirectTargetCalled.Store(true) + w.WriteHeader(http.StatusOK) + })) + defer redirectTarget.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, redirectTarget.URL, http.StatusFound) + })) + defer server.Close() + + u, err := ValidateLocalServerURL(server.URL) + require.NoError(t, err) + client := NewAgentHTTPClient(u) + + resp, err := client.Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusFound, resp.StatusCode) + assert.Equal(t, redirectTarget.URL, resp.Header.Get("Location")) + assert.False(t, redirectTargetCalled.Load()) + }) + t.Run("unix_client_has_timeout_and_transport", func(t *testing.T) { u, err := ValidateLocalServerURL("/run/fleetint/fleetint.sock") require.NoError(t, err) client := NewAgentHTTPClient(u) assert.NotNil(t, client) assert.Equal(t, 5*time.Second, client.Timeout) + assert.NotNil(t, client.CheckRedirect, "unix client should disable redirects") assert.NotNil(t, client.Transport, "unix client should have a custom transport") }) } diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index d1436ce9..67c26175 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -47,9 +47,13 @@ func PerformEnrollment(ctx context.Context, enrollEndpoint, sakToken string) (st // Use the provided enrollment endpoint directly enrollURL := enrollEndpoint - // Create HTTP client with timeout + // Create HTTP client with timeout. Disable redirects so a compromised + // backend cannot bounce us to an internal service (SSRF). client := &http.Client{ Timeout: 30 * time.Second, + CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }, } // Create HTTP request with empty body diff --git a/internal/enrollment/enrollment_test.go b/internal/enrollment/enrollment_test.go index 27ae7784..74a1cc35 100644 --- a/internal/enrollment/enrollment_test.go +++ b/internal/enrollment/enrollment_test.go @@ -145,6 +145,28 @@ func TestPerformEnrollment_HTTPStatusCodes(t *testing.T) { } } +func TestPerformEnrollment_DoesNotFollowRedirects(t *testing.T) { + redirectTargetCalled := false + redirectTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + redirectTargetCalled = true + w.WriteHeader(http.StatusOK) + })) + defer redirectTarget.Close() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, redirectTarget.URL, http.StatusFound) + })) + defer server.Close() + + ctx := context.Background() + token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") + + require.Error(t, err) + assert.Contains(t, err.Error(), "status 302") + assert.Empty(t, token) + assert.False(t, redirectTargetCalled) +} + func TestPerformEnrollment_MissingJWTToken(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Send response without JWT token