diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2ac732bd..7ec52836 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: branches: [ main, 'release/**' ] env: - GO_VERSION: '1.26.1' + GO_VERSION: '1.26.2' GOFLAGS: '-trimpath' jobs: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index edc1403f..882e1970 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -6,7 +6,7 @@ on: - 'v*' # Triggers on version tags like v1.0.0, v2.1.3, etc. env: - GO_VERSION: '1.26.1' + GO_VERSION: '1.26.2' permissions: contents: write # Needed for creating GitHub releases diff --git a/.golangci.yml b/.golangci.yml index 4ddb2e22..67e46a3a 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,6 +1,6 @@ version: "2" run: - go: 1.26.1 + go: 1.26.2 linters: default: none enable: diff --git a/SECURITY.md b/SECURITY.md index 693f32a7..a32655dd 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -49,7 +49,7 @@ This repository contains the `fleetint` host agent. The notes below are intended - Exporter token refresh and endpoint reload: [`internal/exporter/exporter.go`](internal/exporter/exporter.go) - Local HTTP server routes: [`internal/server/server.go`](internal/server/server.go) - Optional fault injection handler: [`internal/server/handlers_inject_fault.go`](internal/server/handlers_inject_fault.go) -- Remote attestation and `nvattest` invocation: [`internal/attestation/attestation.go`](internal/attestation/attestation.go) +- Remote attestation and `nvattest` invocation: [`internal/attestation/manager.go`](internal/attestation/manager.go), [`internal/attestation/collector.go`](internal/attestation/collector.go) ### Threat Model diff --git a/cmd/fleetint/enroll.go b/cmd/fleetint/enroll.go index 42d2b504..a4740e46 100644 --- a/cmd/fleetint/enroll.go +++ b/cmd/fleetint/enroll.go @@ -22,21 +22,12 @@ import ( "os" "strings" - pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" - "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" "github.com/urfave/cli" - "github.com/NVIDIA/fleet-intelligence-agent/internal/config" - "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" "github.com/NVIDIA/fleet-intelligence-agent/internal/enrollment" ) -var ( - performEnrollment = func(enrollEndpoint, sakToken string) (string, error) { - return enrollment.PerformEnrollment(context.Background(), enrollEndpoint, sakToken) - } - storeEnrollmentConfig = storeConfigInMetadata -) +var performEnrollWorkflow = enrollment.Enroll // resolveToken returns the SAK token from --token, --token-file, or stdin. func resolveToken(cliContext *cli.Context) (string, error) { @@ -48,7 +39,7 @@ func resolveToken(cliContext *cli.Context) (string, error) { } if tokenFile != "" { - const maxTokenSize = 1 << 20 // 1 MiB -- SAK tokens are small; anything larger is a mistake + const maxTokenSize = 1 << 20 var raw []byte var err error if tokenFile == "-" { @@ -105,84 +96,5 @@ func enrollCommand(cliContext *cli.Context) error { fmt.Fprintln(writerFromContext(cliContext), "Proceeding with enrollment because --force was provided") } - baseURL, err := endpoint.ValidateBackendEndpoint(baseEndpoint) - if err != nil { - return fmt.Errorf("invalid enrollment endpoint: %w", err) - } - - // Construct enroll endpoint - enrollEndpoint, err := endpoint.JoinPath(baseURL, "api", "v1", "health", "enroll") - if err != nil { - return fmt.Errorf("failed to construct enroll endpoint: %w", err) - } - - // Make enrollment request to get JWT token - jwtToken, err := performEnrollment(enrollEndpoint, sakToken) - if err != nil { - // Error already printed to stderr by PerformEnrollment - return err - } - - // Construct other endpoints using url.JoinPath for proper URL handling - metricsEndpoint, err := endpoint.JoinPath(baseURL, "api", "v1", "health", "metrics") - if err != nil { - return fmt.Errorf("failed to construct metrics endpoint: %w", err) - } - logsEndpoint, err := endpoint.JoinPath(baseURL, "api", "v1", "health", "logs") - if err != nil { - return fmt.Errorf("failed to construct logs endpoint: %w", err) - } - nonceEndpoint, err := endpoint.JoinPath(baseURL, "api", "v1", "health", "nonce") - if err != nil { - return fmt.Errorf("failed to construct nonce endpoint: %w", err) - } - - // Store endpoints and JWT token in metadata table - if err := storeEnrollmentConfig(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken); err != nil { - return fmt.Errorf("failed to store configuration: %w", err) - } - - return nil -} - -func storeConfigInMetadata(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { - stateFile, err := config.DefaultStateFile() - if err != nil { - return fmt.Errorf("failed to get state file path: %w", err) - } - - dbRW, err := sqlite.Open(stateFile) - if err != nil { - return fmt.Errorf("failed to open state database: %w", err) - } - defer dbRW.Close() - - if err := pkgmetadata.CreateTableMetadata(context.Background(), dbRW); err != nil { - return fmt.Errorf("failed to create metadata table: %w", err) - } - - // Store SAK token (for JWT refresh), JWT token (for API calls), and all endpoints - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "sak_token", sakToken); err != nil { - return fmt.Errorf("failed to set SAK token: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, pkgmetadata.MetadataKeyToken, jwtToken); err != nil { - return fmt.Errorf("failed to set JWT token: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "enroll_endpoint", enrollEndpoint); err != nil { - return fmt.Errorf("failed to set enroll endpoint: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "metrics_endpoint", metricsEndpoint); err != nil { - return fmt.Errorf("failed to set metrics endpoint: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "logs_endpoint", logsEndpoint); err != nil { - return fmt.Errorf("failed to set logs endpoint: %w", err) - } - if err := pkgmetadata.SetMetadata(context.Background(), dbRW, "nonce_endpoint", nonceEndpoint); err != nil { - return fmt.Errorf("failed to set nonce endpoint: %w", err) - } - if err := config.SecureStateFilePermissions(stateFile); err != nil { - return fmt.Errorf("failed to secure state database permissions: %w", err) - } - - return nil + return performEnrollWorkflow(context.Background(), baseEndpoint, sakToken) } diff --git a/cmd/fleetint/enroll_test.go b/cmd/fleetint/enroll_test.go index faf63b8f..e191530b 100644 --- a/cmd/fleetint/enroll_test.go +++ b/cmd/fleetint/enroll_test.go @@ -17,9 +17,8 @@ package main import ( "bytes" + "context" "fmt" - "os" - "path/filepath" "testing" "github.com/stretchr/testify/assert" @@ -50,12 +49,10 @@ func TestEnrollCommandPrecheckError(t *testing.T) { func TestEnrollCommandBlocksOnFailedPrecheck(t *testing.T) { originalRunPrecheck := runPrecheck - originalPerformEnrollment := performEnrollment - originalStoreConfig := storeEnrollmentConfig + originalEnrollWorkflow := performEnrollWorkflow t.Cleanup(func() { runPrecheck = originalRunPrecheck - performEnrollment = originalPerformEnrollment - storeEnrollmentConfig = originalStoreConfig + performEnrollWorkflow = originalEnrollWorkflow }) enrollmentCalled := false @@ -66,11 +63,8 @@ func TestEnrollCommandBlocksOnFailedPrecheck(t *testing.T) { }, }, nil } - performEnrollment = func(enrollEndpoint, sakToken string) (string, error) { + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string) error { enrollmentCalled = true - return "jwt-token", nil - } - storeEnrollmentConfig = func(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { return nil } @@ -88,12 +82,10 @@ func TestEnrollCommandBlocksOnFailedPrecheck(t *testing.T) { func TestEnrollCommandForceBypassesFailedPrecheck(t *testing.T) { originalRunPrecheck := runPrecheck - originalPerformEnrollment := performEnrollment - originalStoreConfig := storeEnrollmentConfig + originalEnrollWorkflow := performEnrollWorkflow t.Cleanup(func() { runPrecheck = originalRunPrecheck - performEnrollment = originalPerformEnrollment - storeEnrollmentConfig = originalStoreConfig + performEnrollWorkflow = originalEnrollWorkflow }) enrollmentCalled := false @@ -104,11 +96,8 @@ func TestEnrollCommandForceBypassesFailedPrecheck(t *testing.T) { }, }, nil } - performEnrollment = func(enrollEndpoint, sakToken string) (string, error) { + performEnrollWorkflow = func(ctx context.Context, baseEndpoint, sakToken string) error { enrollmentCalled = true - return "jwt-token", nil - } - storeEnrollmentConfig = func(enrollEndpoint, metricsEndpoint, logsEndpoint, nonceEndpoint, jwtToken, sakToken string) error { return nil } @@ -120,35 +109,3 @@ func TestEnrollCommandForceBypassesFailedPrecheck(t *testing.T) { require.NoError(t, err) assert.True(t, enrollmentCalled) } - -func TestStoreConfigInMetadataSecuresFreshStateFile(t *testing.T) { - if os.Geteuid() == 0 { - t.Skip("test expects non-root default state path resolution") - } - - tmpHome := t.TempDir() - t.Setenv("HOME", tmpHome) - - err := storeConfigInMetadata( - "https://example.com/api/v1/health/enroll", - "https://example.com/api/v1/health/metrics", - "https://example.com/api/v1/health/logs", - "https://example.com/api/v1/health/nonce", - "jwt-token", - "sak-token", - ) - require.NoError(t, err) - - stateFile := filepath.Join(tmpHome, ".fleetint", "fleetint.state") - for _, candidate := range []string{stateFile, stateFile + "-wal", stateFile + "-shm"} { - info, err := os.Stat(candidate) - if os.IsNotExist(err) { - if candidate == stateFile { - require.NoError(t, err) - } - continue - } - require.NoError(t, err) - assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) - } -} diff --git a/cmd/fleetint/run.go b/cmd/fleetint/run.go index 466fbbb1..aaede648 100644 --- a/cmd/fleetint/run.go +++ b/cmd/fleetint/run.go @@ -202,15 +202,6 @@ func configureHealthExporterFromEnv(cfg *config.Config) error { return err } - // FLEETINT_ATTESTATION_JITTER_ENABLED - Enable/disable attestation jitter - if err := setBoolFromEnv("FLEETINT_ATTESTATION_JITTER_ENABLED", &he.Attestation.JitterEnabled, "set attestation jitter enabled from env", "attestation_jitter_enabled"); err != nil { - return err - } - - if err := setDurationFromEnv("FLEETINT_ATTESTATION_INTERVAL", &he.Attestation.Interval, "set attestation interval from env", "attestation_interval", 0, 0); err != nil { - return err - } - // Lookbacks if err := setDurationFromEnv("FLEETINT_METRICS_LOOKBACK", &he.MetricsLookback, "set health exporter metrics lookback from env", "metrics_lookback", 0, 0); err != nil { return err @@ -231,6 +222,29 @@ func configureHealthExporterFromEnv(cfg *config.Config) error { return nil } +func configureLoopConfigFromEnv(cfg *config.Config) error { + if cfg.Inventory != nil { + if err := setBoolFromEnv("FLEETINT_INVENTORY_ENABLED", &cfg.Inventory.Enabled, "set inventory enabled from env", "inventory_enabled"); err != nil { + return err + } + if err := setDurationFromEnv("FLEETINT_INVENTORY_INTERVAL", &cfg.Inventory.Interval, "set inventory interval from env", "inventory_interval", time.Minute, 0); err != nil { + return err + } + } + if cfg.Attestation != nil { + if err := setBoolFromEnv("FLEETINT_ATTESTATION_ENABLED", &cfg.Attestation.Enabled, "set attestation enabled from env", "attestation_enabled"); err != nil { + return err + } + if err := setDurationFromEnv("FLEETINT_ATTESTATION_INITIAL_INTERVAL", &cfg.Attestation.InitialInterval, "set attestation initial interval from env", "attestation_initial_interval", time.Minute, 0); err != nil { + return err + } + if err := setDurationFromEnv("FLEETINT_ATTESTATION_INTERVAL", &cfg.Attestation.Interval, "set attestation interval from env", "attestation_interval", time.Minute, 0); err != nil { + return err + } + } + return nil +} + func runCommand(cliContext *cli.Context) error { logLevel := cliContext.String("log-level") logFile := cliContext.String("log-file") @@ -310,6 +324,9 @@ func runCommand(cliContext *cli.Context) error { if err := configureHealthExporterFromEnv(cfg); err != nil { return fmt.Errorf("failed to configure health exporter from environment variables: %w", err) } + if err := configureLoopConfigFromEnv(cfg); err != nil { + return fmt.Errorf("failed to configure loop settings from environment variables: %w", err) + } log.Logger.Infow("health exporter configuration", "cfg", cfg.HealthExporter) if listenAddress != "" { diff --git a/cmd/fleetint/status.go b/cmd/fleetint/status.go index 6c61d4fc..edb7666a 100644 --- a/cmd/fleetint/status.go +++ b/cmd/fleetint/status.go @@ -31,11 +31,18 @@ import ( "github.com/NVIDIA/fleet-intelligence-sdk/pkg/systemd" "github.com/urfave/cli" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" "github.com/NVIDIA/fleet-intelligence-agent/internal/cmdutil" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" ) +type enrollmentStatus struct { + baseURL string + metricsEndpoint string + logsEndpoint string +} + func statusCommand(cliContext *cli.Context) error { logLevel := cliContext.String("log-level") serverURL := cliContext.String("server-url") @@ -78,25 +85,21 @@ func statusCommand(cliContext *cli.Context) error { defer dbRO.Close() log.Logger.Debugw("successfully opened state file for reading") - metricsEndpoint, err := pkgmetadata.ReadMetadata(rootCtx, dbRO, "metrics_endpoint") - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("failed to read metrics endpoint: %w", err) - } - logsEndpoint, err := pkgmetadata.ReadMetadata(rootCtx, dbRO, "logs_endpoint") - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("failed to read logs endpoint: %w", err) + enrollment, err := readEnrollmentStatus(rootCtx, dbRO) + if err != nil { + return err } - if metricsEndpoint != "" || logsEndpoint != "" { + if enrollment.baseURL != "" || enrollment.metricsEndpoint != "" || enrollment.logsEndpoint != "" { fmt.Printf("%s enrolled\n", cmdutil.CheckMark) - if metricsEndpoint != "" { - fmt.Printf(" metrics endpoint: %s\n", metricsEndpoint) + if enrollment.metricsEndpoint != "" { + fmt.Printf(" metrics endpoint: %s\n", enrollment.metricsEndpoint) } - if logsEndpoint != "" { - fmt.Printf(" logs endpoint: %s\n", logsEndpoint) + if enrollment.logsEndpoint != "" { + fmt.Printf(" logs endpoint: %s\n", enrollment.logsEndpoint) } } else { - fmt.Printf("%s not enrolled (no endpoint configured)\n", cmdutil.WarningSign) + fmt.Printf("%s not enrolled (no backend or legacy endpoints configured)\n", cmdutil.WarningSign) } var active bool @@ -159,3 +162,47 @@ func statusCommand(cliContext *cli.Context) error { fmt.Printf("%s successfully checked fleetint health\n", cmdutil.CheckMark) return nil } + +func readEnrollmentStatus(ctx context.Context, dbRO *sql.DB) (*enrollmentStatus, error) { + baseURL, err := pkgmetadata.ReadMetadata(ctx, dbRO, agentstate.MetadataKeyBackendBaseURL) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("failed to read backend base URL: %w", err) + } + + status := &enrollmentStatus{baseURL: baseURL} + if baseURL != "" { + validated, err := endpoint.ValidateBackendEndpoint(baseURL) + if err != nil { + log.Logger.Warnw("ignoring invalid backend base URL in metadata", "backend_base_url", baseURL, "error", err) + status.baseURL = "" + } else { + status.metricsEndpoint, err = endpoint.JoinPath(validated, "api", "v1", "health", "metrics") + if err != nil { + return nil, fmt.Errorf("failed to construct metrics endpoint: %w", err) + } + status.logsEndpoint, err = endpoint.JoinPath(validated, "api", "v1", "health", "logs") + if err != nil { + return nil, fmt.Errorf("failed to construct logs endpoint: %w", err) + } + return status, nil + } + } + + status.metricsEndpoint, err = readLegacyEndpoint(ctx, dbRO, "metrics_endpoint") + if err != nil { + return nil, err + } + status.logsEndpoint, err = readLegacyEndpoint(ctx, dbRO, "logs_endpoint") + if err != nil { + return nil, err + } + return status, nil +} + +func readLegacyEndpoint(ctx context.Context, dbRO *sql.DB, key string) (string, error) { + value, err := pkgmetadata.ReadMetadata(ctx, dbRO, key) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return "", fmt.Errorf("failed to read %s: %w", key, err) + } + return value, nil +} diff --git a/cmd/fleetint/unenroll.go b/cmd/fleetint/unenroll.go index 6bb8b6d2..a968a033 100644 --- a/cmd/fleetint/unenroll.go +++ b/cmd/fleetint/unenroll.go @@ -19,12 +19,14 @@ import ( "context" "database/sql" "fmt" + "strings" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" "github.com/urfave/cli" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" ) @@ -66,7 +68,8 @@ func removeEnrollmentMetadata(ctx context.Context, dbRW *sql.DB) error { // List of metadata keys to delete keysToDelete := []string{ pkgmetadata.MetadataKeyToken, - "sak_token", + agentstate.MetadataKeySAKToken, + agentstate.MetadataKeyBackendBaseURL, "enroll_endpoint", "metrics_endpoint", "logs_endpoint", @@ -74,7 +77,8 @@ func removeEnrollmentMetadata(ctx context.Context, dbRW *sql.DB) error { } // Build batch delete query - query := "DELETE FROM gpud_metadata WHERE key IN (?, ?, ?, ?, ?, ?)" + placeholders := strings.TrimSuffix(strings.Repeat("?, ", len(keysToDelete)), ", ") + query := fmt.Sprintf("DELETE FROM gpud_metadata WHERE key IN (%s)", placeholders) // Convert string slice to []interface{} for ExecContext args := make([]interface{}, len(keysToDelete)) diff --git a/cmd/fleetint/unenroll_test.go b/cmd/fleetint/unenroll_test.go new file mode 100644 index 00000000..b52c9e5e --- /dev/null +++ b/cmd/fleetint/unenroll_test.go @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "context" + "path/filepath" + "testing" + + pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + "github.com/stretchr/testify/require" +) + +func TestRemoveEnrollmentMetadata(t *testing.T) { + t.Parallel() + + db, err := sqlite.Open(filepath.Join(t.TempDir(), "agent.state")) + require.NoError(t, err) + defer db.Close() + + ctx := context.Background() + require.NoError(t, pkgmetadata.CreateTableMetadata(ctx, db)) + + for key, value := range map[string]string{ + pkgmetadata.MetadataKeyToken: "jwt-token", + "sak_token": "sak-token", + "backend_base_url": "https://backend.example.com", + "enroll_endpoint": "https://backend.example.com/api/v1/enroll", + "metrics_endpoint": "https://backend.example.com/api/v1/health/metrics", + "logs_endpoint": "https://backend.example.com/api/v1/health/logs", + "nonce_endpoint": "https://backend.example.com/api/v1/attest/nonce", + "keep_me": "still-here", + } { + require.NoError(t, pkgmetadata.SetMetadata(ctx, db, key, value)) + } + + require.NoError(t, removeEnrollmentMetadata(ctx, db)) + + for _, key := range []string{ + pkgmetadata.MetadataKeyToken, + "sak_token", + "backend_base_url", + "enroll_endpoint", + "metrics_endpoint", + "logs_endpoint", + "nonce_endpoint", + } { + value, err := pkgmetadata.ReadMetadata(ctx, db, key) + require.NoError(t, err) + require.Empty(t, value) + } + + value, err := pkgmetadata.ReadMetadata(ctx, db, "keep_me") + require.NoError(t, err) + require.Equal(t, "still-here", value) +} diff --git a/deployments/helm/fleet-intelligence-agent/README.md b/deployments/helm/fleet-intelligence-agent/README.md index ddec247a..c6d75b28 100644 --- a/deployments/helm/fleet-intelligence-agent/README.md +++ b/deployments/helm/fleet-intelligence-agent/README.md @@ -33,7 +33,10 @@ Common values (defaults from `values.yaml`): | `env.FLEETINT_EVENTS_LOOKBACK` | `"1m"` | Lookback window for events export. | | `env.FLEETINT_CHECK_INTERVAL` | `"1m"` | Health check frequency (1s to 24h). | | `env.FLEETINT_RETRY_MAX_ATTEMPTS` | `"3"` | Max retry attempts for failed exports. | -| `env.FLEETINT_ATTESTATION_JITTER_ENABLED` | `"true"` | Enable/disable attestation jitter. | +| `env.FLEETINT_INVENTORY_ENABLED` | `"true"` | Enable or disable the inventory loop. | +| `env.FLEETINT_INVENTORY_INTERVAL` | `"1h"` | Inventory loop interval override. | +| `env.FLEETINT_ATTESTATION_ENABLED` | `"true"` | Enable or disable the attestation loop. | +| `env.FLEETINT_ATTESTATION_INITIAL_INTERVAL` | `"5m"` | Initial attestation bootstrap interval before the first successful attestation. | | `env.FLEETINT_ATTESTATION_INTERVAL` | `"24h"` | Attestation interval override. | | `env.HTTP_PROXY` | `""` | Optional HTTP proxy for outbound requests. | | `env.HTTPS_PROXY` | `""` | Optional HTTPS proxy for outbound requests. | diff --git a/deployments/helm/fleet-intelligence-agent/values.yaml b/deployments/helm/fleet-intelligence-agent/values.yaml index 0e1b75ee..f7d2150b 100644 --- a/deployments/helm/fleet-intelligence-agent/values.yaml +++ b/deployments/helm/fleet-intelligence-agent/values.yaml @@ -40,7 +40,10 @@ env: FLEETINT_EVENTS_LOOKBACK: "1m" FLEETINT_CHECK_INTERVAL: "1m" FLEETINT_RETRY_MAX_ATTEMPTS: "3" - FLEETINT_ATTESTATION_JITTER_ENABLED: "true" + FLEETINT_INVENTORY_ENABLED: "true" + FLEETINT_INVENTORY_INTERVAL: "1h" + FLEETINT_ATTESTATION_ENABLED: "true" + FLEETINT_ATTESTATION_INITIAL_INTERVAL: "5m" FLEETINT_ATTESTATION_INTERVAL: "24h" HTTP_PROXY: "" HTTPS_PROXY: "" diff --git a/deployments/packages/systemd/fleetint.env b/deployments/packages/systemd/fleetint.env index e589d782..b06c47d7 100644 --- a/deployments/packages/systemd/fleetint.env +++ b/deployments/packages/systemd/fleetint.env @@ -11,7 +11,10 @@ FLEETINT_METRICS_LOOKBACK="1m" FLEETINT_EVENTS_LOOKBACK="1m" FLEETINT_CHECK_INTERVAL="1m" FLEETINT_RETRY_MAX_ATTEMPTS="3" -FLEETINT_ATTESTATION_JITTER_ENABLED="true" +FLEETINT_INVENTORY_ENABLED="true" +FLEETINT_INVENTORY_INTERVAL="1h" +FLEETINT_ATTESTATION_ENABLED="true" +FLEETINT_ATTESTATION_INITIAL_INTERVAL="5m" FLEETINT_ATTESTATION_INTERVAL="24h" HTTP_PROXY="" HTTPS_PROXY="" diff --git a/docs/configuration.md b/docs/configuration.md index 9ec52f38..b3324345 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -50,7 +50,10 @@ These environment variables are read by `fleetint run` at startup. | `FLEETINT_EVENTS_LOOKBACK` | Lookback window for events included in each export. | `1m` | `/etc/default/fleetint` | `env.FLEETINT_EVENTS_LOOKBACK` | | `FLEETINT_CHECK_INTERVAL` | Health check interval for monitored components. Valid range: `1s` to `24h`. | `1m` | `/etc/default/fleetint` | `env.FLEETINT_CHECK_INTERVAL` | | `FLEETINT_RETRY_MAX_ATTEMPTS` | Maximum retry attempts for failed exports. Minimum: `0`. | `3` | `/etc/default/fleetint` | `env.FLEETINT_RETRY_MAX_ATTEMPTS` | -| `FLEETINT_ATTESTATION_JITTER_ENABLED` | Enable random startup jitter for attestation scheduling. | `true` | `/etc/default/fleetint` | `env.FLEETINT_ATTESTATION_JITTER_ENABLED` | +| `FLEETINT_INVENTORY_ENABLED` | Enable or disable the inventory loop. | `true` | `/etc/default/fleetint` | `env.FLEETINT_INVENTORY_ENABLED` | +| `FLEETINT_INVENTORY_INTERVAL` | Inventory loop interval override. Minimum: `1m`. | `1h` | `/etc/default/fleetint` | `env.FLEETINT_INVENTORY_INTERVAL` | +| `FLEETINT_ATTESTATION_ENABLED` | Enable or disable the attestation loop. | `true` | `/etc/default/fleetint` | `env.FLEETINT_ATTESTATION_ENABLED` | +| `FLEETINT_ATTESTATION_INITIAL_INTERVAL` | Initial attestation bootstrap interval before the first successful attestation. Minimum: `1m`. | `5m` | `/etc/default/fleetint` | `env.FLEETINT_ATTESTATION_INITIAL_INTERVAL` | | `FLEETINT_ATTESTATION_INTERVAL` | Attestation interval override. | `24h` | `/etc/default/fleetint` | `env.FLEETINT_ATTESTATION_INTERVAL` | | `HTTP_PROXY` | Proxy URL for outbound HTTP requests. | empty | `/etc/default/fleetint` | `env.HTTP_PROXY` | | `HTTPS_PROXY` | Proxy URL for outbound HTTPS requests. | empty | `/etc/default/fleetint` | `env.HTTPS_PROXY` | @@ -58,7 +61,7 @@ These environment variables are read by `fleetint run` at startup. Notes: - Duration-valued environment variables use Go duration syntax such as `30s`, `1m`, `10m`, or `24h`. -- These environment variables modify the health exporter configuration used by `fleetint run`. +- These environment variables modify the telemetry exporter configuration and runtime loop intervals used by `fleetint run`. - `DCGM_URL` and `DCGM_URL_IS_UNIX_SOCKET` configure connectivity to DCGM HostEngine for DCGM-backed components. ### Bare Metal Example @@ -95,6 +98,7 @@ env: FLEETINT_COLLECT_INTERVAL: "2m" FLEETINT_INCLUDE_EVENTS: "false" FLEETINT_CHECK_INTERVAL: "30s" + FLEETINT_INVENTORY_INTERVAL: "2m" HTTPS_PROXY: "http://proxy.example.com:3128" ``` diff --git a/internal/agentstate/sqlite.go b/internal/agentstate/sqlite.go new file mode 100644 index 00000000..285b0644 --- /dev/null +++ b/internal/agentstate/sqlite.go @@ -0,0 +1,184 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package agentstate + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + sqlite3 "github.com/mattn/go-sqlite3" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/config" + "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" +) + +type sqliteState struct { + stateFileFn func() (string, error) +} + +// NewSQLite returns a State backed by the agent sqlite metadata database. +func NewSQLite() State { + return &sqliteState{stateFileFn: config.DefaultStateFile} +} + +func (s *sqliteState) GetBackendBaseURL(ctx context.Context) (string, bool, error) { + db, err := s.openReadOnly() + if err != nil { + return "", false, err + } + defer db.Close() + + value, err := pkgmetadata.ReadMetadata(ctx, db, MetadataKeyBackendBaseURL) + switch { + case err == nil && value != "": + return value, true, nil + case err == nil || isMetadataAbsentErr(err): + // fall through to legacy endpoint keys + default: + return "", false, fmt.Errorf("read metadata %q: %w", MetadataKeyBackendBaseURL, err) + } + + for _, key := range []string{"enroll_endpoint", "metrics_endpoint", "logs_endpoint", "nonce_endpoint"} { + value, err := pkgmetadata.ReadMetadata(ctx, db, key) + switch { + case err == nil && value == "": + continue + case err == nil: + baseURL, err := endpoint.DeriveBackendBaseURL(value) + if err != nil { + return "", false, fmt.Errorf("derive backend base URL from metadata %q: %w", key, err) + } + return baseURL, true, nil + case isMetadataAbsentErr(err): + continue + default: + return "", false, fmt.Errorf("read metadata %q: %w", key, err) + } + } + + return "", false, nil +} + +func (s *sqliteState) SetBackendBaseURL(ctx context.Context, value string) error { + if _, err := endpoint.ValidateBackendEndpoint(value); err != nil { + return fmt.Errorf("validate backend base URL: %w", err) + } + return s.setMetadata(ctx, MetadataKeyBackendBaseURL, value) +} + +func (s *sqliteState) GetJWT(ctx context.Context) (string, bool, error) { + return s.getMetadata(ctx, pkgmetadata.MetadataKeyToken) +} + +func (s *sqliteState) SetJWT(ctx context.Context, value string) error { + return s.setMetadata(ctx, pkgmetadata.MetadataKeyToken, value) +} + +func (s *sqliteState) GetSAK(ctx context.Context) (string, bool, error) { + return s.getMetadata(ctx, MetadataKeySAKToken) +} + +func (s *sqliteState) SetSAK(ctx context.Context, value string) error { + return s.setMetadata(ctx, MetadataKeySAKToken, value) +} + +func (s *sqliteState) GetNodeUUID(ctx context.Context) (string, bool, error) { + return s.getMetadata(ctx, pkgmetadata.MetadataKeyMachineID) +} + +func (s *sqliteState) SetNodeUUID(ctx context.Context, value string) error { + return s.setMetadata(ctx, pkgmetadata.MetadataKeyMachineID, value) +} + +func (s *sqliteState) getMetadata(ctx context.Context, key string) (string, bool, error) { + db, err := s.openReadOnly() + if err != nil { + return "", false, err + } + defer db.Close() + + value, err := pkgmetadata.ReadMetadata(ctx, db, key) + if err != nil { + if isMetadataAbsentErr(err) { + return "", false, nil + } + return "", false, fmt.Errorf("read metadata %q: %w", key, err) + } + if value == "" { + return "", false, nil + } + return value, true, nil +} + +func (s *sqliteState) setMetadata(ctx context.Context, key, value string) error { + db, err := s.openReadWrite() + if err != nil { + return err + } + defer db.Close() + + stateFile, err := s.stateFileFn() + if err == nil { + if err := config.SecureStateFilePermissions(stateFile); err != nil { + return fmt.Errorf("secure state file permissions: %w", err) + } + } + if err := pkgmetadata.CreateTableMetadata(ctx, db); err != nil { + return fmt.Errorf("create metadata table: %w", err) + } + if err := pkgmetadata.SetMetadata(ctx, db, key, value); err != nil { + return fmt.Errorf("set metadata %q: %w", key, err) + } + return nil +} + +func (s *sqliteState) openReadOnly() (*sql.DB, error) { + stateFile, err := s.stateFileFn() + if err != nil { + return nil, fmt.Errorf("get state file path: %w", err) + } + db, err := sqlite.Open(stateFile, sqlite.WithReadOnly(true)) + if err != nil { + return nil, fmt.Errorf("open state database read-only: %w", err) + } + return db, nil +} + +func isMetadataAbsentErr(err error) bool { + if errors.Is(err, sql.ErrNoRows) { + return true + } + + var sqliteErr sqlite3.Error + return errors.As(err, &sqliteErr) && sqliteErr.Code == sqlite3.ErrError && strings.Contains(strings.ToLower(sqliteErr.Error()), "no such table") +} + +func (s *sqliteState) openReadWrite() (*sql.DB, error) { + stateFile, err := s.stateFileFn() + if err != nil { + return nil, fmt.Errorf("get state file path: %w", err) + } + db, err := sqlite.Open(stateFile) + if err != nil { + return nil, fmt.Errorf("open state database: %w", err) + } + return db, nil +} diff --git a/internal/agentstate/sqlite_test.go b/internal/agentstate/sqlite_test.go new file mode 100644 index 00000000..175b7ad6 --- /dev/null +++ b/internal/agentstate/sqlite_test.go @@ -0,0 +1,182 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package agentstate + +import ( + "context" + "database/sql" + "errors" + "path/filepath" + "testing" + + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + "github.com/stretchr/testify/require" +) + +func newTestSQLiteState(t *testing.T) *sqliteState { + t.Helper() + stateFile := filepath.Join(t.TempDir(), "agent.state") + return &sqliteState{ + stateFileFn: func() (string, error) { + return stateFile, nil + }, + } +} + +func TestSQLiteStateRoundTrip(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + err := state.SetBackendBaseURL(ctx, "https://backend.example.com") + require.NoError(t, err) + err = state.SetJWT(ctx, "jwt-token") + require.NoError(t, err) + err = state.SetSAK(ctx, "sak-token") + require.NoError(t, err) + err = state.SetNodeUUID(ctx, "node-1") + require.NoError(t, err) + + value, ok, err := state.GetBackendBaseURL(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "https://backend.example.com", value) + + value, ok, err = state.GetJWT(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "jwt-token", value) + + value, ok, err = state.GetSAK(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "sak-token", value) + + value, ok, err = state.GetNodeUUID(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "node-1", value) +} + +func TestSQLiteStateMissingValue(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + err := state.SetJWT(ctx, "jwt-token") + require.NoError(t, err) + + value, ok, err := state.GetBackendBaseURL(ctx) + require.NoError(t, err) + require.False(t, ok) + require.Empty(t, value) +} + +func TestSQLiteStateMissingMetadataTableIsTreatedAsAbsent(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + stateFile, err := state.stateFileFn() + require.NoError(t, err) + + db, err := sqlite.Open(stateFile) + require.NoError(t, err) + _, err = db.Exec("PRAGMA user_version = 1") + require.NoError(t, err) + require.NoError(t, db.Close()) + + for _, get := range []func(context.Context) (string, bool, error){ + state.GetJWT, + state.GetSAK, + state.GetNodeUUID, + } { + value, ok, err := get(ctx) + require.NoError(t, err) + require.False(t, ok) + require.Empty(t, value) + } +} + +func TestSQLiteStateSetBackendBaseURLValidatesInput(t *testing.T) { + t.Parallel() + + state := newTestSQLiteState(t) + + err := state.SetBackendBaseURL(context.Background(), "http://example.com") + require.Error(t, err) + + err = state.SetBackendBaseURL(context.Background(), "not-a-url") + require.Error(t, err) +} + +func TestSQLiteStateGetBackendBaseURLFallsBackToLegacyEndpoints(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + err := state.setMetadata(ctx, "metrics_endpoint", "https://backend.example.com/api/v1/health/metrics") + require.NoError(t, err) + + value, ok, err := state.GetBackendBaseURL(ctx) + require.NoError(t, err) + require.True(t, ok) + require.Equal(t, "https://backend.example.com", value) +} + +func TestSQLiteStateStateFileErrors(t *testing.T) { + t.Parallel() + + boom := errors.New("boom") + state := &sqliteState{ + stateFileFn: func() (string, error) { + return "", boom + }, + } + + _, _, err := state.GetJWT(context.Background()) + require.ErrorIs(t, err, boom) + + err = state.SetJWT(context.Background(), "jwt-token") + require.ErrorIs(t, err, boom) +} + +func TestNewSQLite(t *testing.T) { + t.Parallel() + require.NotNil(t, NewSQLite()) +} + +func TestSQLiteStateGetBackendBaseURLPropagatesReadErrors(t *testing.T) { + t.Parallel() + + ctx := context.Background() + state := newTestSQLiteState(t) + + stateFile, err := state.stateFileFn() + require.NoError(t, err) + + db, err := sqlite.Open(stateFile) + require.NoError(t, err) + require.NoError(t, db.Close()) + + _, _, err = state.GetBackendBaseURL(ctx) + require.Error(t, err) + require.NotErrorIs(t, err, sql.ErrNoRows) +} diff --git a/internal/agentstate/state.go b/internal/agentstate/state.go new file mode 100644 index 00000000..2429c76e --- /dev/null +++ b/internal/agentstate/state.go @@ -0,0 +1,39 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package agentstate centralizes access to local persisted agent state. +package agentstate + +import "context" + +const ( + MetadataKeyBackendBaseURL = "backend_base_url" + MetadataKeySAKToken = "sak_token" +) + +// State provides local persisted metadata/state access for backend workflows. +type State interface { + GetBackendBaseURL(ctx context.Context) (value string, ok bool, err error) + SetBackendBaseURL(ctx context.Context, value string) error + + GetJWT(ctx context.Context) (value string, ok bool, err error) + SetJWT(ctx context.Context, value string) error + + GetSAK(ctx context.Context) (value string, ok bool, err error) + SetSAK(ctx context.Context, value string) error + + GetNodeUUID(ctx context.Context) (value string, ok bool, err error) + SetNodeUUID(ctx context.Context, value string) error +} diff --git a/internal/attestation/attestation.go b/internal/attestation/attestation.go deleted file mode 100644 index e68984d0..00000000 --- a/internal/attestation/attestation.go +++ /dev/null @@ -1,602 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package attestation provides functionality for GPU attestation -package attestation - -import ( - "bytes" - "context" - "crypto/rand" - "encoding/json" - "fmt" - "math/big" - "net/http" - "os/exec" - "sync" - "time" - - pkgfile "github.com/NVIDIA/fleet-intelligence-sdk/pkg/file" - "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" - pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" - nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" - "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" - - "github.com/NVIDIA/fleet-intelligence-agent/internal/config" - "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" -) - -var defaultStateFileFn = config.DefaultStateFile - -// EvidenceItem represents a single evidence item from the attestation SDK -type EvidenceItem struct { - Arch string `json:"arch"` - Certificate string `json:"certificate"` - DriverVersion string `json:"driver_version"` - Evidence string `json:"evidence"` - Nonce string `json:"nonce"` - VBIOSVersion string `json:"vbios_version"` - Version string `json:"version"` -} - -// AttestationSDKResponse represents the complete response from the attestation SDK -type AttestationSDKResponse struct { - Evidences []EvidenceItem `json:"evidences"` - ResultCode int `json:"result_code"` - ResultMessage string `json:"result_message"` -} - -// AttestationData represents the complete attestation information including SDK response and timestamp -type AttestationData struct { - SDKResponse AttestationSDKResponse `json:"sdk_response"` - NonceRefreshTimestamp time.Time `json:"nonce_refresh_timestamp"` - Success bool `json:"success"` - ErrorMessage string `json:"error_message,omitempty"` -} - -// Manager manages the attestation process with configurable intervals -type Manager struct { - ctx context.Context - cancel context.CancelFunc - wg sync.WaitGroup - cache *cache - nvmlInstance nvidianvml.Instance - config *config.AttestationConfig -} - -// cache holds the latest attestation results with thread-safe access -type cache struct { - mu sync.RWMutex - attestationData *AttestationData - lastUpdated time.Time -} - -// GetAttestationData returns the current attestation data thread-safely -func (c *cache) GetAttestationData() *AttestationData { - c.mu.RLock() - defer c.mu.RUnlock() - return c.attestationData -} - -// updateAttestationData updates the attestation cache thread-safely -func (c *cache) updateAttestationData(attestationData *AttestationData) { - c.mu.Lock() - defer c.mu.Unlock() - c.attestationData = attestationData - c.lastUpdated = time.Now().UTC() -} - -// NewManager creates a new attestation manager -func NewManager(ctx context.Context, nvmlInstance nvidianvml.Instance, config *config.AttestationConfig) *Manager { - cctx, cancel := context.WithCancel(ctx) - - // Use 24 hours as default if not specified or invalid - if config.Interval.Duration <= 0 { - config.Interval.Duration = 24 * time.Hour - } - - return &Manager{ - ctx: cctx, - cancel: cancel, - cache: &cache{}, - nvmlInstance: nvmlInstance, - config: config, - } -} - -// GetAttestationData returns the current attestation data directly -func (m *Manager) GetAttestationData() *AttestationData { - return m.cache.GetAttestationData() -} - -// IsAttestationDataUpdated checks if attestation data has been updated since the given time -func (m *Manager) IsAttestationDataUpdated(since time.Time) bool { - return m.cache.isUpdatedSince(since) -} - -// isUpdatedSince checks if the cache has been updated since the given time -func (c *cache) isUpdatedSince(since time.Time) bool { - c.mu.RLock() - defer c.mu.RUnlock() - return c.lastUpdated.After(since) -} - -// retryInterval is the shorter interval used when agent is not enrolled yet -const retryInterval = 5 * time.Minute - -// Start begins the attestation loop with jitter to prevent thundering herd -// Uses dynamic intervals: shorter retry interval when not enrolled, normal interval otherwise -func (m *Manager) Start() { - log.Logger.Infow("Starting attestation manager with thundering herd prevention") - - m.wg.Add(1) - go func() { - defer m.wg.Done() - // Add initial jitter to spread out startup requests (0-30 minutes) if enabled - var initialJitter time.Duration - if m.config.JitterEnabled { - initialJitter = m.calculateJitter(30 * time.Minute) - log.Logger.Infow("Adding initial startup jitter to prevent thundering herd", - "jitter_duration", initialJitter) - } else { - log.Logger.Infow("Startup jitter disabled for testing") - } - - // Wait for initial jitter before first attestation - select { - case <-m.ctx.Done(): - log.Logger.Infow("Context done during initial jitter") - return - case <-time.After(initialJitter): - // Continue to first attestation - } - - // Run first attestation with additional jitter - shouldRetrySoon := m.runAttestationWithJitter() - - // Create ticker with configurable interval (default 24 hours) - // Will be reset after each attestation based on result - ticker := time.NewTicker(m.getNextInterval(shouldRetrySoon)) - defer ticker.Stop() - - log.Logger.Infow("Attestation ticker started", "interval", m.getNextInterval(shouldRetrySoon)) - - for { - select { - case <-m.ctx.Done(): - log.Logger.Infow("Context done, stopping attestation manager") - return - case <-ticker.C: - shouldRetrySoon = m.runAttestationWithJitter() - nextInterval := m.getNextInterval(shouldRetrySoon) - ticker.Reset(nextInterval) - } - } - }() -} - -// getNextInterval returns the appropriate interval based on whether we should retry soon -func (m *Manager) getNextInterval(shouldRetrySoon bool) time.Duration { - if shouldRetrySoon { - // Use the shorter of retryInterval and configured interval - // to avoid slowing down attestation in fast-retry environments (e.g., testing) - interval := retryInterval - if m.config.Interval.Duration < retryInterval { - interval = m.config.Interval.Duration - } - log.Logger.Infow("Agent not enrolled, using retry interval", - "retry_interval", interval) - return interval - } - log.Logger.Infow("Using normal attestation interval", - "interval", m.config.Interval.Duration) - return m.config.Interval.Duration -} - -// Stop gracefully shuts down the attestation manager and waits for the -// background goroutine to exit. This ensures that any in-progress call -// to defaultStateFileFn (or any other shared state) finishes before Stop -// returns, which prevents data races in tests and orderly cleanup in production. -func (m *Manager) Stop() { - log.Logger.Infow("Stopping attestation manager") - m.cancel() - m.wg.Wait() -} - -// runAttestation performs the attestation process and updates the cache -// Returns true if attestation should be retried soon (e.g., agent not enrolled yet) -func (m *Manager) runAttestation() bool { - log.Logger.Infow("Starting attestation process") - - // Always update cache with result (success or failure) so server knows status - attestationData := &AttestationData{} - - // Step 1: Get machine ID - log.Logger.Debugw("Getting machine ID in Attestation") - machineId, err := m.getMachineId() - if err != nil { - log.Logger.Errorw("Failed to get machine ID in Attestation", "error", err) - // SDKResponse and NonceRefreshTimestamp are nil - attestationData.Success = false - attestationData.ErrorMessage = err.Error() - m.cache.updateAttestationData(attestationData) - return false - } - - log.Logger.Debugw("Machine ID retrieved in Attestation", - "machine_id", machineId) - - // Step 2: Load JWT token from metadata database - jwtToken := m.getJWTTokenFromMetadata(m.ctx) - if jwtToken == "" { - if endpoint := m.getEndpointFromMetadata(m.ctx); endpoint != "" { - log.Logger.Errorw("JWT token not found in metadata", "endpoint", endpoint) - // SDKResponse and NonceRefreshTimestamp are nil - attestationData.Success = false - attestationData.ErrorMessage = "JWT token not found in agent metadata while agent is enrolled" - m.cache.updateAttestationData(attestationData) - return false - } else { - log.Logger.Infow("No backend endpoint found in metadata, agent not enrolled, will retry soon") - // SDKResponse and NonceRefreshTimestamp are nil - attestationData.Success = false - attestationData.ErrorMessage = "No backend endpoint found in metadata, agent is not enrolled" - m.cache.updateAttestationData(attestationData) - return true // Retry soon - agent may enroll shortly - } - } - - // Step 3: Get nonce by calling the envoy endpoint - nonce, nonceRefreshTimestamp, err := m.getNonce(jwtToken, machineId) - if err != nil { - // if agent is not enrolled, it will return in step 2. so here we can directly return the nonce error - log.Logger.Errorw("Failed to get nonce", "error", err) - // SDKResponse and NonceRefreshTimestamp are nil - attestationData.Success = false - attestationData.ErrorMessage = err.Error() - m.cache.updateAttestationData(attestationData) - return false - } - - // Update nonce refresh timestamp with actual server response - attestationData.NonceRefreshTimestamp = nonceRefreshTimestamp - - // Step 4: Get evidences from attestation SDK - log.Logger.Debugw("Getting evidences with nonce") - sdkResponse, err := m.getEvidences(nonce) - if err != nil { - log.Logger.Errorw("Failed to get evidences from attestation SDK", "error", err) - // SDKResponse - attestationData.Success = false - attestationData.ErrorMessage = err.Error() - m.cache.updateAttestationData(attestationData) - return false - } - - // Success case: populate all data - attestationData.SDKResponse = *sdkResponse - attestationData.Success = true - attestationData.ErrorMessage = "" - log.Logger.Debugw("Attestation data", "attestation_data", attestationData) - - // Update the attestation cache - m.cache.updateAttestationData(attestationData) - return false -} - -func (m *Manager) getNonce(jwtToken string, machineId string) (string, time.Time, error) { - endpointURL, err := m.getValidatedNonceEndpoint(m.ctx) - if err != nil { - return "", time.Time{}, err - } - - // Request payload (only include machine ID, JWT token goes in header) - requestBody, err := json.Marshal(map[string]string{ - "uuid": machineId, - }) - if err != nil { - log.Logger.Debugw("failed to marshal request body in nonce endpoint request", "error", err) - return "", time.Time{}, err - } - - // Create HTTP request tied to the manager context so that Stop() cancellation - // unblocks the request and prevents wg.Wait() from hanging indefinitely. - req, err := http.NewRequestWithContext(m.ctx, "POST", endpointURL, bytes.NewBuffer(requestBody)) - if err != nil { - log.Logger.Debugw("failed to create HTTP request in nonce endpoint request", "error", err) - return "", time.Time{}, err - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwtToken)) - - // Make the HTTP request. Disable redirects so a compromised backend - // cannot bounce us to an internal service (SSRF). - client := &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - }, - } - resp, err := client.Do(req) - if err != nil { - log.Logger.Debugw("failed to make POST request in nonce endpoint request", "error", err) - return "", time.Time{}, err - } - defer resp.Body.Close() - - // Parsing the response - var response struct { - Nonce string `json:"nonce"` - NonceRefreshTimestamp time.Time `json:"nonce_refresh_timestamp"` - Error string `json:"error"` - } - - log.Logger.Debugw("HTTP Response received:", - "status_code", resp.StatusCode, - "status", resp.Status, - "content_type", resp.Header.Get("Content-Type")) - - if resp.StatusCode != http.StatusOK { - return "", time.Time{}, fmt.Errorf("nonce endpoint returned HTTP %d", resp.StatusCode) - } - - if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { - log.Logger.Debugw("failed to decode response in nonce endpoint request", "error", err) - return "", time.Time{}, err - } - - if response.Error != "" { - log.Logger.Debugw("error from server in nonce endpoint request", "error", response.Error) - return "", time.Time{}, fmt.Errorf("nonce endpoint returned error: %s", response.Error) - } - - log.Logger.Debugw("Nonce received from server", "nonce_refresh_timestamp", response.NonceRefreshTimestamp) - - return response.Nonce, response.NonceRefreshTimestamp, nil -} - -func (m *Manager) getValidatedNonceEndpoint(ctx context.Context) (string, error) { - nonceEndpoint := m.getNonceEndpointFromMetadata(ctx) - if nonceEndpoint == "" { - return "", fmt.Errorf("nonce endpoint not found in metadata") - } - - validated, err := endpoint.ValidateBackendEndpoint(nonceEndpoint) - if err != nil { - return "", fmt.Errorf("invalid nonce endpoint: %w", err) - } - - return validated.String(), nil -} - -// getJWTTokenFromMetadata retrieves the JWT token from the metadata database -func (m *Manager) getJWTTokenFromMetadata(ctx context.Context) string { - stateFile, err := defaultStateFileFn() - if err != nil { - log.Logger.Debugw("failed to get state file path", "error", err) - return "" - } - - dbRO, err := sqlite.Open(stateFile) - if err != nil { - log.Logger.Debugw("failed to open state database", "error", err) - return "" - } - defer dbRO.Close() - - // Load JWT token from metadata - if token, err := pkgmetadata.ReadMetadata(ctx, dbRO, pkgmetadata.MetadataKeyToken); err == nil && token != "" { - return token - } - - log.Logger.Debugw("JWT token not found in metadata") - return "" -} - -func (m *Manager) getEndpointFromMetadata(ctx context.Context) string { - stateFile, err := defaultStateFileFn() - if err != nil { - log.Logger.Debugw("failed to get state file path", "error", err) - return "" - } - - dbRO, err := sqlite.Open(stateFile) - if err != nil { - log.Logger.Debugw("failed to open state database", "error", err) - return "" - } - defer dbRO.Close() - - // Load endpoint from metadata - if endpoint, err := pkgmetadata.ReadMetadata(ctx, dbRO, pkgmetadata.MetadataKeyEndpoint); err == nil && endpoint != "" { - return endpoint - } - - log.Logger.Debugw("backend endpoint not found in metadata") - return "" -} - -// getNonceEndpointFromMetadata retrieves the nonce endpoint from the metadata database -func (m *Manager) getNonceEndpointFromMetadata(ctx context.Context) string { - stateFile, err := defaultStateFileFn() - if err != nil { - log.Logger.Debugw("failed to get state file path", "error", err) - return "" - } - - dbRO, err := sqlite.Open(stateFile) - if err != nil { - log.Logger.Debugw("failed to open state database", "error", err) - return "" - } - defer dbRO.Close() - - // Load nonce endpoint from metadata - if endpoint, err := pkgmetadata.ReadMetadata(ctx, dbRO, "nonce_endpoint"); err == nil && endpoint != "" { - return endpoint - } - - log.Logger.Debugw("Nonce endpoint not found in metadata") - return "" -} - -func (m *Manager) getMachineId() (string, error) { - stateFile, err := defaultStateFileFn() - if err != nil { - return "", fmt.Errorf("failed to get state file path: %w", err) - } - - dbRO, err := sqlite.Open(stateFile) - if err != nil { - return "", fmt.Errorf("failed to open state database: %w", err) - } - defer dbRO.Close() - - machineID, err := pkgmetadata.ReadMachineID(m.ctx, dbRO) - if err != nil { - return "", fmt.Errorf("failed to read machine ID from metadata: %w", err) - } - - if machineID == "" { - return "", fmt.Errorf("machine ID not found in metadata") - } - - return machineID, nil -} - -// validateNonce verifies that a nonce returned by the backend is safe to forward -// as a command-line argument to nvattest. It enforces an allowlist of characters -// (hex, base64url, and common padding/separator symbols) and a maximum length so -// that a compromised backend cannot craft an argument that exploits nvattest's own -// argument parser. -func validateNonce(nonce string) error { - if nonce == "" { - return fmt.Errorf("nonce is empty") - } - const maxLen = 512 - if len(nonce) > maxLen { - return fmt.Errorf("nonce length %d exceeds maximum of %d characters", len(nonce), maxLen) - } - for i, c := range nonce { - switch { - case c >= '0' && c <= '9', - c >= 'a' && c <= 'z', - c >= 'A' && c <= 'Z', - c == '-', c == '_', c == '=', c == '+', c == '/': - // allowed - default: - return fmt.Errorf("nonce contains invalid character %q at position %d", c, i) - } - } - return nil -} - -func (m *Manager) getEvidences(nonce string) (*AttestationSDKResponse, error) { - if err := validateNonce(nonce); err != nil { - return nil, fmt.Errorf("invalid nonce received from backend: %w", err) - } - - log.Logger.Infow("Calling attestation SDK CLI") - - // Execute nvattest command - // Set timeout to prevent hanging, derived from manager context to respect cancellation - ctx, cancel := context.WithTimeout(m.ctx, 60*time.Second) - defer cancel() - - nvattestPath, err := pkgfile.LocateExecutable("nvattest") - if err != nil { - return nil, fmt.Errorf("failed to locate attestation CLI: %w", err) - } - cmd := exec.CommandContext(ctx, nvattestPath, "collect-evidence", "--gpu-evidence-source=corelib", "--nonce", nonce, "--gpu-architecture", "blackwell", "--format", "json") - - // Capture stdout (JSON) and stderr (logs) separately - var stdout, stderr bytes.Buffer - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - runErr := cmd.Run() - - log.Logger.Debugw("Attestation CLI completed", "exit_error", runErr, "stdout", stdout.String(), "stderr", stderr.String()) - - if runErr != nil { - // If stdout is empty, it means the command failed completely (e.g. command not found) - if stdout.Len() == 0 { - return nil, fmt.Errorf("attestation CLI execution failed: %w (stderr: %s)", runErr, stderr.String()) - } - // If stdout has content, we continue to try parsing the JSON response - log.Logger.Warnw("Attestation CLI returned error exit code but has output, attempting to parse", "error", runErr) - } - - // Parse the JSON response from stdout (clean JSON without log messages) - var response AttestationSDKResponse - if parseErr := json.Unmarshal(stdout.Bytes(), &response); parseErr != nil { - log.Logger.Debugw("Failed to parse attestation CLI JSON response", "parse_error", parseErr) - return nil, fmt.Errorf("failed to parse CLI response: %w (stderr: %s, stdout: %s, exec: %v)", parseErr, stderr.String(), stdout.String(), runErr) - } - - log.Logger.Infow("Successfully called attestation SDK", - "result_code", response.ResultCode, - "result_message", response.ResultMessage, - "evidences_count", len(response.Evidences)) - - return &response, nil -} - -// calculateJitter returns a random duration between 0 and maxJitter to prevent thundering herd -func (m *Manager) calculateJitter(maxJitter time.Duration) time.Duration { - if maxJitter <= 0 { - return 0 - } - - // Generate cryptographically secure random number - maxMs := int64(maxJitter / time.Millisecond) - if maxMs <= 0 { - return 0 - } - - randomMs, err := rand.Int(rand.Reader, big.NewInt(maxMs)) - if err != nil { - log.Logger.Warnw("Failed to generate secure random jitter, using fallback", "error", err) - // Fallback to simple time-based pseudo-random - return time.Duration(time.Now().UnixNano()%maxMs) * time.Millisecond - } - - return time.Duration(randomMs.Int64()) * time.Millisecond -} - -// runAttestationWithJitter runs attestation with additional per-request jitter -// Returns true if attestation should be retried soon (e.g., agent not enrolled yet) -func (m *Manager) runAttestationWithJitter() bool { - if !m.config.JitterEnabled { - log.Logger.Infow("Running attestation immediately (jitter disabled)") - return m.runAttestation() - } - - // Add significant jitter (0–30 minutes) for each request to spread load across a window, - // reducing thundering herd risk across many agents. - requestJitter := m.calculateJitter(30 * time.Minute) - log.Logger.Infow("Adding request jitter for thundering herd prevention", - "jitter_duration", requestJitter, - "max_jitter", "30 minutes") - - select { - case <-m.ctx.Done(): - return false - case <-time.After(requestJitter): - return m.runAttestation() - } -} diff --git a/internal/attestation/attestation_test.go b/internal/attestation/attestation_test.go deleted file mode 100644 index 3f11e87e..00000000 --- a/internal/attestation/attestation_test.go +++ /dev/null @@ -1,832 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package attestation - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" - "path/filepath" - "strings" - "sync/atomic" - "testing" - "time" - - pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" - "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - - "github.com/NVIDIA/fleet-intelligence-agent/internal/config" -) - -func useMissingStateFile(t *testing.T) { - t.Helper() - - orig := defaultStateFileFn - defaultStateFileFn = func() (string, error) { - return filepath.Join(t.TempDir(), "missing", "fleetint.state"), nil - } - t.Cleanup(func() { - defaultStateFileFn = orig - }) -} - -func TestManager_NewManager(t *testing.T) { - ctx := context.Background() - - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) // nil nvmlInstance, 20s for testing, no jitter - require.NotNil(t, manager) - assert.NotNil(t, manager.ctx) - assert.NotNil(t, manager.cancel) - assert.NotNil(t, manager.cache) - assert.Nil(t, manager.nvmlInstance) // Should be nil as passed -} - -func TestManager_StartStop(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Start should not block (Start() doesn't return error) - manager.Start() - - // Give it a moment to start - time.Sleep(50 * time.Millisecond) - - // Stop should work cleanly - manager.Stop() - - // Verify context is canceled - select { - case <-manager.ctx.Done(): - // Expected - context should be canceled - case <-time.After(100 * time.Millisecond): - t.Error("Expected context to be canceled after Stop()") - } -} - -func TestManager_GetAttestationData(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Initially should have empty data - attestationData := manager.GetAttestationData() - assert.Nil(t, attestationData) - - // Manually update cache to test getter - testAttestationData := &AttestationData{ - SDKResponse: AttestationSDKResponse{ - Evidences: []EvidenceItem{ - { - Arch: "turing", - Certificate: "test_cert", - DriverVersion: "550.90.07", - Evidence: "test_evidence", - Nonce: "test_nonce", - VBIOSVersion: "90.17.A9.00.0B", - Version: "1.0", - }, - }, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Now().UTC(), - } - - manager.cache.updateAttestationData(testAttestationData) - - // Now should return the test data - attestationData = manager.GetAttestationData() - assert.Equal(t, testAttestationData, attestationData) -} - -func TestManager_IsAttestationDataUpdated(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - baseTime := time.Now().UTC() - - // Initially, no updates - assert.False(t, manager.IsAttestationDataUpdated(baseTime)) - - // Update the cache - testData := &AttestationData{ - SDKResponse: AttestationSDKResponse{ - Evidences: []EvidenceItem{{ - Arch: "turing", - Certificate: "test", - DriverVersion: "550.90.07", - Evidence: "test", - Nonce: "test_nonce", - VBIOSVersion: "90.17.A9.00.0B", - Version: "1.0", - }}, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Now().UTC(), - } - manager.cache.updateAttestationData(testData) - - // Should now show as updated since baseTime - assert.True(t, manager.IsAttestationDataUpdated(baseTime)) - - // But not updated compared to a future time - futureTime := time.Now().Add(1 * time.Hour) - assert.False(t, manager.IsAttestationDataUpdated(futureTime)) -} - -func TestManager_GetMachineId_NoDatabase(t *testing.T) { - useMissingStateFile(t) - - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // getMachineId will fail because there's no database with machine ID in test environment - machineId, err := manager.getMachineId() - - // Expected to fail in test environment without proper database setup - assert.Error(t, err) - assert.Empty(t, machineId) -} - -// Define the response struct for testing -type testNonceResponse struct { - Nonce string `json:"nonce"` - NonceRefreshTimestamp time.Time `json:"nonceRefreshTimestamp"` - Error string `json:"error,omitempty"` -} - -func useDefaultTransport(t *testing.T, transport http.RoundTripper) { - t.Helper() - - orig := http.DefaultTransport - http.DefaultTransport = transport - t.Cleanup(func() { - http.DefaultTransport = orig - }) -} - -func TestManager_GetNonce_MockHTTP(t *testing.T) { - // Test the nonce parsing logic without actually calling the private method - // This tests the HTTP interaction pattern - - expectedNonce := "test-nonce-12345" - expectedTimestamp := time.Now().UTC() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify request method and content type - assert.Equal(t, "POST", r.Method) - assert.Equal(t, "application/json", r.Header.Get("Content-Type")) - - // Verify Bearer authorization header - authHeader := r.Header.Get("Authorization") - assert.Equal(t, "Bearer test-jwt-token", authHeader, "Should have Bearer authorization header") - - // Verify request body (should only contain uuid, not token) - var requestBody map[string]string - err := json.NewDecoder(r.Body).Decode(&requestBody) - require.NoError(t, err) - assert.Equal(t, "test-machine-id", requestBody["uuid"]) - assert.NotContains(t, requestBody, "token", "Token should not be in request body when using Bearer auth") - - // Send successful response - response := testNonceResponse{ - Nonce: expectedNonce, - NonceRefreshTimestamp: expectedTimestamp, - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("Failed to encode response: %v", err) - } - })) - defer server.Close() - - // Test HTTP request/response parsing manually with Bearer auth - url := server.URL + "/nonce" - requestBody, err := json.Marshal(map[string]string{ - "uuid": "test-machine-id", - }) - require.NoError(t, err) - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer test-jwt-token") - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - var response testNonceResponse - err = json.NewDecoder(resp.Body).Decode(&response) - require.NoError(t, err) - - assert.Equal(t, expectedNonce, response.Nonce) - assert.Equal(t, expectedTimestamp.Unix(), response.NonceRefreshTimestamp.Unix()) -} - -func TestManager_GetNonce_ServerError(t *testing.T) { - // Test server error response parsing - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - response := testNonceResponse{ - Error: "Invalid token", - } - w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(response); err != nil { - t.Errorf("Failed to encode response: %v", err) - } - })) - defer server.Close() - - // Test error response parsing with Bearer auth - url := server.URL + "/nonce" - requestBody, err := json.Marshal(map[string]string{ - "uuid": "test-machine-id", - }) - require.NoError(t, err) - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(requestBody)) - require.NoError(t, err) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer invalid-token") - - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - defer resp.Body.Close() - - var response testNonceResponse - err = json.NewDecoder(resp.Body).Decode(&response) - require.NoError(t, err) - - assert.Equal(t, "Invalid token", response.Error) - assert.Empty(t, response.Nonce) - assert.True(t, response.NonceRefreshTimestamp.IsZero()) -} - -func TestManager_GetNonce_RejectsNonOKStatus(t *testing.T) { - manager := newTestManager(t) - var redirectTargetCalled atomic.Bool - var server *httptest.Server - server = httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path == "/redirected" { - redirectTargetCalled.Store(true) - w.WriteHeader(http.StatusOK) - return - } - http.Redirect(w, r, server.URL+"/redirected", http.StatusFound) - })) - defer server.Close() - - useDefaultTransport(t, server.Client().Transport) - stateFile := setupAttestationMetadataDB(t, map[string]string{ - "nonce_endpoint": server.URL, - }) - useTestStateFile(t, stateFile) - - nonce, refresh, err := manager.getNonce("test-jwt-token", "test-machine-id") - - require.Error(t, err) - assert.Contains(t, err.Error(), "nonce endpoint returned HTTP 302") - assert.Empty(t, nonce) - assert.True(t, refresh.IsZero()) - assert.False(t, redirectTargetCalled.Load()) -} - -func TestManager_GetNonce_RejectsServerErrorPayload(t *testing.T) { - manager := newTestManager(t) - server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - require.NoError(t, json.NewEncoder(w).Encode(map[string]string{ - "error": "Invalid token", - })) - })) - defer server.Close() - - useDefaultTransport(t, server.Client().Transport) - stateFile := setupAttestationMetadataDB(t, map[string]string{ - "nonce_endpoint": server.URL, - }) - useTestStateFile(t, stateFile) - - nonce, refresh, err := manager.getNonce("test-jwt-token", "test-machine-id") - - require.Error(t, err) - assert.Contains(t, err.Error(), "nonce endpoint returned error: Invalid token") - assert.Empty(t, nonce) - assert.True(t, refresh.IsZero()) -} - -func TestManager_GetValidatedNonceEndpoint_UsesStoredNonceEndpoint(t *testing.T) { - manager := newTestManager(t) - stateFile := setupAttestationMetadataDB(t, map[string]string{ - "nonce_endpoint": "https://backend.example.com/api/v1/health/nonce", - }) - useTestStateFile(t, stateFile) - - got, err := manager.getValidatedNonceEndpoint(context.Background()) - require.NoError(t, err) - assert.Equal(t, "https://backend.example.com/api/v1/health/nonce", got) -} - -func TestManager_GetValidatedNonceEndpoint_RejectsTamperedStoredNonceEndpoint(t *testing.T) { - manager := newTestManager(t) - stateFile := setupAttestationMetadataDB(t, map[string]string{ - "nonce_endpoint": "http://evil.example.com/api/v1/health/nonce", - }) - useTestStateFile(t, stateFile) - - _, err := manager.getValidatedNonceEndpoint(context.Background()) - require.Error(t, err) - assert.Contains(t, err.Error(), "invalid nonce endpoint") - assert.Contains(t, err.Error(), "https") -} - -func newTestManager(t *testing.T) *Manager { - t.Helper() - - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - return NewManager(ctx, nil, cfg) -} - -func setupAttestationMetadataDB(t *testing.T, entries map[string]string) string { - t.Helper() - - stateFile := filepath.Join(t.TempDir(), "fleetint.state") - db, err := sqlite.Open(stateFile) - require.NoError(t, err) - - err = pkgmetadata.CreateTableMetadata(context.Background(), db) - require.NoError(t, err) - - for key, value := range entries { - err = pkgmetadata.SetMetadata(context.Background(), db, key, value) - require.NoError(t, err) - } - - err = db.Close() - require.NoError(t, err) - - return stateFile -} - -func useTestStateFile(t *testing.T, stateFile string) { - t.Helper() - - orig := defaultStateFileFn - defaultStateFileFn = func() (string, error) { - return stateFile, nil - } - t.Cleanup(func() { - defaultStateFileFn = orig - }) -} - -func TestManager_GetEvidences(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - testNonce := "931d8dd0add203ac3d8b4fbde75e115278eefcdceac5b87671a748f32364dfcb" - - sdkResponse, err := manager.getEvidences(testNonce) - - // In CI environment, the nvattest binary might not exist or directory might not exist - if err != nil { - // If binary is missing, this is expected in CI - if strings.Contains(err.Error(), "executable file not found") || strings.Contains(err.Error(), "no such file or directory") || strings.Contains(err.Error(), "not found in PATH") || strings.Contains(err.Error(), "The following arguments were not expected") { - t.Log("Attestation CLI binary or directory not found (expected in CI)") - return - } - // If it's another error, fail the test - require.NoError(t, err, "Unexpected error running attestation CLI") - } - - assert.NotNil(t, sdkResponse) - - // In test environment with binary present but no GPU (or mock), CLI may fail - // Check for expected real CLI response structure - if sdkResponse.ResultCode == 0 { - // Success case (when running on real attestation-capable hardware) - assert.Equal(t, "Ok", sdkResponse.ResultMessage) - assert.NotEmpty(t, sdkResponse.Evidences, "Should have evidences on success") - t.Log("Attestation CLI succeeded - running on attestation-capable hardware") - } else { - // Expected failure case (test environment without proper attestation hardware) - // We expect a structured error response from the CLI - t.Logf("Attestation CLI failed as expected in test environment: %s (Code: %d)", - sdkResponse.ResultMessage, sdkResponse.ResultCode) - } -} - -func TestCache_ThreadSafety(t *testing.T) { - cache := &cache{} - - // Test concurrent reads and writes - done := make(chan bool, 10) - - // Start multiple goroutines writing to cache - for i := 0; i < 5; i++ { - go func(id int) { - defer func() { done <- true }() - - for j := 0; j < 10; j++ { - testData := &AttestationData{ - SDKResponse: AttestationSDKResponse{ - Evidences: []EvidenceItem{ - { - Arch: "BLACKWELL", - Certificate: fmt.Sprintf("cert-%d-%d", id, j), - DriverVersion: "575.28", - Evidence: fmt.Sprintf("evidence-%d-%d", id, j), - Nonce: fmt.Sprintf("nonce-%d-%d", id, j), - VBIOSVersion: "96.00.AF.00.01", - Version: "1.0", - }, - }, - ResultCode: 0, - ResultMessage: "Ok", - //NonceRefreshTimestamp: time.Now().UTC().Add(time.Duration(id*j) * time.Millisecond), - }, - NonceRefreshTimestamp: time.Now().UTC().Add(time.Duration(id*j) * time.Millisecond), - } - cache.updateAttestationData(testData) - time.Sleep(time.Millisecond) // Small delay to increase chance of concurrent access - } - }(i) - } - - // Start multiple goroutines reading from cache - for i := 0; i < 5; i++ { - go func(id int) { - defer func() { done <- true }() - - for j := 0; j < 10; j++ { - cache.GetAttestationData() - baseTime := time.Now().UTC().Add(-time.Duration(j) * time.Second) - cache.isUpdatedSince(baseTime) - time.Sleep(time.Millisecond) - } - }(i) - } - - // Wait for all goroutines to complete - for i := 0; i < 10; i++ { - select { - case <-done: - // Good, goroutine completed - case <-time.After(15 * time.Second): - t.Fatal("Goroutines did not complete within timeout - possible deadlock") - } - } - - // Verify cache is still functional - attestationData := cache.GetAttestationData() - assert.NotNil(t, attestationData) // Should have some data from the concurrent writes -} - -func TestManager_RunAttestation_WithFallback(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Since runAttestation is called by Start(), we need to test it differently - // We'll test the fallback behavior by checking that it handles errors gracefully - - // The method should not panic when NVML is not available - assert.NotPanics(t, func() { - manager.runAttestation() - }) - - // After running with fallbacks, cache should contain failure information - attestationData := manager.GetAttestationData() - if attestationData != nil { - // If we have attestation data, it should indicate failure - assert.False(t, attestationData.Success, "Should indicate failure") - assert.NotEmpty(t, attestationData.ErrorMessage, "Should have error message") - t.Log("Attestation failed as expected with error:", attestationData.ErrorMessage) - } else { - t.Log("No attestation data available - this can happen if attestation manager didn't run") - } -} - -func TestManager_IntegrationTest(t *testing.T) { - // This is a more comprehensive test that tests the update tracking functionality - - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Test the update tracking over time - baseTime := time.Now().UTC() - - // Initially no updates - assert.False(t, manager.IsAttestationDataUpdated(baseTime)) - - // Manually update cache (simulating successful attestation) - testAttestationData := &AttestationData{ - SDKResponse: AttestationSDKResponse{ - Evidences: []EvidenceItem{ - { - Arch: "BLACKWELL", - Certificate: "integration-cert", - DriverVersion: "575.28", - Evidence: "integration-evidence", - Nonce: "integration-nonce", - VBIOSVersion: "96.00.AF.00.01", - Version: "1.0", - }, - }, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Now().UTC(), - } - manager.cache.updateAttestationData(testAttestationData) - - // Now should show updates - assert.True(t, manager.IsAttestationDataUpdated(baseTime)) - - // Data should be retrievable - attestationData := manager.GetAttestationData() - assert.Equal(t, testAttestationData, attestationData) - - // Test Start/Stop functionality - manager.Start() - time.Sleep(50 * time.Millisecond) // Let it start - manager.Stop() - - // Context should be done - select { - case <-manager.ctx.Done(): - // Expected - case <-time.After(100 * time.Millisecond): - t.Error("Context should be canceled after Stop()") - } -} - -func TestEvidenceItem_JSONSerialization(t *testing.T) { - evidence := EvidenceItem{ - Evidence: "test-evidence-data", - Certificate: "test-certificate-data", - } - - // Test marshaling - data, err := json.Marshal(evidence) - assert.NoError(t, err) - assert.Contains(t, string(data), "test-evidence-data") - assert.Contains(t, string(data), "test-certificate-data") - - // Test unmarshaling - var unmarshaled EvidenceItem - err = json.Unmarshal(data, &unmarshaled) - assert.NoError(t, err) - assert.Equal(t, evidence.Evidence, unmarshaled.Evidence) - assert.Equal(t, evidence.Certificate, unmarshaled.Certificate) -} - -func TestManager_CalculateJitter(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: true, - } - manager := NewManager(ctx, nil, cfg) - - // Test with 0 max jitter - jitter := manager.calculateJitter(0) - assert.Equal(t, time.Duration(0), jitter) - - // Test with positive max jitter - maxJitter := 100 * time.Millisecond - for i := 0; i < 10; i++ { - jitter = manager.calculateJitter(maxJitter) - assert.GreaterOrEqual(t, jitter, time.Duration(0)) - assert.Less(t, jitter, maxJitter) - } -} - -func TestValidateNonce(t *testing.T) { - tests := []struct { - name string - nonce string - wantErr string - }{ - {name: "valid_hex", nonce: "abcdef0123456789"}, - {name: "valid_base64", nonce: "dGVzdA=="}, - {name: "valid_base64url", nonce: "abc-def_ghi+jkl/mno="}, - {name: "empty", nonce: "", wantErr: "nonce is empty"}, - {name: "too_long", nonce: strings.Repeat("a", 513), wantErr: "exceeds maximum"}, - {name: "max_length_ok", nonce: strings.Repeat("a", 512)}, - {name: "space", nonce: "abc def", wantErr: "invalid character"}, - {name: "newline", nonce: "abc\ndef", wantErr: "invalid character"}, - {name: "semicolon", nonce: "abc;def", wantErr: "invalid character"}, - {name: "shell_metachar", nonce: "$(whoami)", wantErr: "invalid character"}, - {name: "flag_like_valid_chars", nonce: "--output=/etc/passwd"}, // all chars are in the base64url allowlist; safe because nvattest receives it as --nonce value, not a flag - {name: "pipe", nonce: "abc|def", wantErr: "invalid character"}, - {name: "null_byte", nonce: "abc\x00def", wantErr: "invalid character"}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - err := validateNonce(tc.nonce) - if tc.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.wantErr) - } else { - assert.NoError(t, err) - } - }) - } -} - -func TestManager_RunAttestation_ReturnsRetrySoon(t *testing.T) { - useMissingStateFile(t) - - // This test verifies that runAttestation returns the correct retry hint - // When agent is not enrolled, it should return true (retry soon) - // When there's a real failure, it should return false (normal interval) - - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 24 * time.Hour}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Run attestation - in test environment, this will fail at getMachineId - // (which is a real failure, not "not enrolled"), so it should return false - shouldRetrySoon := manager.runAttestation() - - // Since we don't have the metadata database with machine ID in test environment, - // attestation fails with a real error (not "not enrolled"), so shouldRetrySoon should be false - assert.False(t, shouldRetrySoon, "Should return false for real failures (not 'not enrolled')") - - // Verify the cache has the failure info - attestationData := manager.GetAttestationData() - require.NotNil(t, attestationData) - assert.False(t, attestationData.Success) - assert.NotEmpty(t, attestationData.ErrorMessage) - - // The error should NOT be about enrollment - assert.NotContains(t, attestationData.ErrorMessage, "not enrolled", - "Error should be about machine ID, not enrollment") -} - -func TestRetryInterval_Constant(t *testing.T) { - // Verify the retry interval constant is set appropriately - assert.Equal(t, 5*time.Minute, retryInterval, - "Retry interval should be 5 minutes for quick recovery after enrollment") -} - -func TestManager_GetNextInterval(t *testing.T) { - tests := []struct { - name string - configInterval time.Duration - shouldRetrySoon bool - expectedInterval time.Duration - }{ - { - name: "normal interval when not retrying", - configInterval: 24 * time.Hour, - shouldRetrySoon: false, - expectedInterval: 24 * time.Hour, - }, - { - name: "retry interval when config is longer", - configInterval: 24 * time.Hour, - shouldRetrySoon: true, - expectedInterval: 5 * time.Minute, // retryInterval - }, - { - name: "config interval when config is shorter than retry", - configInterval: 20 * time.Second, - shouldRetrySoon: true, - expectedInterval: 20 * time.Second, // use config, not retryInterval - }, - { - name: "config interval when config equals retry", - configInterval: 5 * time.Minute, - shouldRetrySoon: true, - expectedInterval: 5 * time.Minute, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: tt.configInterval}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - interval := manager.getNextInterval(tt.shouldRetrySoon) - assert.Equal(t, tt.expectedInterval, interval) - }) - } -} - -func TestManager_RunAttestationWithJitter_Disabled(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := NewManager(ctx, nil, cfg) - - // Should run immediately (we can't easily verify it ran without mocking runAttestation, - // but we can ensure it doesn't panic and covers the code path) - done := make(chan bool) - go func() { - manager.runAttestationWithJitter() - done <- true - }() - - select { - case <-done: - // Success - case <-time.After(15 * time.Second): - t.Error("runAttestationWithJitter should return immediately when jitter is disabled") - } -} - -func TestManager_RunAttestationWithJitter_Enabled(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: true, - } - manager := NewManager(ctx, nil, cfg) - - // We can't easily wait for the random jitter, but we can verify it respects context cancellation - done := make(chan bool) - go func() { - manager.runAttestationWithJitter() - done <- true - }() - - // Cancel context to force exit - cancel() - - select { - case <-done: - // Success - case <-time.After(15 * time.Second): - t.Error("runAttestationWithJitter should return when context is canceled") - } -} diff --git a/internal/attestation/backend.go b/internal/attestation/backend.go new file mode 100644 index 00000000..d9c52811 --- /dev/null +++ b/internal/attestation/backend.go @@ -0,0 +1,128 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestation + +import ( + "context" + "fmt" + "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +var newBackendClient = backendclient.New + +func toAttestationRequest(r *Result) *backendclient.AttestationRequest { + if r == nil { + return nil + } + req := &backendclient.AttestationRequest{ + AttestationData: backendclient.AttestationData{ + NonceRefreshTimestamp: r.NonceRefreshTimestamp, + Success: r.Success, + ErrorMessage: r.ErrorMessage, + SDKResponse: backendclient.AttestationSDKResponse{ + ResultCode: r.SDKResponse.ResultCode, + ResultMessage: r.SDKResponse.ResultMessage, + }, + }, + } + + if len(r.SDKResponse.Evidences) > 0 { + req.AttestationData.SDKResponse.Evidences = make([]backendclient.EvidenceItem, 0, len(r.SDKResponse.Evidences)) + for _, ev := range r.SDKResponse.Evidences { + req.AttestationData.SDKResponse.Evidences = append(req.AttestationData.SDKResponse.Evidences, backendclient.EvidenceItem{ + Arch: ev.Arch, + Certificate: ev.Certificate, + DriverVersion: ev.DriverVersion, + Evidence: ev.Evidence, + Nonce: ev.Nonce, + VBIOSVersion: ev.VBIOSVersion, + Version: ev.Version, + }) + } + } + + return req +} + +type stateBackendClientFactory struct { + state agentstate.State +} + +func (f *stateBackendClientFactory) client(ctx context.Context) (backendclient.Client, error) { + if f.state == nil { + return nil, fmt.Errorf("backend client factory requires agent state") + } + baseURL, ok, err := f.state.GetBackendBaseURL(ctx) + if err != nil { + return nil, err + } + if !ok || baseURL == "" { + return nil, fmt.Errorf("%w: backend base URL not available in agent state", ErrNotEnrolled) + } + return newBackendClient(baseURL) +} + +type stateNonceProvider struct { + factory *stateBackendClientFactory +} + +// NewStateNonceProvider creates a nonce provider that resolves backend state dynamically. +func NewStateNonceProvider(state agentstate.State) NonceProvider { + return &stateNonceProvider{factory: &stateBackendClientFactory{state: state}} +} + +func (p *stateNonceProvider) GetNonce(ctx context.Context, nodeUUID, jwt string) (string, time.Time, string, error) { + client, err := p.factory.client(ctx) + if err != nil { + return "", time.Time{}, "", err + } + return NewBackendNonceProvider(client).GetNonce(ctx, nodeUUID, jwt) +} + +type stateSubmitter struct { + factory *stateBackendClientFactory +} + +// NewStateBackendSubmitter creates a submitter that resolves backend state dynamically. +func NewStateBackendSubmitter(state agentstate.State) Submitter { + return &stateSubmitter{factory: &stateBackendClientFactory{state: state}} +} + +func (s *stateSubmitter) Submit(ctx context.Context, result *Result, jwt string) error { + if result == nil { + return fmt.Errorf("attestation submission requires result") + } + if result.NodeUUID == "" { + nodeUUID, ok, err := s.factory.state.GetNodeUUID(ctx) + if err != nil { + return err + } + if !ok || nodeUUID == "" { + return fmt.Errorf("%w: node UUID not available in agent state", ErrNotEnrolled) + } + cloned := *result + cloned.NodeUUID = nodeUUID + result = &cloned + } + client, err := s.factory.client(ctx) + if err != nil { + return err + } + return NewBackendSubmitter(client).Submit(ctx, result, jwt) +} diff --git a/internal/attestation/backend_test.go b/internal/attestation/backend_test.go new file mode 100644 index 00000000..ed43ae4d --- /dev/null +++ b/internal/attestation/backend_test.go @@ -0,0 +1,203 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestation + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +type stubState struct { + baseURL string + baseOK bool + baseErr error + jwt string + jwtOK bool + jwtErr error + setJWT string + nodeUUID string + nodeOK bool + nodeErr error +} + +func (s *stubState) GetBackendBaseURL(context.Context) (string, bool, error) { + return s.baseURL, s.baseOK, s.baseErr +} +func (s *stubState) SetBackendBaseURL(context.Context, string) error { return nil } +func (s *stubState) GetJWT(context.Context) (string, bool, error) { return s.jwt, s.jwtOK, s.jwtErr } +func (s *stubState) SetJWT(_ context.Context, v string) error { s.setJWT = v; s.jwt = v; return nil } +func (s *stubState) GetSAK(context.Context) (string, bool, error) { return "", false, nil } +func (s *stubState) SetSAK(context.Context, string) error { return nil } +func (s *stubState) GetNodeUUID(context.Context) (string, bool, error) { + return s.nodeUUID, s.nodeOK, s.nodeErr +} +func (s *stubState) SetNodeUUID(context.Context, string) error { return nil } + +type recordingClient struct { + lastNodeUUID string + lastJWT string + lastReq *backendclient.AttestationRequest + nonceResp *backendclient.NonceResponse +} + +func (c *recordingClient) Enroll(context.Context, string) (string, error) { return "", nil } +func (c *recordingClient) UpsertNode(context.Context, string, *backendclient.NodeUpsertRequest, string) error { + return nil +} +func (c *recordingClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return c.nonceResp, nil +} +func (c *recordingClient) SubmitAttestation(_ context.Context, nodeUUID string, req *backendclient.AttestationRequest, jwt string) error { + c.lastNodeUUID = nodeUUID + c.lastJWT = jwt + c.lastReq = req + return nil +} +func (c *recordingClient) RefreshToken(context.Context, string) (string, error) { return "", nil } + +func TestToAttestationRequest(t *testing.T) { + refreshTS := time.Now().UTC() + req := toAttestationRequest(&Result{ + NonceRefreshTimestamp: refreshTS, + Success: true, + ErrorMessage: "", + SDKResponse: SDKResponse{ + ResultCode: 7, + ResultMessage: "ok", + Evidences: []EvidenceItem{{ + Arch: "BLACKWELL", + Certificate: "cert", + DriverVersion: "575.1", + Evidence: "blob", + Nonce: "nonce", + VBIOSVersion: "vbios", + Version: "1.0", + }}, + }, + }) + require.NotNil(t, req) + require.Equal(t, refreshTS, req.AttestationData.NonceRefreshTimestamp) + require.True(t, req.AttestationData.Success) + require.Len(t, req.AttestationData.SDKResponse.Evidences, 1) + require.Equal(t, "BLACKWELL", req.AttestationData.SDKResponse.Evidences[0].Arch) + require.Nil(t, toAttestationRequest(nil)) +} + +func TestStateBackendClientFactory(t *testing.T) { + orig := newBackendClient + t.Cleanup(func() { newBackendClient = orig }) + + client := &recordingClient{} + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + require.Equal(t, "https://backend.example.com", rawBaseURL) + return client, nil + } + + factory := &stateBackendClientFactory{state: &stubState{baseURL: "https://backend.example.com", baseOK: true}} + got, err := factory.client(context.Background()) + require.NoError(t, err) + require.Equal(t, client, got) + + _, err = (&stateBackendClientFactory{}).client(context.Background()) + require.ErrorContains(t, err, "requires agent state") + + _, err = (&stateBackendClientFactory{state: &stubState{baseErr: errors.New("boom")}}).client(context.Background()) + require.ErrorContains(t, err, "boom") + + _, err = (&stateBackendClientFactory{state: &stubState{baseOK: false}}).client(context.Background()) + require.ErrorContains(t, err, "backend base URL not available") +} + +func TestStateProvidersAndSubmitter(t *testing.T) { + orig := newBackendClient + t.Cleanup(func() { newBackendClient = orig }) + + recording := &recordingClient{ + nonceResp: &backendclient.NonceResponse{ + Nonce: "abc123", + NonceRefreshTimestamp: time.Unix(10, 0).UTC(), + JWTAssertion: "new-jwt", + }, + } + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + require.Equal(t, "https://backend.example.com", rawBaseURL) + return recording, nil + } + + state := &stubState{ + baseURL: "https://backend.example.com", baseOK: true, + jwt: "jwt-token", jwtOK: true, + nodeUUID: "node-1", nodeOK: true, + } + + jwtProvider := NewStateJWTProvider(state) + jwt, err := jwtProvider.GetJWT(context.Background()) + require.NoError(t, err) + require.Equal(t, "jwt-token", jwt) + require.NoError(t, jwtProvider.SetJWT(context.Background(), "updated")) + require.Equal(t, "updated", state.setJWT) + + nodeUUID, err := NewStateNodeUUIDProvider(state)(context.Background()) + require.NoError(t, err) + require.Equal(t, "node-1", nodeUUID) + + nonce, ts, refreshedJWT, err := NewStateNonceProvider(state).GetNonce(context.Background(), "node-1", "jwt-token") + require.NoError(t, err) + require.Equal(t, "abc123", nonce) + require.Equal(t, time.Unix(10, 0).UTC(), ts) + require.Equal(t, "new-jwt", refreshedJWT) + + result := &Result{ + NodeUUID: "node-1", + SDKResponse: SDKResponse{ + ResultCode: 1, + Evidences: []EvidenceItem{{Arch: "BLACKWELL"}}, + }, + } + err = NewStateBackendSubmitter(state).Submit(context.Background(), result, "jwt-token") + require.NoError(t, err) + require.Equal(t, "node-1", recording.lastNodeUUID) + require.Equal(t, "jwt-token", recording.lastJWT) + require.NotNil(t, recording.lastReq) + require.Equal(t, "BLACKWELL", recording.lastReq.AttestationData.SDKResponse.Evidences[0].Arch) + + recording.lastNodeUUID = "" + err = NewStateBackendSubmitter(state).Submit(context.Background(), &Result{}, "jwt-token") + require.NoError(t, err) + require.Equal(t, "node-1", recording.lastNodeUUID) +} + +func TestStateProvidersPropagateBackendClientConstructionErrors(t *testing.T) { + orig := newBackendClient + t.Cleanup(func() { newBackendClient = orig }) + + newBackendClient = func(string) (backendclient.Client, error) { + return nil, errors.New("construct failed") + } + state := &stubState{baseURL: "https://backend.example.com", baseOK: true} + + _, _, _, err := NewStateNonceProvider(state).GetNonce(context.Background(), "node-1", "jwt-token") + require.ErrorContains(t, err, "construct failed") + + err = NewStateBackendSubmitter(state).Submit(context.Background(), &Result{NodeUUID: "node-1"}, "jwt-token") + require.ErrorContains(t, err, "construct failed") +} diff --git a/internal/attestation/collector.go b/internal/attestation/collector.go new file mode 100644 index 00000000..7f747bbf --- /dev/null +++ b/internal/attestation/collector.go @@ -0,0 +1,101 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestation + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "os/exec" + "time" +) + +var execCommandContext = exec.CommandContext + +type cliEvidenceCollector struct { + timeout time.Duration +} + +// NewCLIEvidenceCollector creates an evidence collector backed by the nvattest CLI. +func NewCLIEvidenceCollector(timeout time.Duration) EvidenceCollector { + if timeout <= 0 { + timeout = 60 * time.Second + } + return &cliEvidenceCollector{timeout: timeout} +} + +func (c *cliEvidenceCollector) Collect(ctx context.Context, nonce string) (*SDKResponse, error) { + if err := validateNonce(nonce); err != nil { + return nil, fmt.Errorf("invalid nonce received from backend: %w", err) + } + + runCtx, cancel := context.WithTimeout(ctx, c.timeout) + defer cancel() + + cmd := execCommandContext( + runCtx, + "nvattest", + "collect-evidence", + "--gpu-evidence-source=corelib", + "--nonce", nonce, + "--gpu-architecture", "blackwell", + "--format", "json", + ) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil && stdout.Len() == 0 { + return nil, fmt.Errorf("attestation CLI execution failed: %w (stderr: %s)", err, stderr.String()) + } + + var response SDKResponse + if parseErr := json.Unmarshal(stdout.Bytes(), &response); parseErr != nil { + errText := "" + if err != nil { + errText = err.Error() + } + return nil, fmt.Errorf( + "failed to parse CLI response: %w (stderr: %s), stdout: %s, error: %s", + parseErr, stderr.String(), stdout.String(), errText, + ) + } + return &response, nil +} + +func validateNonce(nonce string) error { + if nonce == "" { + return fmt.Errorf("nonce is empty") + } + const maxLen = 512 + if len(nonce) > maxLen { + return fmt.Errorf("nonce length %d exceeds maximum of %d characters", len(nonce), maxLen) + } + for i, c := range nonce { + switch { + case c >= '0' && c <= '9', + c >= 'a' && c <= 'z', + c >= 'A' && c <= 'Z', + c == '-', c == '_', c == '=', c == '+', c == '/': + default: + return fmt.Errorf("nonce contains invalid character %q at position %d", c, i) + } + } + return nil +} diff --git a/internal/attestation/collector_test.go b/internal/attestation/collector_test.go new file mode 100644 index 00000000..93423924 --- /dev/null +++ b/internal/attestation/collector_test.go @@ -0,0 +1,129 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestation + +import ( + "context" + "errors" + "os" + "os/exec" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +func TestValidateNonce(t *testing.T) { + require.NoError(t, validateNonce("abc123-_=/+")) + require.Error(t, validateNonce("")) + require.Error(t, validateNonce("bad nonce")) +} + +func TestCLIEvidenceCollectorRejectsInvalidNonce(t *testing.T) { + collector := NewCLIEvidenceCollector(time.Second) + _, err := collector.Collect(context.Background(), "bad nonce") + require.Error(t, err) +} + +func TestCLIEvidenceCollectorParsesResponse(t *testing.T) { + original := execCommandContext + t.Cleanup(func() { execCommandContext = original }) + execCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + argv := append([]string{"-test.run=^TestHelperProcess$", "--", name}, args...) + cmd := exec.CommandContext(ctx, os.Args[0], argv...) + cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1") + return cmd + } + + collector := NewCLIEvidenceCollector(time.Second) + resp, err := collector.Collect(context.Background(), "abc123") + require.NoError(t, err) + require.Equal(t, 200, resp.ResultCode) + require.Equal(t, "ok", resp.ResultMessage) +} + +func TestCLIEvidenceCollectorExecutionAndParseErrors(t *testing.T) { + original := execCommandContext + t.Cleanup(func() { execCommandContext = original }) + + execCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + argv := append([]string{"-test.run=^TestHelperProcessError$", "--", name}, args...) + cmd := exec.CommandContext(ctx, os.Args[0], argv...) + cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1", "HELPER_MODE=stderr_only") + return cmd + } + collector := NewCLIEvidenceCollector(time.Second) + _, err := collector.Collect(context.Background(), "abc123") + require.Error(t, err) + + execCommandContext = func(ctx context.Context, name string, args ...string) *exec.Cmd { + argv := append([]string{"-test.run=^TestHelperProcessError$", "--", name}, args...) + cmd := exec.CommandContext(ctx, os.Args[0], argv...) + cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1", "HELPER_MODE=bad_json") + return cmd + } + _, err = collector.Collect(context.Background(), "abc123") + require.Error(t, err) +} + +func TestBackendNonceProviderErrors(t *testing.T) { + _, _, _, err := NewBackendNonceProvider(nil).GetNonce(context.Background(), "node", "jwt") + require.ErrorContains(t, err, "backend client") + + client := &testNonceClient{} + _, _, _, err = NewBackendNonceProvider(client).GetNonce(context.Background(), "", "jwt") + require.ErrorContains(t, err, "node UUID") + _, _, _, err = NewBackendNonceProvider(client).GetNonce(context.Background(), "node", "") + require.ErrorContains(t, err, "jwt") + + nilClient := &nilNonceClient{} + _, _, _, err = NewBackendNonceProvider(nilClient).GetNonce(context.Background(), "node", "jwt") + require.ErrorContains(t, err, "nil") +} + +type nilNonceClient struct{} + +func (c *nilNonceClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return nil, nil +} + +func TestHelperProcessError(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + switch os.Getenv("HELPER_MODE") { + case "stderr_only": + _, _ = os.Stderr.WriteString("boom") + os.Exit(1) + case "bad_json": + _, _ = os.Stdout.WriteString("{") + _, _ = os.Stderr.WriteString("warn") + os.Exit(1) + default: + _ = errors.New("unused") + os.Exit(2) + } +} + +func TestHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + _, _ = os.Stdout.WriteString(`{"evidences":[],"resultCode":200,"resultMessage":"ok"}`) + os.Exit(0) +} diff --git a/internal/attestation/manager.go b/internal/attestation/manager.go new file mode 100644 index 00000000..000e2092 --- /dev/null +++ b/internal/attestation/manager.go @@ -0,0 +1,323 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestation + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "math/big" + "sync" + "time" + + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +// JWTProvider retrieves the current backend JWT. +type JWTProvider interface { + GetJWT(ctx context.Context) (string, error) + SetJWT(ctx context.Context, value string) error +} + +// Manager coordinates periodic attestation collection into a store. +type Manager interface { + Run(ctx context.Context) error + CollectOnce(ctx context.Context) (*Result, error) + LastResult() *Result + IsResultUpdated(since time.Time) bool +} + +type manager struct { + mu sync.RWMutex + nodeUUIDProvider func(context.Context) (string, error) + jwtProvider JWTProvider + nonceProvider NonceProvider + collector EvidenceCollector + submitter Submitter + config AttestationConfig + + lastResult *Result + lastUpdated time.Time +} + +// NewManager creates an attestation loop manager skeleton. +func NewManager( + nodeUUIDProvider func(context.Context) (string, error), + jwtProvider JWTProvider, + nonceProvider NonceProvider, + collector EvidenceCollector, + submitter Submitter, + cfg AttestationConfig, +) Manager { + return &manager{ + nodeUUIDProvider: nodeUUIDProvider, + jwtProvider: jwtProvider, + nonceProvider: nonceProvider, + collector: collector, + submitter: submitter, + config: cfg, + } +} + +func (m *manager) Run(ctx context.Context) error { + if m.nodeUUIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil || m.submitter == nil { + return fmt.Errorf("attestation loop dependencies are incomplete") + } + if m.config.Interval <= 0 { + return nil + } + startupInterval := m.config.Interval + if m.config.InitialInterval > 0 { + startupInterval = m.config.InitialInterval + } + if m.config.JitterEnabled { + jitter := calculateJitter(initialJitterCap(startupInterval)) + log.Logger.Infow("adding initial attestation startup jitter", "jitter_duration", jitter) + if err := sleepWithContext(ctx, jitter); err != nil { + return err + } + } + + firstSuccess := false + for { + _, err := m.CollectOnce(ctx) + nextInterval := m.config.Interval + if err == nil { + firstSuccess = true + } else { + log.Logger.Warnw("attestation workflow failed", "error", err) + switch { + case !firstSuccess && errors.Is(err, ErrNotEnrolled) && m.config.InitialInterval > 0: + nextInterval = m.config.InitialInterval + case m.config.RetryInterval > 0 && (nextInterval <= 0 || m.config.RetryInterval < nextInterval): + nextInterval = m.config.RetryInterval + if m.config.JitterEnabled { + nextInterval += calculateJitter(retryJitterCap(m.config.RetryInterval)) + } + } + } + if err := sleepWithContext(ctx, nextInterval); err != nil { + return err + } + } +} + +func (m *manager) CollectOnce(ctx context.Context) (*Result, error) { + if m.nodeUUIDProvider == nil || m.jwtProvider == nil || m.nonceProvider == nil || m.collector == nil || m.submitter == nil { + return nil, fmt.Errorf("attestation loop dependencies are incomplete") + } + + nodeUUID, err := m.nodeUUIDProvider(ctx) + if err != nil { + return nil, err + } + jwt, err := m.jwtProvider.GetJWT(ctx) + if err != nil { + return nil, err + } + nonce, refreshTS, refreshedJWT, err := m.nonceProvider.GetNonce(ctx, nodeUUID, jwt) + if err != nil { + return nil, err + } + if refreshedJWT != "" && refreshedJWT != jwt { + if err := m.jwtProvider.SetJWT(ctx, refreshedJWT); err != nil { + return nil, err + } + jwt = refreshedJWT + } + sdkResp, err := m.collector.Collect(ctx, nonce) + collectErr := err + result := &Result{ + CollectedAt: time.Now().UTC(), + NodeUUID: nodeUUID, + NonceRefreshTimestamp: refreshTS, + } + if err != nil { + result.Success = false + result.ErrorMessage = err.Error() + } else { + result.Success = true + } + if sdkResp != nil { + result.SDKResponse = *sdkResp + } + m.mu.Lock() + cloned := *result + m.lastResult = &cloned + m.lastUpdated = time.Now().UTC() + m.mu.Unlock() + if err := m.submitter.Submit(ctx, result, jwt); err != nil { + return nil, err + } + if collectErr != nil { + return result, collectErr + } + return result, nil +} + +func (m *manager) LastResult() *Result { + m.mu.RLock() + defer m.mu.RUnlock() + if m.lastResult == nil { + return nil + } + cloned := *m.lastResult + return &cloned +} + +func (m *manager) IsResultUpdated(since time.Time) bool { + m.mu.RLock() + defer m.mu.RUnlock() + return m.lastUpdated.After(since) +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } + } + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func calculateJitter(maxJitter time.Duration) time.Duration { + if maxJitter <= 0 { + return 0 + } + maxMs := int64(maxJitter / time.Millisecond) + if maxMs <= 0 { + return 0 + } + randomMs, err := rand.Int(rand.Reader, big.NewInt(maxMs)) + if err != nil { + log.Logger.Warnw("failed to generate secure attestation jitter, using fallback", "error", err) + return time.Duration(time.Now().UnixNano()%maxMs) * time.Millisecond + } + return time.Duration(randomMs.Int64()) * time.Millisecond +} + +func initialJitterCap(interval time.Duration) time.Duration { + if interval <= 0 { + return 0 + } + jitter := interval / 4 + const maxInitialJitter = 30 * time.Minute + if jitter > maxInitialJitter { + return maxInitialJitter + } + return jitter +} + +func retryJitterCap(retryInterval time.Duration) time.Duration { + if retryInterval <= 0 { + return 0 + } + jitter := retryInterval / 2 + const maxRetryJitter = 5 * time.Minute + if jitter > maxRetryJitter { + return maxRetryJitter + } + return jitter +} + +type backendSubmitter struct { + client BackendClient +} + +// BackendClient is the backend client view required by the attestation workflow. +type BackendClient interface { + SubmitAttestation(ctx context.Context, nodeUUID string, req *backendclient.AttestationRequest, jwt string) error +} + +// NewBackendSubmitter creates a backend submitter backed by the agent backend client. +func NewBackendSubmitter(client BackendClient) Submitter { + return &backendSubmitter{client: client} +} + +func (s *backendSubmitter) Submit(ctx context.Context, result *Result, jwt string) error { + if s.client == nil { + return fmt.Errorf("attestation submission requires backend client") + } + if result == nil { + return fmt.Errorf("attestation submission requires result") + } + if jwt == "" { + return fmt.Errorf("attestation submission requires jwt") + } + return s.client.SubmitAttestation(ctx, result.NodeUUID, toAttestationRequest(result), jwt) +} + +type stateJWTProvider struct { + state agentstate.State +} + +// NewStateJWTProvider returns a JWT provider backed by persisted agent state. +func NewStateJWTProvider(state agentstate.State) JWTProvider { + return &stateJWTProvider{state: state} +} + +func (p *stateJWTProvider) GetJWT(ctx context.Context) (string, error) { + if p.state == nil { + return "", fmt.Errorf("jwt provider requires agent state") + } + value, ok, err := p.state.GetJWT(ctx) + if err != nil { + return "", err + } + if !ok || value == "" { + return "", fmt.Errorf("%w: jwt not available in agent state", ErrNotEnrolled) + } + return value, nil +} + +func (p *stateJWTProvider) SetJWT(ctx context.Context, value string) error { + if p.state == nil { + return fmt.Errorf("jwt provider requires agent state") + } + return p.state.SetJWT(ctx, value) +} + +// NewStateNodeUUIDProvider returns a node UUID provider backed by persisted agent state. +func NewStateNodeUUIDProvider(state agentstate.State) func(context.Context) (string, error) { + return func(ctx context.Context) (string, error) { + if state == nil { + return "", fmt.Errorf("node UUID provider requires agent state") + } + value, ok, err := state.GetNodeUUID(ctx) + if err != nil { + return "", err + } + if !ok || value == "" { + return "", fmt.Errorf("%w: node UUID not available in agent state", ErrNotEnrolled) + } + return value, nil + } +} diff --git a/internal/attestation/manager_test.go b/internal/attestation/manager_test.go new file mode 100644 index 00000000..fb7b5d12 --- /dev/null +++ b/internal/attestation/manager_test.go @@ -0,0 +1,256 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestation + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type testJWTProvider struct { + jwt string + setJWT string +} + +func (p *testJWTProvider) GetJWT(context.Context) (string, error) { return p.jwt, nil } +func (p *testJWTProvider) SetJWT(_ context.Context, value string) error { + p.setJWT = value + p.jwt = value + return nil +} + +type testNonceProvider struct { + nonce string + refreshTS time.Time + refreshedJWT string + err error +} + +func (p *testNonceProvider) GetNonce(context.Context, string, string) (string, time.Time, string, error) { + return p.nonce, p.refreshTS, p.refreshedJWT, p.err +} + +type testEvidenceCollector struct { + resp *SDKResponse + err error + n int +} + +func (c *testEvidenceCollector) Collect(context.Context, string) (*SDKResponse, error) { + c.n++ + return c.resp, c.err +} + +type submitted struct { + result *Result + jwt string +} + +type testSubmitter struct { + submitted submitted + err error + count int +} + +func (s *testSubmitter) Submit(_ context.Context, result *Result, jwt string) error { + s.count++ + s.submitted = submitted{result: result, jwt: jwt} + return s.err +} + +func TestCollectOnceSuccess(t *testing.T) { + refreshTS := time.Now().UTC() + jwtProvider := &testJWTProvider{jwt: "old-jwt"} + submitter := &testSubmitter{} + manager := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + jwtProvider, + &testNonceProvider{nonce: "abc123", refreshTS: refreshTS, refreshedJWT: "new-jwt"}, + &testEvidenceCollector{resp: &SDKResponse{ResultCode: 200, ResultMessage: "ok"}}, + submitter, + AttestationConfig{}, + ) + + result, err := manager.CollectOnce(context.Background()) + require.NoError(t, err) + require.True(t, result.Success) + require.Equal(t, "node-1", result.NodeUUID) + require.Equal(t, refreshTS, result.NonceRefreshTimestamp) + require.Equal(t, "new-jwt", jwtProvider.setJWT) + require.NotNil(t, submitter.submitted.result) + require.Equal(t, "new-jwt", submitter.submitted.jwt) + require.True(t, submitter.submitted.result.Success) +} + +func TestCollectOnceCollectorFailureStillSubmitsFailureResult(t *testing.T) { + submitter := &testSubmitter{} + manager := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + &testEvidenceCollector{err: errors.New("collect failed")}, + submitter, + AttestationConfig{}, + ) + + result, err := manager.CollectOnce(context.Background()) + require.ErrorContains(t, err, "collect failed") + require.False(t, result.Success) + require.Equal(t, "collect failed", result.ErrorMessage) + require.NotNil(t, submitter.submitted.result) + require.False(t, submitter.submitted.result.Success) +} + +func TestCollectOnceMissingDependencies(t *testing.T) { + _, err := NewManager(nil, nil, nil, nil, nil, AttestationConfig{}).CollectOnce(context.Background()) + require.Error(t, err) +} + +func TestManagerRunAndCachedResult(t *testing.T) { + submitter := &testSubmitter{} + mgr := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + &testEvidenceCollector{resp: &SDKResponse{ResultCode: 200}}, + submitter, + AttestationConfig{InitialInterval: 5 * time.Millisecond, Interval: 5 * time.Millisecond}, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + require.ErrorIs(t, mgr.Run(ctx), context.DeadlineExceeded) + + last := mgr.LastResult() + require.NotNil(t, last) + require.Equal(t, "node-1", last.NodeUUID) + require.True(t, mgr.IsResultUpdated(time.Time{})) +} + +func TestManagerRunUsesRetryIntervalOnFailure(t *testing.T) { + submitter := &testSubmitter{} + collector := &testEvidenceCollector{err: errors.New("collect failed")} + mgr := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + collector, + submitter, + AttestationConfig{InitialInterval: 5 * time.Millisecond, Interval: time.Hour, RetryInterval: 5 * time.Millisecond}, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + require.ErrorIs(t, mgr.Run(ctx), context.DeadlineExceeded) + + last := mgr.LastResult() + require.NotNil(t, last) + require.False(t, last.Success) + require.Equal(t, "collect failed", last.ErrorMessage) + require.GreaterOrEqual(t, collector.n, 2) + require.GreaterOrEqual(t, submitter.count, 2) +} + +func TestManagerHelpersAndSubmitterErrors(t *testing.T) { + mgr := NewManager( + func(context.Context) (string, error) { return "node-1", nil }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + &testEvidenceCollector{resp: &SDKResponse{}}, + &testSubmitter{}, + AttestationConfig{}, + ) + require.Nil(t, mgr.LastResult()) + require.False(t, mgr.IsResultUpdated(time.Now().UTC())) + + err := NewBackendSubmitter(nil).Submit(context.Background(), &Result{}, "jwt") + require.ErrorContains(t, err, "backend client") + err = NewBackendSubmitter(&recordingClient{}).Submit(context.Background(), nil, "jwt") + require.ErrorContains(t, err, "requires result") + err = NewBackendSubmitter(&recordingClient{}).Submit(context.Background(), &Result{}, "") + require.ErrorContains(t, err, "requires jwt") +} + +func TestStateJWTProviderAndNodeUUIDProviderErrors(t *testing.T) { + _, err := NewStateJWTProvider(nil).GetJWT(context.Background()) + require.ErrorContains(t, err, "requires agent state") + err = NewStateJWTProvider(nil).SetJWT(context.Background(), "x") + require.ErrorContains(t, err, "requires agent state") + + _, err = NewStateJWTProvider(&stubState{}).GetJWT(context.Background()) + require.ErrorContains(t, err, "jwt not available") + + _, err = NewStateNodeUUIDProvider(nil)(context.Background()) + require.ErrorContains(t, err, "requires agent state") + _, err = NewStateNodeUUIDProvider(&stubState{})(context.Background()) + require.ErrorContains(t, err, "node UUID not available") +} + +func TestSleepWithContext(t *testing.T) { + require.NoError(t, sleepWithContext(context.Background(), 0)) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, sleepWithContext(ctx, 0), context.Canceled) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, sleepWithContext(ctx, time.Hour), context.Canceled) +} + +func TestAttestationJitterHelpers(t *testing.T) { + require.Equal(t, time.Duration(0), initialJitterCap(0)) + require.Equal(t, 75*time.Second, initialJitterCap(5*time.Minute)) + require.Equal(t, 30*time.Minute, initialJitterCap(4*time.Hour)) + + require.Equal(t, time.Duration(0), retryJitterCap(0)) + require.Equal(t, 150*time.Second, retryJitterCap(5*time.Minute)) + require.Equal(t, 5*time.Minute, retryJitterCap(20*time.Minute)) + + require.Equal(t, time.Duration(0), calculateJitter(0)) + jitter := calculateJitter(50 * time.Millisecond) + require.GreaterOrEqual(t, jitter, time.Duration(0)) + require.Less(t, jitter, 50*time.Millisecond) +} + +func TestManagerRunUsesInitialIntervalWithoutJitterWhenNotEnrolled(t *testing.T) { + mgr := NewManager( + func(context.Context) (string, error) { return "", ErrNotEnrolled }, + &testJWTProvider{jwt: "jwt-token"}, + &testNonceProvider{nonce: "abc123"}, + &testEvidenceCollector{resp: &SDKResponse{ResultCode: 200}}, + &testSubmitter{}, + AttestationConfig{ + InitialInterval: 5 * time.Millisecond, + Interval: time.Hour, + RetryInterval: time.Minute, + }, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + start := time.Now() + err := mgr.Run(ctx) + elapsed := time.Since(start) + + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, elapsed, 15*time.Millisecond) + require.Less(t, elapsed, 100*time.Millisecond) +} diff --git a/internal/attestation/nonce.go b/internal/attestation/nonce.go new file mode 100644 index 00000000..074e7120 --- /dev/null +++ b/internal/attestation/nonce.go @@ -0,0 +1,59 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestation + +import ( + "context" + "fmt" + "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +// NonceBackendClient is the backend client view required by the nonce provider. +type NonceBackendClient interface { + GetNonce(ctx context.Context, nodeUUID string, jwt string) (*backendclient.NonceResponse, error) +} + +type backendNonceProvider struct { + client NonceBackendClient +} + +// NewBackendNonceProvider creates a nonce provider backed by the agent backend client. +func NewBackendNonceProvider(client NonceBackendClient) NonceProvider { + return &backendNonceProvider{client: client} +} + +func (p *backendNonceProvider) GetNonce(ctx context.Context, nodeUUID, jwt string) (string, time.Time, string, error) { + if p.client == nil { + return "", time.Time{}, "", fmt.Errorf("nonce provider requires backend client") + } + if nodeUUID == "" { + return "", time.Time{}, "", fmt.Errorf("nonce provider requires node UUID") + } + if jwt == "" { + return "", time.Time{}, "", fmt.Errorf("nonce provider requires jwt") + } + + resp, err := p.client.GetNonce(ctx, nodeUUID, jwt) + if err != nil { + return "", time.Time{}, "", err + } + if resp == nil { + return "", time.Time{}, "", fmt.Errorf("nonce response is nil") + } + return resp.Nonce, resp.NonceRefreshTimestamp, resp.JWTAssertion, nil +} diff --git a/internal/attestation/nonce_test.go b/internal/attestation/nonce_test.go new file mode 100644 index 00000000..2cd9cc09 --- /dev/null +++ b/internal/attestation/nonce_test.go @@ -0,0 +1,58 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package attestation + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" +) + +type testNonceClient struct { + resp *backendclient.NonceResponse + gotNodeUUID string + gotJWT string +} + +func (c *testNonceClient) GetNonce(_ context.Context, nodeUUID, jwt string) (*backendclient.NonceResponse, error) { + c.gotNodeUUID = nodeUUID + c.gotJWT = jwt + return c.resp, nil +} + +func TestBackendNonceProvider(t *testing.T) { + refreshTS := time.Now().UTC() + client := &testNonceClient{ + resp: &backendclient.NonceResponse{ + Nonce: "abc123", + NonceRefreshTimestamp: refreshTS, + JWTAssertion: "new-jwt", + }, + } + provider := NewBackendNonceProvider(client) + + nonce, ts, jwt, err := provider.GetNonce(context.Background(), "node-1", "jwt-token") + require.NoError(t, err) + require.Equal(t, "abc123", nonce) + require.Equal(t, refreshTS, ts) + require.Equal(t, "new-jwt", jwt) + require.Equal(t, "node-1", client.gotNodeUUID) + require.Equal(t, "jwt-token", client.gotJWT) +} diff --git a/internal/attestation/types.go b/internal/attestation/types.go new file mode 100644 index 00000000..eebd3b28 --- /dev/null +++ b/internal/attestation/types.go @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package attestation owns the backend attestation workflow. +package attestation + +import ( + "context" + "errors" + "time" +) + +// ErrNotEnrolled indicates attestation cannot run yet because the agent is not enrolled. +var ErrNotEnrolled = errors.New("agent not enrolled") + +// Result is the agent-owned attestation state model for the new backend sync loop. +type Result struct { + CollectedAt time.Time + NodeUUID string + NonceRefreshTimestamp time.Time + Success bool + ErrorMessage string + SDKResponse SDKResponse +} + +type SDKResponse struct { + Evidences []EvidenceItem + ResultCode int + ResultMessage string +} + +type EvidenceItem struct { + Arch string + Certificate string + DriverVersion string + Evidence string + Nonce string + VBIOSVersion string + Version string +} + +// NonceProvider retrieves a backend nonce for a node. +type NonceProvider interface { + GetNonce(ctx context.Context, nodeUUID, jwt string) (nonce string, refreshTS time.Time, refreshedJWT string, err error) +} + +// EvidenceCollector collects attestation evidence from local tooling. +type EvidenceCollector interface { + Collect(ctx context.Context, nonce string) (*SDKResponse, error) +} + +// Submitter submits attestation results to the backend. +type Submitter interface { + Submit(ctx context.Context, result *Result, jwt string) error +} + +// AttestationConfig controls periodic attestation workflow scheduling. +type AttestationConfig struct { + InitialInterval time.Duration + Interval time.Duration + RetryInterval time.Duration + JitterEnabled bool +} diff --git a/internal/backendclient/client.go b/internal/backendclient/client.go new file mode 100644 index 00000000..31958e2b --- /dev/null +++ b/internal/backendclient/client.go @@ -0,0 +1,250 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package backendclient provides the agent-facing client for backend workflows. +package backendclient + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" +) + +const ( + userAgent = "fleet-intelligence-agent" + maxResponseBodyBytes = 1 << 20 +) + +var errRedirectNotAllowed = errors.New("backend redirects are not allowed") +var errNilBaseURL = errors.New("backend base URL is required") + +// Client is the backend workflow client used by enrollment, inventory, and attestation paths. +type Client interface { + Enroll(ctx context.Context, sakToken string) (jwt string, err error) + UpsertNode(ctx context.Context, nodeUUID string, req *NodeUpsertRequest, jwt string) error + GetNonce(ctx context.Context, nodeUUID string, jwt string) (*NonceResponse, error) + SubmitAttestation(ctx context.Context, nodeUUID string, req *AttestationRequest, jwt string) error + RefreshToken(ctx context.Context, jwt string) (newJWT string, err error) +} + +type client struct { + httpClient *http.Client + baseURL *url.URL +} + +// New creates a backend client from a validated backend base URL. +func New(rawBaseURL string) (Client, error) { + baseURL, err := endpoint.ValidateBackendEndpoint(rawBaseURL) + if err != nil { + return nil, fmt.Errorf("invalid backend base URL: %w", err) + } + + return NewWithHTTPClient(baseURL, &http.Client{Timeout: 30 * time.Second}), nil +} + +// NewWithHTTPClient creates a backend client with an explicit HTTP client. +func NewWithHTTPClient(baseURL *url.URL, httpClient *http.Client) Client { + if httpClient == nil { + httpClient = &http.Client{Timeout: 30 * time.Second} + } + if httpClient.CheckRedirect == nil { + httpClient.CheckRedirect = func(*http.Request, []*http.Request) error { + return errRedirectNotAllowed + } + } + return &client{ + httpClient: httpClient, + baseURL: baseURL, + } +} + +func (c *client) Enroll(ctx context.Context, sakToken string) (string, error) { + if sakToken == "" { + return "", fmt.Errorf("sakToken cannot be empty") + } + + var resp struct { + JWTAssertion string `json:"jwtAssertion"` + } + if err := c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "enroll"}, sakToken, nil, &resp); err != nil { + return "", mapEnrollError(err) + } + if resp.JWTAssertion == "" { + return "", fmt.Errorf("enrollment response missing jwtAssertion field") + } + return resp.JWTAssertion, nil +} + +func (c *client) UpsertNode(ctx context.Context, nodeUUID string, req *NodeUpsertRequest, jwt string) error { + if nodeUUID == "" { + return fmt.Errorf("nodeUUID cannot be empty") + } + if jwt == "" { + return fmt.Errorf("jwt cannot be empty") + } + if req == nil { + return fmt.Errorf("node upsert request cannot be nil") + } + return c.doJSON(ctx, http.MethodPut, []string{"v1", "agent", "nodes", nodeUUID}, jwt, req, nil) +} + +func (c *client) GetNonce(ctx context.Context, nodeUUID string, jwt string) (*NonceResponse, error) { + if nodeUUID == "" { + return nil, fmt.Errorf("nodeUUID cannot be empty") + } + if jwt == "" { + return nil, fmt.Errorf("jwt cannot be empty") + } + + var resp NonceResponse + if err := c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "nodes", nodeUUID, "nonce"}, jwt, nil, &resp); err != nil { + return nil, err + } + if resp.Nonce == "" { + return nil, fmt.Errorf("nonce response missing nonce field") + } + return &resp, nil +} + +func (c *client) SubmitAttestation(ctx context.Context, nodeUUID string, req *AttestationRequest, jwt string) error { + if nodeUUID == "" { + return fmt.Errorf("nodeUUID cannot be empty") + } + if jwt == "" { + return fmt.Errorf("jwt cannot be empty") + } + if req == nil { + return fmt.Errorf("attestation request cannot be nil") + } + return c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "nodes", nodeUUID, "attestation"}, jwt, req, nil) +} + +func (c *client) RefreshToken(ctx context.Context, jwt string) (string, error) { + if jwt == "" { + return "", fmt.Errorf("jwt cannot be empty") + } + + var resp struct { + JWTAssertion string `json:"jwtAssertion"` + } + req := struct { + JWTAssertion string `json:"jwtAssertion"` + }{ + JWTAssertion: jwt, + } + if err := c.doJSON(ctx, http.MethodPost, []string{"v1", "agent", "token"}, "", req, &resp); err != nil { + return "", err + } + if resp.JWTAssertion == "" { + return "", fmt.Errorf("token refresh response missing jwtAssertion field") + } + return resp.JWTAssertion, nil +} + +func (c *client) doJSON(ctx context.Context, method string, pathElems []string, bearerToken string, reqBody any, respBody any) error { + if c.baseURL == nil { + return errNilBaseURL + } + requestURL, err := endpoint.JoinPath(c.baseURL, pathElems...) + if err != nil { + return fmt.Errorf("failed to construct request URL: %w", err) + } + + var bodyReader io.Reader + if reqBody != nil { + payload, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + bodyReader = bytes.NewReader(payload) + } + + req, err := http.NewRequestWithContext(ctx, method, requestURL, bodyReader) + if err != nil { + return fmt.Errorf("failed to create HTTP request: %w", err) + } + req.Header.Set("User-Agent", userAgent) + if reqBody != nil { + req.Header.Set("Content-Type", "application/json") + } + if bearerToken != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", bearerToken)) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to make backend request: %w", err) + } + defer resp.Body.Close() + + data, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodyBytes+1)) + if err != nil { + return fmt.Errorf("failed to read backend response: %w", err) + } + if len(data) > maxResponseBodyBytes { + return fmt.Errorf("backend response too large (max %d bytes)", maxResponseBodyBytes) + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return &HTTPStatusError{ + StatusCode: resp.StatusCode, + Body: string(bytes.TrimSpace(data)), + } + } + + if respBody == nil || len(bytes.TrimSpace(data)) == 0 { + return nil + } + if err := json.Unmarshal(data, respBody); err != nil { + return fmt.Errorf("failed to parse backend response: %w", err) + } + return nil +} + +func mapEnrollError(err error) error { + var statusErr *HTTPStatusError + if !errors.As(err, &statusErr) { + return err + } + + switch statusErr.StatusCode { + case http.StatusBadRequest: + return fmt.Errorf("the token used in the enrollment is not in the correct format") + case http.StatusUnauthorized: + return fmt.Errorf("the token used in the enrollment is incorrect") + case http.StatusForbidden: + return fmt.Errorf("the token used in the enrollment is incorrect/expired") + case http.StatusNotFound: + return fmt.Errorf("the endpoint is not found") + case http.StatusTooManyRequests: + return fmt.Errorf("please retry after some time; server is under heavy load") + case http.StatusBadGateway: + return fmt.Errorf("some temporary issue caused enrollment to fail") + case http.StatusServiceUnavailable: + return fmt.Errorf("service is unavailable currently") + case http.StatusGatewayTimeout: + return fmt.Errorf("service is experiencing load and is slow to respond") + default: + return err + } +} diff --git a/internal/backendclient/client_test.go b/internal/backendclient/client_test.go new file mode 100644 index 00000000..8d8c4670 --- /dev/null +++ b/internal/backendclient/client_test.go @@ -0,0 +1,384 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backendclient + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestClient_Enroll(t *testing.T) { + t.Parallel() + + var ( + gotMethod string + gotPath string + gotAuth string + gotUA string + ) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + gotUA = r.Header.Get("User-Agent") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"jwtAssertion": "jwt-token"}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + + jwt, err := c.Enroll(context.Background(), "sak-token") + require.NoError(t, err) + require.Equal(t, "jwt-token", jwt) + require.Equal(t, http.MethodPost, gotMethod) + require.Equal(t, "/v1/agent/enroll", gotPath) + require.Equal(t, "Bearer sak-token", gotAuth) + require.Equal(t, userAgent, gotUA) +} + +func TestNew(t *testing.T) { + t.Parallel() + + c, err := New("https://backend.example.com") + require.NoError(t, err) + require.NotNil(t, c) +} + +func TestClient_UpsertNode(t *testing.T) { + t.Parallel() + + var ( + gotMethod string + gotPath string + gotAuth string + gotReq NodeUpsertRequest + ) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + _ = json.NewDecoder(r.Body).Decode(&gotReq) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + err := c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{Hostname: "node-1"}, "jwt-token") + require.NoError(t, err) + require.Equal(t, http.MethodPut, gotMethod) + require.Equal(t, "/v1/agent/nodes/node-1", gotPath) + require.Equal(t, "Bearer jwt-token", gotAuth) + require.Equal(t, "node-1", gotReq.Hostname) +} + +func TestClient_GetNonce(t *testing.T) { + t.Parallel() + + var ( + gotMethod string + gotPath string + gotAuth string + ) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + _ = json.NewEncoder(w).Encode(NonceResponse{ + Nonce: "abc123", + JWTAssertion: "new-jwt", + }) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + resp, err := c.GetNonce(context.Background(), "node-1", "jwt-token") + require.NoError(t, err) + require.Equal(t, "abc123", resp.Nonce) + require.Equal(t, "new-jwt", resp.JWTAssertion) + require.Equal(t, http.MethodPost, gotMethod) + require.Equal(t, "/v1/agent/nodes/node-1/nonce", gotPath) + require.Equal(t, "Bearer jwt-token", gotAuth) +} + +func TestClient_SubmitAttestation(t *testing.T) { + t.Parallel() + + var ( + gotMethod string + gotPath string + gotAuth string + ) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + err := c.SubmitAttestation(context.Background(), "node-1", &AttestationRequest{}, "jwt-token") + require.NoError(t, err) + require.Equal(t, http.MethodPost, gotMethod) + require.Equal(t, "/v1/agent/nodes/node-1/attestation", gotPath) + require.Equal(t, "Bearer jwt-token", gotAuth) +} + +func TestClient_RefreshToken(t *testing.T) { + t.Parallel() + + var ( + gotMethod string + gotPath string + gotReq struct { + JWTAssertion string `json:"jwtAssertion"` + } + ) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + _ = json.NewDecoder(r.Body).Decode(&gotReq) + + _ = json.NewEncoder(w).Encode(map[string]string{"jwtAssertion": "new-jwt-token"}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + jwt, err := c.RefreshToken(context.Background(), "jwt-token") + require.NoError(t, err) + require.Equal(t, "new-jwt-token", jwt) + require.Equal(t, http.MethodPost, gotMethod) + require.Equal(t, "/v1/agent/token", gotPath) + require.Equal(t, "jwt-token", gotReq.JWTAssertion) +} + +func TestClient_EnrollMapsHTTPStatus(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "bad token", http.StatusUnauthorized) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.Enroll(context.Background(), "sak-token") + require.Error(t, err) + require.Contains(t, err.Error(), "incorrect") +} + +func TestClient_ValidationErrors(t *testing.T) { + t.Parallel() + + c := NewWithHTTPClient(mustParseURL(t, "https://backend.example.com"), nil) + + _, err := c.Enroll(context.Background(), "") + require.ErrorContains(t, err, "sakToken cannot be empty") + + err = c.UpsertNode(context.Background(), "", &NodeUpsertRequest{}, "jwt") + require.ErrorContains(t, err, "nodeUUID cannot be empty") + err = c.UpsertNode(context.Background(), "node-1", nil, "jwt") + require.ErrorContains(t, err, "cannot be nil") + err = c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{}, "") + require.ErrorContains(t, err, "jwt cannot be empty") + + _, err = c.GetNonce(context.Background(), "", "jwt") + require.ErrorContains(t, err, "nodeUUID cannot be empty") + _, err = c.GetNonce(context.Background(), "node-1", "") + require.ErrorContains(t, err, "jwt cannot be empty") + + err = c.SubmitAttestation(context.Background(), "", &AttestationRequest{}, "jwt") + require.ErrorContains(t, err, "nodeUUID cannot be empty") + err = c.SubmitAttestation(context.Background(), "node-1", nil, "jwt") + require.ErrorContains(t, err, "cannot be nil") + err = c.SubmitAttestation(context.Background(), "node-1", &AttestationRequest{}, "") + require.ErrorContains(t, err, "jwt cannot be empty") + + _, err = c.RefreshToken(context.Background(), "") + require.ErrorContains(t, err, "jwt cannot be empty") + + c = NewWithHTTPClient(nil, nil) + _, err = c.Enroll(context.Background(), "sak-token") + require.ErrorIs(t, err, errNilBaseURL) +} + +func TestClient_RejectsRedirects(t *testing.T) { + t.Parallel() + + redirected := false + target := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + redirected = true + w.WriteHeader(http.StatusOK) + })) + defer target.Close() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, target.URL, http.StatusTemporaryRedirect) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.Enroll(context.Background(), "sak-token") + require.ErrorContains(t, err, errRedirectNotAllowed.Error()) + require.False(t, redirected) +} + +func TestClient_ResponseValidationAndErrors(t *testing.T) { + t.Parallel() + + t.Run("missing jwt assertion in enroll", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.Enroll(context.Background(), "sak-token") + require.ErrorContains(t, err, "missing jwtAssertion") + }) + + t.Run("missing nonce field", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.GetNonce(context.Background(), "node-1", "jwt-token") + require.ErrorContains(t, err, "missing nonce") + }) + + t.Run("missing refresh jwt assertion", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]string{}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.RefreshToken(context.Background(), "jwt-token") + require.ErrorContains(t, err, "missing jwtAssertion") + }) + + t.Run("invalid json response", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("{invalid")) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + err := c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{Hostname: "node-1"}, "jwt-token") + require.NoError(t, err) + + _, err = c.GetNonce(context.Background(), "node-1", "jwt-token") + require.ErrorContains(t, err, "failed to parse backend response") + }) + + t.Run("http client error", func(t *testing.T) { + c := NewWithHTTPClient(mustParseURL(t, "https://backend.example.com"), &http.Client{ + Transport: roundTripFunc(func(*http.Request) (*http.Response, error) { + return nil, errors.New("network boom") + }), + }) + err := c.UpsertNode(context.Background(), "node-1", &NodeUpsertRequest{Hostname: "node-1"}, "jwt-token") + require.ErrorContains(t, err, "failed to make backend request") + }) + + t.Run("oversized response body", func(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(strings.Repeat("a", maxResponseBodyBytes+10))) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.GetNonce(context.Background(), "node-1", "jwt-token") + require.ErrorContains(t, err, "response too large") + }) +} + +func TestClient_HandlerAssertionsDoNotRace(t *testing.T) { + t.Parallel() + + var ( + mu sync.Mutex + method string + ) + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + method = r.Method + mu.Unlock() + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"jwtAssertion": "jwt-token"}) + })) + defer server.Close() + + c := NewWithHTTPClient(mustParseURL(t, server.URL), server.Client()) + _, err := c.Enroll(context.Background(), "sak-token") + require.NoError(t, err) + + mu.Lock() + gotMethod := method + mu.Unlock() + assert.Equal(t, http.MethodPost, gotMethod) +} + +func TestMapEnrollErrorStatuses(t *testing.T) { + t.Parallel() + + cases := map[int]string{ + http.StatusBadRequest: "correct format", + http.StatusUnauthorized: "incorrect", + http.StatusForbidden: "incorrect/expired", + http.StatusNotFound: "not found", + http.StatusTooManyRequests: "retry after some time", + http.StatusBadGateway: "temporary issue", + http.StatusServiceUnavailable: "unavailable", + http.StatusGatewayTimeout: "slow to respond", + } + + for status, want := range cases { + got := mapEnrollError(&HTTPStatusError{StatusCode: status}) + require.ErrorContains(t, got, want) + } + + other := &HTTPStatusError{StatusCode: http.StatusTeapot} + require.Equal(t, other, mapEnrollError(other)) + + plain := errors.New("plain") + require.Equal(t, plain, mapEnrollError(plain)) +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { + return f(r) +} + +func mustParseURL(t *testing.T, raw string) *url.URL { + t.Helper() + parsed, err := url.Parse(raw) + require.NoError(t, err) + return parsed +} diff --git a/internal/backendclient/errors.go b/internal/backendclient/errors.go new file mode 100644 index 00000000..890c2dd4 --- /dev/null +++ b/internal/backendclient/errors.go @@ -0,0 +1,55 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backendclient + +import ( + "errors" + "fmt" + "strings" +) + +// ErrNotImplemented is returned by skeleton backend client methods. +var ErrNotImplemented = errors.New("backend client not implemented") + +// HTTPStatusError captures a non-2xx backend response. +type HTTPStatusError struct { + StatusCode int + Body string +} + +func (e *HTTPStatusError) Error() string { + body := sanitizeErrorBody(e.Body) + if body != "" { + return fmt.Sprintf("backend request failed with status %d: %s", e.StatusCode, body) + } + return fmt.Sprintf("backend request failed with status %d", e.StatusCode) +} + +func sanitizeErrorBody(body string) string { + const maxLen = 200 + + body = strings.TrimSpace(body) + if body == "" { + return "" + } + body = strings.ReplaceAll(body, "\n", " ") + body = strings.ReplaceAll(body, "\r", " ") + body = strings.Join(strings.Fields(body), " ") + if len(body) <= maxLen { + return body + } + return body[:maxLen] + "...(truncated)" +} diff --git a/internal/backendclient/errors_test.go b/internal/backendclient/errors_test.go new file mode 100644 index 00000000..ebee5327 --- /dev/null +++ b/internal/backendclient/errors_test.go @@ -0,0 +1,46 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backendclient + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestHTTPStatusErrorError(t *testing.T) { + t.Parallel() + + err := (&HTTPStatusError{StatusCode: 500, Body: "line1\nline2"}).Error() + require.Contains(t, err, "status 500") + require.Contains(t, err, "line1 line2") + + err = (&HTTPStatusError{StatusCode: 404}).Error() + require.Equal(t, "backend request failed with status 404", err) +} + +func TestSanitizeErrorBody(t *testing.T) { + t.Parallel() + + require.Empty(t, sanitizeErrorBody(" \n\t ")) + require.Equal(t, "hello world", sanitizeErrorBody(" hello \n world ")) + + body := strings.Repeat("a", 220) + got := sanitizeErrorBody(body) + require.True(t, strings.HasSuffix(got, "...(truncated)")) + require.LessOrEqual(t, len(got), 214) +} diff --git a/internal/backendclient/types.go b/internal/backendclient/types.go new file mode 100644 index 00000000..e17d4d65 --- /dev/null +++ b/internal/backendclient/types.go @@ -0,0 +1,143 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package backendclient + +import "time" + +// NodeUpsertRequest is the backend DTO for node inventory upserts. +type NodeUpsertRequest struct { + Hostname string `json:"hostname"` + AgentConfig AgentConfig `json:"agentConfig,omitempty"` + Resources NodeResources `json:"resources"` + AgentVersion string `json:"agentVersion"` + GPUDriverVersion string `json:"gpuDriverVersion"` + CUDAVersion string `json:"cudaVersion"` + DCGMVersion string `json:"dcgmVersion"` + ContainerRuntimeVersion string `json:"containerRuntimeVersion"` + KernelVersion string `json:"kernelVersion"` + OSImage string `json:"osImage"` + OperatingSystem string `json:"operatingSystem"` + SystemUUID string `json:"systemUUID"` + MachineID string `json:"machineId"` + BootID string `json:"bootID"` + NetPrivateIP string `json:"netPrivateIP,omitempty"` +} + +type NodeResources struct { + CPUInfo CPUInfo `json:"cpuInfo"` + MemoryInfo MemoryInfo `json:"memoryInfo"` + GPUInfo GPUInfo `json:"gpuInfo"` + DiskInfo DiskInfo `json:"diskInfo"` + NICInfo NICInfo `json:"nicInfo"` +} + +type AgentConfig struct { + TotalComponents int64 `json:"totalComponents,omitempty"` + RetentionPeriodSeconds int64 `json:"retentionPeriodSeconds,omitempty"` + EnabledComponents []string `json:"enabledComponents,omitempty"` + DisabledComponents []string `json:"disabledComponents,omitempty"` +} + +type CPUInfo struct { + Type string `json:"type"` + Manufacturer string `json:"manufacturer"` + Architecture string `json:"architecture"` + LogicalCores string `json:"logicalCores"` +} + +type MemoryInfo struct { + TotalBytes string `json:"totalBytes"` +} + +type GPUInfo struct { + Product string `json:"product"` + Manufacturer string `json:"manufacturer"` + Architecture string `json:"architecture"` + Memory string `json:"memory"` + GPUs []GPUDevice `json:"gpus"` +} + +type GPUDevice struct { + UUID string `json:"uuid"` + BusID string `json:"busID"` + SN string `json:"sn"` + MinorID string `json:"minorID"` + BoardID int `json:"boardID"` + VBIOSVersion string `json:"vbiosVersion"` + ChassisSN string `json:"chassisSN"` + GPUIndex string `json:"gpuIndex,omitempty"` +} + +type DiskInfo struct { + ContainerRootDisk string `json:"containerRootDisk"` + BlockDevices []BlockDevice `json:"blockDevices"` +} + +type BlockDevice struct { + Name string `json:"name"` + Type string `json:"type"` + Size int64 `json:"size"` + WWN string `json:"wwn"` + MountPoint string `json:"mountPoint"` + FSType string `json:"fsType"` + PartUUID string `json:"partUUID"` + Parents []string `json:"parents"` +} + +type NICInfo struct { + PrivateIPInterfaces []NICInterface `json:"privateIPInterfaces"` +} + +type NICInterface struct { + Interface string `json:"interface"` + MAC string `json:"mac"` + IP string `json:"ip"` +} + +// NonceResponse is the backend DTO for node-scoped nonce responses. +type NonceResponse struct { + Nonce string `json:"nonce"` + JWTAssertion string `json:"jwtAssertion,omitempty"` + NonceRefreshTimestamp time.Time `json:"nonceRefreshTimestamp"` +} + +// AttestationRequest is the backend DTO for attestation submission. +type AttestationRequest struct { + AttestationData AttestationData `json:"attestationData"` +} + +type AttestationData struct { + SDKResponse AttestationSDKResponse `json:"sdkResponse"` + NonceRefreshTimestamp time.Time `json:"nonceRefreshTimestamp"` + Success bool `json:"success"` + ErrorMessage string `json:"errorMessage,omitempty"` +} + +type AttestationSDKResponse struct { + Evidences []EvidenceItem `json:"evidences"` + ResultCode int `json:"resultCode"` + ResultMessage string `json:"resultMessage"` +} + +type EvidenceItem struct { + Arch string `json:"arch"` + Certificate string `json:"certificate"` + DriverVersion string `json:"driverVersion"` + Evidence string `json:"evidence"` + Nonce string `json:"nonce"` + VBIOSVersion string `json:"vbiosVersion"` + Version string `json:"version"` +} diff --git a/internal/config/config.go b/internal/config/config.go index f2eae596..711cc239 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -65,17 +65,36 @@ type Config struct { // SECURITY: Only accessible from localhost (127.0.0.0/8 or ::1). Disabled by default. EnableFaultInjection bool `json:"enable_fault_injection"` + // Inventory controls the periodic inventory loop. + Inventory *InventoryConfig `json:"inventory,omitempty"` + + // Attestation controls the periodic attestation loop. + Attestation *AttestationConfig `json:"attestation,omitempty"` + // Health Exporter Configuration HealthExporter *HealthExporterConfig `json:"health_exporter,omitempty"` } -// AttestationConfig holds configuration for the attestation process -type AttestationConfig struct { - // Interval is how often to run attestation (default: 24 hours) +// InventoryConfig holds configuration for the periodic inventory loop. +type InventoryConfig struct { + // Enabled controls whether the periodic inventory loop runs. + Enabled bool `json:"enabled"` + + // Interval is how often to collect and export inventory. Interval metav1.Duration `json:"interval"` +} - // JitterEnabled controls whether to add random jitter to attestation schedule - JitterEnabled bool `json:"jitter_enabled"` +// AttestationConfig holds configuration for the periodic attestation loop. +type AttestationConfig struct { + // Enabled controls whether the periodic attestation loop runs. + Enabled bool `json:"enabled"` + + // InitialInterval is how often to check enrollment and attempt the first attestation run + // before switching to the steady-state interval after the first successful attestation. + InitialInterval metav1.Duration `json:"initial_interval"` + + // Interval is how often to run attestation. + Interval metav1.Duration `json:"interval"` } // HealthExporterConfig holds configuration for the health data exporter @@ -86,9 +105,6 @@ type HealthExporterConfig struct { // LogsEndpoint is the specific endpoint for sending logs/events data LogsEndpoint string `json:"logs_endpoint"` - // Attestation configuration - Attestation AttestationConfig `json:"attestation"` - // AuthToken is the authentication token for HTTP requests AuthToken string `json:"auth_token,omitempty"` @@ -151,6 +167,21 @@ func (config *Config) Validate() error { return fmt.Errorf("retention_period must be at least 1 minute, got %v", config.RetentionPeriod.Duration) } + if err := validateLoopConfig("inventory", config.Inventory); err != nil { + return err + } + if err := validateLoopConfig("attestation", config.Attestation); err != nil { + return err + } + if config.Attestation != nil && config.Attestation.Enabled { + if config.Attestation.InitialInterval.Duration <= 0 { + return errors.New("attestation.initial_interval is required when attestation is enabled") + } + if config.Attestation.InitialInterval.Duration < time.Minute { + return fmt.Errorf("attestation.initial_interval must be at least 1 minute, got %v", config.Attestation.InitialInterval.Duration) + } + } + // Validate health exporter configuration if present if config.HealthExporter != nil { // Validate health check interval @@ -183,6 +214,51 @@ func (config *Config) Validate() error { return nil } +func validateLoopConfig(name string, cfg interface { + GetEnabled() bool + GetInterval() time.Duration +}) error { + if cfg == nil || !cfg.GetEnabled() { + return nil + } + if cfg.GetInterval() <= 0 { + return fmt.Errorf("%s.interval is required when %s is enabled", name, name) + } + if cfg.GetInterval() < time.Minute { + return fmt.Errorf("%s.interval must be at least 1 minute, got %v", name, cfg.GetInterval()) + } + return nil +} + +func (c *InventoryConfig) GetEnabled() bool { + return c != nil && c.Enabled +} + +func (c *InventoryConfig) GetInterval() time.Duration { + if c == nil { + return 0 + } + return c.Interval.Duration +} + +func (c *AttestationConfig) GetEnabled() bool { + return c != nil && c.Enabled +} + +func (c *AttestationConfig) GetInterval() time.Duration { + if c == nil { + return 0 + } + return c.Interval.Duration +} + +func (c *AttestationConfig) GetInitialInterval() time.Duration { + if c == nil { + return 0 + } + return c.InitialInterval.Duration +} + // ShouldEnable returns true if the component should be enabled. // If no components are specified, all components are enabled by default. func (config *Config) ShouldEnable(componentName string) bool { @@ -264,6 +340,17 @@ func (config *Config) ToConfigEntries(allComponentNames []string) []ConfigEntry return entries } +// InventoryAgentConfig returns the useful, non-sensitive subset of agent config that should be +// persisted with inventory rather than exported through telemetry. +func (config *Config) InventoryAgentConfig(allComponentNames []string) (retentionPeriodSeconds int64, enabled, disabled []string) { + if config == nil { + return 0, nil, nil + } + + enabled, disabled = config.getComponentLists(allComponentNames) + return int64(config.RetentionPeriod.Seconds()), enabled, disabled +} + // getComponentLists computes enabled/disabled lists from config rules against all available components. func (config *Config) getComponentLists(allComponentNames []string) (enabled, disabled []string) { enabled, disabled = []string{}, []string{} @@ -345,13 +432,6 @@ func extractHealthExporterEntries(cfg *HealthExporterConfig) []ConfigEntry { strValue = fmt.Sprintf("%d", int64(duration.Seconds())) } else if d, ok := value.Interface().(time.Duration); ok { strValue = fmt.Sprintf("%d", int64(d.Seconds())) - } else if att, ok := value.Interface().(AttestationConfig); ok { - // Handle nested AttestationConfig - add individual fields - entries = append(entries, - ConfigEntry{Key: "health_exporter.attestation_interval", Value: fmt.Sprintf("%d", int64(att.Interval.Seconds()))}, - ConfigEntry{Key: "health_exporter.attestation_jitter_enabled", Value: fmt.Sprintf("%t", att.JitterEnabled)}, - ) - continue } default: strValue = fmt.Sprintf("%v", value.Interface()) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 3d229b91..c6467633 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -39,6 +39,13 @@ func TestDefault(t *testing.T) { assert.Equal(t, DefaultAPIVersion, cfg.APIVersion) assert.Equal(t, DefaultListenAddress, cfg.Address) assert.Equal(t, DefaultRetentionPeriod, cfg.RetentionPeriod) + require.NotNil(t, cfg.Inventory) + assert.True(t, cfg.Inventory.Enabled) + assert.Equal(t, metav1.Duration{Duration: 1 * time.Hour}, cfg.Inventory.Interval) + require.NotNil(t, cfg.Attestation) + assert.True(t, cfg.Attestation.Enabled) + assert.Equal(t, metav1.Duration{Duration: 5 * time.Minute}, cfg.Attestation.InitialInterval) + assert.Equal(t, metav1.Duration{Duration: 24 * time.Hour}, cfg.Attestation.Interval) // State path should be set assert.NotEmpty(t, cfg.State, "State path should be set") @@ -85,6 +92,97 @@ func TestConfigValidation(t *testing.T) { assert.Error(t, err) assert.Contains(t, err.Error(), "retention_period must be at least 1 minute") }) + + t.Run("inventory sync enabled without interval", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Inventory: &InventoryConfig{ + Enabled: true, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "inventory.interval is required") + }) + + t.Run("inventory sync interval too short", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Inventory: &InventoryConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 500 * time.Millisecond}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "inventory.interval must be at least 1 minute") + }) + + t.Run("attestation enabled without interval", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Attestation: &AttestationConfig{ + Enabled: true, + InitialInterval: metav1.Duration{Duration: 5 * time.Minute}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "attestation.interval is required") + }) + + t.Run("attestation interval too short", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Attestation: &AttestationConfig{ + Enabled: true, + InitialInterval: metav1.Duration{Duration: 5 * time.Minute}, + Interval: metav1.Duration{Duration: 500 * time.Millisecond}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "attestation.interval must be at least 1 minute") + }) + + t.Run("attestation enabled without initial interval", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Attestation: &AttestationConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 24 * time.Hour}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "attestation.initial_interval is required") + }) + + t.Run("attestation initial interval too short", func(t *testing.T) { + cfg := &Config{ + Address: ":8080", + RetentionPeriod: metav1.Duration{Duration: time.Hour}, + Attestation: &AttestationConfig{ + Enabled: true, + InitialInterval: metav1.Duration{Duration: 500 * time.Millisecond}, + Interval: metav1.Duration{Duration: 24 * time.Hour}, + }, + } + + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "attestation.initial_interval must be at least 1 minute") + }) } func TestComponentSelection(t *testing.T) { @@ -555,9 +653,6 @@ func TestDefaultWithHealthExporter(t *testing.T) { require.NoError(t, err) assert.NotNil(t, cfg.HealthExporter) - // Attestation is always enabled, so we just check configuration - assert.Equal(t, metav1.Duration{Duration: 24 * time.Hour}, cfg.HealthExporter.Attestation.Interval) - assert.True(t, cfg.HealthExporter.Attestation.JitterEnabled) assert.Equal(t, metav1.Duration{Duration: 1 * time.Minute}, cfg.HealthExporter.Interval) assert.Equal(t, metav1.Duration{Duration: 30 * time.Second}, cfg.HealthExporter.Timeout) assert.True(t, cfg.HealthExporter.IncludeMetrics) @@ -634,12 +729,8 @@ func TestToConfigEntries(t *testing.T) { RetentionPeriod: metav1.Duration{Duration: 24 * time.Hour}, Components: []string{}, HealthExporter: &HealthExporterConfig{ - MetricsEndpoint: "https://example.com/metrics", - LogsEndpoint: "https://example.com/logs", - Attestation: AttestationConfig{ - Interval: metav1.Duration{Duration: 24 * time.Hour}, - JitterEnabled: true, - }, + MetricsEndpoint: "https://example.com/metrics", + LogsEndpoint: "https://example.com/logs", Interval: metav1.Duration{Duration: 1 * time.Minute}, Timeout: metav1.Duration{Duration: 30 * time.Second}, IncludeMetrics: true, @@ -662,7 +753,6 @@ func TestToConfigEntries(t *testing.T) { // Check for health exporter entries foundMetricsEndpoint := false foundLogsEndpoint := false - foundAttestation := false foundAuthToken := false for _, entry := range entries { @@ -672,10 +762,6 @@ func TestToConfigEntries(t *testing.T) { if entry.Key == "health_exporter.logs_endpoint" { foundLogsEndpoint = true } - if entry.Key == "health_exporter.attestation_jitter_enabled" { - foundAttestation = true - assert.Equal(t, "true", entry.Value) - } if entry.Key == "auth_token" || entry.Key == "health_exporter.auth_token" { foundAuthToken = true } @@ -684,7 +770,6 @@ func TestToConfigEntries(t *testing.T) { // Endpoints are excluded - they're enrollment-assigned, not user config assert.False(t, foundMetricsEndpoint, "metrics_endpoint should not be exported") assert.False(t, foundLogsEndpoint, "logs_endpoint should not be exported") - assert.True(t, foundAttestation) assert.False(t, foundAuthToken, "auth_token should not be exported") }) @@ -822,3 +907,17 @@ func TestGetComponentLists(t *testing.T) { assert.Equal(t, "[]", string(disabledJSON)) }) } + +func TestInventoryAgentConfig(t *testing.T) { + allComponents := []string{"cpu", "disk", "memory", "gpu"} + cfg := &Config{ + APIVersion: "v1", + RetentionPeriod: metav1.Duration{Duration: 24 * time.Hour}, + Components: []string{"*", "-memory", "-disk"}, + } + + retentionPeriodSeconds, enabled, disabled := cfg.InventoryAgentConfig(allComponents) + assert.Equal(t, int64(86400), retentionPeriodSeconds) + assert.ElementsMatch(t, []string{"cpu", "gpu"}, enabled) + assert.ElementsMatch(t, []string{"memory", "disk"}, disabled) +} diff --git a/internal/config/default.go b/internal/config/default.go index ec55391a..9d413dbc 100644 --- a/internal/config/default.go +++ b/internal/config/default.go @@ -75,17 +75,22 @@ func Default(ctx context.Context, opts ...OpOption) (*Config, error) { Address: DefaultListenAddress, RetentionPeriod: DefaultRetentionPeriod, EnableFaultInjection: false, // Disabled by default for security + Inventory: &InventoryConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 1 * time.Hour}, + }, + Attestation: &AttestationConfig{ + Enabled: true, + InitialInterval: metav1.Duration{Duration: 5 * time.Minute}, + Interval: metav1.Duration{Duration: 24 * time.Hour}, + }, NvidiaToolOverwrites: nvidiacommon.ToolOverwrites{ InfinibandClassRootDir: options.InfinibandClassRootDir, }, // Health exporter is enabled by default HealthExporter: &HealthExporterConfig{ - MetricsEndpoint: "", - LogsEndpoint: "", - Attestation: AttestationConfig{ - Interval: metav1.Duration{Duration: 24 * time.Hour}, - JitterEnabled: true, - }, + MetricsEndpoint: "", + LogsEndpoint: "", AuthToken: "", Interval: metav1.Duration{Duration: 1 * time.Minute}, Timeout: metav1.Duration{Duration: 30 * time.Second}, diff --git a/internal/endpoint/endpoint.go b/internal/endpoint/endpoint.go index 8d10e713..321fade4 100644 --- a/internal/endpoint/endpoint.go +++ b/internal/endpoint/endpoint.go @@ -109,18 +109,51 @@ func AgentBaseURL(serverURL *url.URL) *url.URL { return serverURL } -// ValidateBackendEndpoint validates a trusted backend HTTPS endpoint. +// ValidateBackendEndpoint validates a trusted backend endpoint. +// Production backends must use HTTPS. HTTP is accepted only for loopback hosts +// to support local development and smoke tests. func ValidateBackendEndpoint(raw string) (*url.URL, error) { parsed, err := parseURL(raw) if err != nil { return nil, err } + if parsed.Scheme == "http" { + host := parsed.Hostname() + if host == "" || !isLoopbackHost(host) { + return nil, fmt.Errorf("backend endpoint over http must use localhost or a loopback IP, got %q", host) + } + return parsed, nil + } if err := requireScheme(parsed, "https"); err != nil { return nil, err } return parsed, nil } +// DeriveBackendBaseURL converts a legacy backend endpoint URL into its backend base URL. +// For example, "https://backend.example.com/api/v1/health/metrics" becomes +// "https://backend.example.com". +func DeriveBackendBaseURL(raw string) (string, error) { + parsed, err := parseURL(raw) + if err != nil { + return "", err + } + if parsed.Scheme == "http" { + host := parsed.Hostname() + if host == "" || !isLoopbackHost(host) { + return "", fmt.Errorf("backend endpoint over http must use localhost or a loopback IP, got %q", host) + } + } else { + if err := requireScheme(parsed, "https"); err != nil { + return "", err + } + } + return (&url.URL{ + Scheme: parsed.Scheme, + Host: parsed.Host, + }).String(), nil +} + // JoinPath appends path elements to a validated base URL. func JoinPath(base *url.URL, elems ...string) (string, error) { return url.JoinPath(base.String(), elems...) diff --git a/internal/endpoint/endpoint_test.go b/internal/endpoint/endpoint_test.go index fe9fde54..3a084821 100644 --- a/internal/endpoint/endpoint_test.go +++ b/internal/endpoint/endpoint_test.go @@ -194,7 +194,29 @@ func TestValidateBackendEndpoint(t *testing.T) { _, err := ValidateBackendEndpoint("https://example.com/base") require.NoError(t, err) + _, err = ValidateBackendEndpoint("http://localhost:8080") + require.NoError(t, err) + + _, err = ValidateBackendEndpoint("http://127.0.0.1:8080") + require.NoError(t, err) + _, err = ValidateBackendEndpoint("http://example.com/base") require.Error(t, err) - assert.Contains(t, err.Error(), "https") + assert.Contains(t, err.Error(), "loopback") +} + +func TestDeriveBackendBaseURL(t *testing.T) { + t.Parallel() + + got, err := DeriveBackendBaseURL("https://backend.example.com/api/v1/health/metrics") + require.NoError(t, err) + assert.Equal(t, "https://backend.example.com", got) + + got, err = DeriveBackendBaseURL("http://localhost:8080/api/v1/health/metrics") + require.NoError(t, err) + assert.Equal(t, "http://localhost:8080", got) + + _, err = DeriveBackendBaseURL("http://example.com/api/v1/health/metrics") + require.Error(t, err) + assert.Contains(t, err.Error(), "loopback") } diff --git a/internal/enrollment/enrollment.go b/internal/enrollment/enrollment.go index 67c26175..4cf067b7 100644 --- a/internal/enrollment/enrollment.go +++ b/internal/enrollment/enrollment.go @@ -13,123 +13,141 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package enrollment provides shared enrollment functionality for the Fleet Intelligence agent +// Package enrollment provides shared enrollment functionality for the Fleet Intelligence agent. package enrollment import ( "context" - "encoding/json" "fmt" - "io" - "net/http" - "os" - "time" + "net/url" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" + pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" + nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/config" + "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + inventorysink "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/sink" + inventorysource "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/source" + "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" + "github.com/NVIDIA/fleet-intelligence-agent/internal/registry" ) -const maxEnrollmentResponseSize = 1 << 20 +var ( + newBackendClient = backendclient.New + syncInventoryAfterEnroll = syncInventoryOnce +) + +// Enroll runs the full enrollment workflow and performs a best-effort initial inventory sync. +func Enroll(ctx context.Context, baseEndpoint, sakToken string) error { + baseURL, err := normalizeBackendBaseURL(baseEndpoint) + if err != nil { + return fmt.Errorf("invalid enrollment endpoint: %w", err) + } -// EnrollResponse represents the response from the enrollment endpoint -type EnrollResponse struct { - JWTToken string `json:"jwt_assertion"` + client, err := newBackendClient(baseURL.String()) + if err != nil { + return fmt.Errorf("failed to create backend client: %w", err) + } + jwtToken, err := client.Enroll(ctx, sakToken) + if err != nil { + return err + } + if err := storeConfigInMetadata(ctx, baseURL.String(), jwtToken, sakToken); err != nil { + return fmt.Errorf("failed to store configuration: %w", err) + } + if err := syncInventoryAfterEnroll(ctx); err != nil { + log.Logger.Warnw("post-enroll inventory sync failed", "error", err) + } + return nil } -// PerformEnrollment performs the enrollment request to get a new JWT token -func PerformEnrollment(ctx context.Context, enrollEndpoint, sakToken string) (string, error) { - if enrollEndpoint == "" { - return "", fmt.Errorf("enrollEndpoint cannot be empty") +func normalizeBackendBaseURL(raw string) (*url.URL, error) { + baseURL, err := endpoint.ValidateBackendEndpoint(raw) + if err != nil { + return nil, err } - if sakToken == "" { - return "", fmt.Errorf("sakToken cannot be empty") + if baseURL.Path == "" || baseURL.Path == "/" { + return baseURL, nil } - // Use the provided enrollment endpoint directly - enrollURL := enrollEndpoint + normalized, err := endpoint.DeriveBackendBaseURL(raw) + if err != nil { + return nil, err + } + return endpoint.ValidateBackendEndpoint(normalized) +} - // Create HTTP client with timeout. Disable redirects so a compromised - // backend cannot bounce us to an internal service (SSRF). - client := &http.Client{ - Timeout: 30 * time.Second, - CheckRedirect: func(*http.Request, []*http.Request) error { - return http.ErrUseLastResponse - }, +func storeConfigInMetadata(ctx context.Context, baseURL, jwtToken, sakToken string) error { + stateFile, err := config.DefaultStateFile() + if err != nil { + return fmt.Errorf("failed to get state file path: %w", err) } - // Create HTTP request with empty body - req, err := http.NewRequestWithContext(ctx, "POST", enrollURL, nil) + dbRW, err := sqlite.Open(stateFile) if err != nil { - return "", fmt.Errorf("failed to create HTTP request: %w", err) + return fmt.Errorf("failed to open state database: %w", err) + } + defer dbRW.Close() + + if err := config.SecureStateFilePermissions(stateFile); err != nil { + return fmt.Errorf("failed to secure state database permissions: %w", err) + } + if err := pkgmetadata.CreateTableMetadata(ctx, dbRW); err != nil { + return fmt.Errorf("failed to create metadata table: %w", err) + } + + if err := pkgmetadata.SetMetadata(ctx, dbRW, agentstate.MetadataKeySAKToken, sakToken); err != nil { + return fmt.Errorf("failed to set SAK token: %w", err) + } + if err := pkgmetadata.SetMetadata(ctx, dbRW, pkgmetadata.MetadataKeyToken, jwtToken); err != nil { + return fmt.Errorf("failed to set JWT token: %w", err) } + if err := pkgmetadata.SetMetadata(ctx, dbRW, agentstate.MetadataKeyBackendBaseURL, baseURL); err != nil { + return fmt.Errorf("failed to set backend base URL: %w", err) + } + return nil +} - // Set headers (no Content-Type since no body is sent) - req.Header.Set("User-Agent", "fleet-intelligence-agent") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", sakToken)) +type machineInfoCollectorFunc func(context.Context) (*machineinfo.MachineInfo, error) + +func (f machineInfoCollectorFunc) Collect(ctx context.Context) (*machineinfo.MachineInfo, error) { + return f(ctx) +} - // Make the request - resp, err := client.Do(req) +func syncInventoryOnce(ctx context.Context) error { + state := agentstate.NewSQLite() + sink := inventorysink.NewBackendSink(state) + allComponents := registry.AllComponentNames() + + cfg, err := config.Default(ctx) if err != nil { - log.Logger.Errorw("Enrollment request failed", "error", err, "url", enrollURL) - return "", fmt.Errorf("failed to make enrollment request: %w", err) + return fmt.Errorf("load default config for inventory sync: %w", err) } - defer resp.Body.Close() + retentionPeriodSeconds, enabledComponents, disabledComponents := cfg.InventoryAgentConfig(allComponents) - // Read at most max+1 bytes so oversized responses fail explicitly instead of - // surfacing later as a truncated JSON parse error. - respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxEnrollmentResponseSize+1)) + nvmlInstance, err := nvidianvml.New() if err != nil { - log.Logger.Errorw("Failed to read enrollment response body", "error", err) - return "", fmt.Errorf("failed to read enrollment response: %w", err) - } - if len(respBody) > maxEnrollmentResponseSize { - err = fmt.Errorf("enrollment response too large (max %d bytes)", maxEnrollmentResponseSize) - log.Logger.Errorw("Failed to read enrollment response body", "error", err) - return "", err - } - - // Check response status and return specific error messages - if resp.StatusCode != http.StatusOK { - var errMsg string - switch resp.StatusCode { - case http.StatusBadRequest: // 400 - errMsg = "The token used in the enrollment is not in the correct format. Please check the token. If all else fails, generate a new token by going to the UI" - case http.StatusUnauthorized: // 401 - errMsg = "The token used in the enrollment is incorrect. Please generate a new token by going to the UI or make sure you are using the correct token" - case http.StatusForbidden: // 403 - errMsg = "The token used in the enrollment is incorrect/expired. Please generate a new token by going to the UI or make sure you are using the correct token" - case http.StatusNotFound: // 404 - errMsg = "The endpoint is not found. Please use the correct endpoint" - case http.StatusTooManyRequests: // 429 - errMsg = "Please retry after some time. Server is under heavy load" - case http.StatusBadGateway: // 502 - errMsg = "Some temporary issue caused enrollment to fail. Please try again" - case http.StatusServiceUnavailable: // 503 - errMsg = "Service is unavailable currently. Please try again" - case http.StatusGatewayTimeout: // 504 - errMsg = "Service is experiencing load and is slow to respond. Please try again maybe after a few minutes" - default: - errMsg = fmt.Sprintf("enrollment request failed with status %d", resp.StatusCode) - } - - // Print error to stderr - fmt.Fprintf(os.Stderr, "Enrollment failed: %s\n", errMsg) - return "", fmt.Errorf("%s", errMsg) - } - - // Parse response - var enrollResp EnrollResponse - if err := json.Unmarshal(respBody, &enrollResp); err != nil { - log.Logger.Errorw("Failed to parse enrollment response JSON", "error", err) - return "", fmt.Errorf("failed to parse enrollment response: %w", err) - } - - // Validate JWT token is present - if enrollResp.JWTToken == "" { - log.Logger.Errorw("Enrollment response missing jwt-token field") - return "", fmt.Errorf("enrollment response missing jwt-token field") - } - - // Print success to stdout - fmt.Fprintf(os.Stdout, "Enrollment succeeded\n") - return enrollResp.JWTToken, nil + return fmt.Errorf("initialize nvml for inventory sync: %w", err) + } + defer func() { _ = nvmlInstance.Shutdown() }() + + src := inventorysource.NewMachineInfoSourceWithAgentConfig( + machineInfoCollectorFunc(func(context.Context) (*machineinfo.MachineInfo, error) { + return machineinfo.GetMachineInfo(nvmlInstance) + }), + &inventory.AgentConfig{ + TotalComponents: int64(len(allComponents)), + RetentionPeriodSeconds: retentionPeriodSeconds, + EnabledComponents: enabledComponents, + DisabledComponents: disabledComponents, + }, + ) + manager := inventory.NewManager(src, sink, inventory.InventoryConfig{}) + _, err = manager.CollectOnce(ctx) + return err } diff --git a/internal/enrollment/enrollment_test.go b/internal/enrollment/enrollment_test.go index 74a1cc35..027a7413 100644 --- a/internal/enrollment/enrollment_test.go +++ b/internal/enrollment/enrollment_test.go @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,318 +17,217 @@ package enrollment import ( "context" - "encoding/json" - "fmt" - "net/http" - "net/http/httptest" + "errors" + "os" "testing" - "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" -) - -func TestPerformEnrollment_Success(t *testing.T) { - expectedToken := "test-jwt-token-12345" - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify request method - assert.Equal(t, "POST", r.Method) - // Verify headers - assert.Equal(t, "fleet-intelligence-agent", r.Header.Get("User-Agent")) - assert.Equal(t, "Bearer test-sak-token", r.Header.Get("Authorization")) - - // Send successful response - response := EnrollResponse{ - JWTToken: expectedToken, - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - require.NoError(t, err) - })) - defer server.Close() + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/config" +) - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") +type fakeBackendClient struct { + enrollSAK string + enrollJWT string + enrollErr error +} - require.NoError(t, err) - assert.Equal(t, expectedToken, token) +func (f *fakeBackendClient) Enroll(_ context.Context, sakToken string) (string, error) { + f.enrollSAK = sakToken + return f.enrollJWT, f.enrollErr } -func TestPerformEnrollment_EmptyEndpoint(t *testing.T) { - ctx := context.Background() - token, err := PerformEnrollment(ctx, "", "test-sak-token") +func (f *fakeBackendClient) UpsertNode(context.Context, string, *backendclient.NodeUpsertRequest, string) error { + return nil +} - require.Error(t, err) - assert.Contains(t, err.Error(), "enrollEndpoint cannot be empty") - assert.Empty(t, token) +func (f *fakeBackendClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return nil, nil } -func TestPerformEnrollment_EmptyToken(t *testing.T) { - ctx := context.Background() - token, err := PerformEnrollment(ctx, "http://example.com", "") +func (f *fakeBackendClient) SubmitAttestation(context.Context, string, *backendclient.AttestationRequest, string) error { + return nil +} - require.Error(t, err) - assert.Contains(t, err.Error(), "sakToken cannot be empty") - assert.Empty(t, token) +func (f *fakeBackendClient) RefreshToken(context.Context, string) (string, error) { + return "", nil } -func TestPerformEnrollment_HTTPStatusCodes(t *testing.T) { - tests := []struct { - name string - statusCode int - expectedErrMsg string - }{ - { - name: "BadRequest_400", - statusCode: http.StatusBadRequest, - expectedErrMsg: "The token used in the enrollment is not in the correct format", - }, - { - name: "Unauthorized_401", - statusCode: http.StatusUnauthorized, - expectedErrMsg: "The token used in the enrollment is incorrect", - }, - { - name: "Forbidden_403", - statusCode: http.StatusForbidden, - expectedErrMsg: "The token used in the enrollment is incorrect/expired", - }, - { - name: "NotFound_404", - statusCode: http.StatusNotFound, - expectedErrMsg: "The endpoint is not found", - }, - { - name: "TooManyRequests_429", - statusCode: http.StatusTooManyRequests, - expectedErrMsg: "Please retry after some time", - }, - { - name: "BadGateway_502", - statusCode: http.StatusBadGateway, - expectedErrMsg: "Some temporary issue caused enrollment to fail", - }, - { - name: "ServiceUnavailable_503", - statusCode: http.StatusServiceUnavailable, - expectedErrMsg: "Service is unavailable currently", - }, - { - name: "GatewayTimeout_504", - statusCode: http.StatusGatewayTimeout, - expectedErrMsg: "Service is experiencing load and is slow to respond", - }, - { - name: "InternalServerError_500", - statusCode: http.StatusInternalServerError, - expectedErrMsg: "enrollment request failed with status 500", - }, +func TestEnrollWorkflow(t *testing.T) { + originalFactory := newBackendClient + originalSync := syncInventoryAfterEnroll + t.Cleanup(func() { + newBackendClient = originalFactory + syncInventoryAfterEnroll = originalSync + }) + + client := &fakeBackendClient{enrollJWT: "jwt-token"} + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + require.Equal(t, "https://example.com", rawBaseURL) + return client, nil } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(tt.statusCode) - })) - defer server.Close() - - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") - - require.Error(t, err) - assert.Contains(t, err.Error(), tt.expectedErrMsg) - assert.Empty(t, token) - }) + syncCalled := false + syncInventoryAfterEnroll = func(ctx context.Context) error { + syncCalled = true + return nil } -} -func TestPerformEnrollment_DoesNotFollowRedirects(t *testing.T) { - redirectTargetCalled := false - redirectTarget := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - redirectTargetCalled = true - w.WriteHeader(http.StatusOK) - })) - defer redirectTarget.Close() - - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, redirectTarget.URL, http.StatusFound) - })) - defer server.Close() - - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") - - require.Error(t, err) - assert.Contains(t, err.Error(), "status 302") - assert.Empty(t, token) - assert.False(t, redirectTargetCalled) + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + + err := Enroll(context.Background(), "https://example.com", "sak-token") + require.NoError(t, err) + require.Equal(t, "sak-token", client.enrollSAK) + require.True(t, syncCalled) } -func TestPerformEnrollment_MissingJWTToken(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Send response without JWT token - response := EnrollResponse{ - JWTToken: "", - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - require.NoError(t, err) - })) - defer server.Close() +func TestEnrollWorkflowNormalizesLegacyEndpointToBaseURL(t *testing.T) { + originalFactory := newBackendClient + originalSync := syncInventoryAfterEnroll + t.Cleanup(func() { + newBackendClient = originalFactory + syncInventoryAfterEnroll = originalSync + }) + + client := &fakeBackendClient{enrollJWT: "jwt-token"} + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + require.Equal(t, "https://example.com", rawBaseURL) + return client, nil + } + syncInventoryAfterEnroll = func(context.Context) error { return nil } - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) - require.Error(t, err) - assert.Contains(t, err.Error(), "enrollment response missing jwt-token field") - assert.Empty(t, token) + err := Enroll(context.Background(), "https://example.com/api/v1/health/metrics", "sak-token") + require.NoError(t, err) + require.Equal(t, "sak-token", client.enrollSAK) } -func TestPerformEnrollment_InvalidJSON(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write([]byte("invalid json")) - require.NoError(t, err) - })) - defer server.Close() - - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") +func TestEnrollWorkflowErrors(t *testing.T) { + t.Run("invalid endpoint", func(t *testing.T) { + err := Enroll(context.Background(), "http://example.com", "sak-token") + require.Error(t, err) + }) + + t.Run("localhost http endpoint allowed", func(t *testing.T) { + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + + called := false + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + called = true + require.Equal(t, "http://localhost:8080", rawBaseURL) + return &fakeBackendClient{enrollJWT: "jwt-token"}, nil + } - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to parse enrollment response") - assert.Empty(t, token) -} + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) -func TestPerformEnrollment_ResponseTooLarge(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, err := w.Write(make([]byte, maxEnrollmentResponseSize+1)) + err := Enroll(context.Background(), "http://localhost:8080", "sak-token") require.NoError(t, err) - })) - defer server.Close() - - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") + require.True(t, called) + }) + + t.Run("backend client creation", func(t *testing.T) { + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + return nil, errors.New("factory boom") + } - require.Error(t, err) - assert.Contains(t, err.Error(), "enrollment response too large") - assert.Empty(t, token) -} + err := Enroll(context.Background(), "https://example.com", "sak-token") + require.ErrorContains(t, err, "failed to create backend client") + }) -func TestPerformEnrollment_ContextCancellation(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Simulate slow response - time.Sleep(2 * time.Second) - w.WriteHeader(http.StatusOK) - })) - defer server.Close() + t.Run("enroll error", func(t *testing.T) { + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + return &fakeBackendClient{enrollErr: errors.New("enroll boom")}, nil + } - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() + err := Enroll(context.Background(), "https://example.com", "sak-token") + require.ErrorContains(t, err, "enroll boom") + }) - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") + t.Run("localhost legacy endpoint allowed", func(t *testing.T) { + originalFactory := newBackendClient + originalSync := syncInventoryAfterEnroll + t.Cleanup(func() { + newBackendClient = originalFactory + syncInventoryAfterEnroll = originalSync + }) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to make enrollment request") - assert.Empty(t, token) -} + called := false + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + called = true + require.Equal(t, "http://localhost:8080", rawBaseURL) + return &fakeBackendClient{enrollJWT: "jwt-token"}, nil + } + syncInventoryAfterEnroll = func(context.Context) error { return nil } -func TestPerformEnrollment_ServerUnavailable(t *testing.T) { - // Use an invalid URL that will fail to connect - ctx := context.Background() - token, err := PerformEnrollment(ctx, "http://localhost:99999", "test-sak-token") + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) - require.Error(t, err) - assert.Contains(t, err.Error(), "failed to make enrollment request") - assert.Empty(t, token) + err := Enroll(context.Background(), "http://localhost:8080/api/v1/health/enroll", "sak-token") + require.NoError(t, err) + require.True(t, called) + }) } -func TestPerformEnrollment_RequestBodyEmpty(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify no body is sent (Content-Length should be 0 or not set) - assert.Equal(t, int64(0), r.ContentLength) +func TestEnrollWorkflowInventorySyncFailureIsNonFatal(t *testing.T) { + originalFactory := newBackendClient + originalSync := syncInventoryAfterEnroll + t.Cleanup(func() { + newBackendClient = originalFactory + syncInventoryAfterEnroll = originalSync + }) - response := EnrollResponse{ - JWTToken: "test-token", - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) - require.NoError(t, err) - })) - defer server.Close() + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + return &fakeBackendClient{enrollJWT: "jwt-token"}, nil + } + syncInventoryAfterEnroll = func(ctx context.Context) error { + return errors.New("inventory failed") + } - ctx := context.Background() - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) + err := Enroll(context.Background(), "https://example.com", "sak-token") require.NoError(t, err) - assert.NotEmpty(t, token) } -func TestEnrollResponse_JSONSerialization(t *testing.T) { - tests := []struct { - name string - response EnrollResponse - }{ - { - name: "with_token", - response: EnrollResponse{ - JWTToken: "test-jwt-token", - }, - }, - { - name: "empty_token", - response: EnrollResponse{ - JWTToken: "", - }, - }, +func TestStoreConfigInMetadataSecuresFreshStateFile(t *testing.T) { + if os.Geteuid() == 0 { + t.Skip("test expects non-root default state path resolution") } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Test marshaling - data, err := json.Marshal(tt.response) - assert.NoError(t, err) - - // Test unmarshaling - var unmarshaled EnrollResponse - err = json.Unmarshal(data, &unmarshaled) - assert.NoError(t, err) - assert.Equal(t, tt.response.JWTToken, unmarshaled.JWTToken) - }) - } -} + tmpHome := t.TempDir() + t.Setenv("HOME", tmpHome) -func TestPerformEnrollment_MultipleSuccessiveRequests(t *testing.T) { - requestCount := 0 + err := storeConfigInMetadata( + context.Background(), + "https://example.com", + "jwt-token", + "sak-token", + ) + require.NoError(t, err) - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ - response := EnrollResponse{ - JWTToken: fmt.Sprintf("token-%d", requestCount), + stateFile, err := config.DefaultStateFile() + require.NoError(t, err) + for _, candidate := range []string{stateFile, stateFile + "-wal", stateFile + "-shm"} { + info, err := os.Stat(candidate) + if os.IsNotExist(err) { + if candidate == stateFile { + require.NoError(t, err) + } + continue } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - err := json.NewEncoder(w).Encode(response) require.NoError(t, err) - })) - defer server.Close() - - ctx := context.Background() - - // Make multiple requests - for i := 1; i <= 3; i++ { - token, err := PerformEnrollment(ctx, server.URL, "test-sak-token") - require.NoError(t, err) - assert.Equal(t, fmt.Sprintf("token-%d", i), token) + assert.Equal(t, os.FileMode(0o600), info.Mode().Perm()) } - - assert.Equal(t, 3, requestCount) } diff --git a/internal/exporter/collector/collector.go b/internal/exporter/collector/collector.go index acbe48ec..b7cd5f70 100644 --- a/internal/exporter/collector/collector.go +++ b/internal/exporter/collector/collector.go @@ -31,13 +31,10 @@ import ( nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" "github.com/google/uuid" - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) -const initialMachineInfoWait = 5 * time.Second - // GenerateCollectionID generates a unique identifier for a data collection cycle func GenerateCollectionID() string { bytes := make([]byte, 16) @@ -52,15 +49,14 @@ func GenerateEventID() string { // HealthData represents the collected health data type HealthData struct { - CollectionID string - MachineID string - Timestamp time.Time - MachineInfo *machineinfo.MachineInfo - Metrics pkgmetrics.Metrics - Events eventstore.Events - ComponentData map[string]interface{} - AttestationData *attestation.AttestationData - ConfigEntries []config.ConfigEntry + CollectionID string + MachineID string + Timestamp time.Time + MachineInfo *machineinfo.MachineInfo + GPUUUIDToIndex map[string]string + Metrics pkgmetrics.Metrics + Events eventstore.Events + ComponentData map[string]interface{} } // Collector defines the interface for collecting health data @@ -70,17 +66,13 @@ type Collector interface { // collector implements the Collector interface type collector struct { - config *config.HealthExporterConfig - configEntries []config.ConfigEntry // Cached config entries computed once at startup - metricsStore pkgmetrics.Store - eventStore eventstore.Store - componentsRegistry components.Registry - nvmlInstance nvidianvml.Instance - attestationManager *attestation.Manager - lastAttestationCollection time.Time - machineID string // Agent's stable identity from server initialization - dcgmGPUIndexes map[string]string // UUID → DCGM device ID override for GPU indices - machineInfoProvider machineInfoProvider + config *config.HealthExporterConfig + metricsStore pkgmetrics.Store + eventStore eventstore.Store + componentsRegistry components.Registry + nvmlInstance nvidianvml.Instance + machineID string // Agent's stable identity from server initialization + dcgmGPUIndexes map[string]string // UUID → DCGM device ID override for GPU indices } // New creates a new health data collector @@ -92,38 +84,18 @@ func New( eventStore eventstore.Store, componentsRegistry components.Registry, nvmlInstance nvidianvml.Instance, - attestationManager *attestation.Manager, + _ any, machineID string, dcgmGPUIndexes map[string]string, ) Collector { - // Compute config entries once at startup (no dynamic config reload) - var configEntries []config.ConfigEntry - if fullConfig != nil { - configEntries = fullConfig.ToConfigEntries(allComponentNames) - log.Logger.Infow("Config entries computed at startup", "entries_count", len(configEntries)) - } - - var provider machineInfoProvider - if cfg != nil && cfg.IncludeMachineInfo && nvmlInstance != nil { - var opts []machineinfo.MachineInfoOption - if len(dcgmGPUIndexes) > 0 { - opts = append(opts, machineinfo.WithDCGMGPUIndexes(dcgmGPUIndexes)) - } - provider = newCachedMachineInfoProvider(nvmlInstance, 0, opts...) - provider.RefreshAsync(context.Background()) - } - return &collector{ - config: cfg, - configEntries: configEntries, - metricsStore: metricsStore, - eventStore: eventStore, - componentsRegistry: componentsRegistry, - nvmlInstance: nvmlInstance, - attestationManager: attestationManager, - machineID: machineID, - dcgmGPUIndexes: dcgmGPUIndexes, - machineInfoProvider: provider, + config: cfg, + metricsStore: metricsStore, + eventStore: eventStore, + componentsRegistry: componentsRegistry, + nvmlInstance: nvmlInstance, + machineID: machineID, + dcgmGPUIndexes: dcgmGPUIndexes, } } @@ -139,14 +111,10 @@ func (c *collector) Collect(ctx context.Context) (*HealthData, error) { } data := &HealthData{ - CollectionID: collectionID, - MachineID: c.machineID, - Timestamp: time.Now().UTC(), - } - - // Collect machine info if enabled - if c.config.IncludeMachineInfo { - c.collectMachineInfo(ctx, data) + CollectionID: collectionID, + MachineID: c.machineID, + Timestamp: time.Now().UTC(), + GPUUUIDToIndex: cloneStringMap(c.dcgmGPUIndexes), } // Collect metrics if enabled @@ -170,37 +138,9 @@ func (c *collector) Collect(ctx context.Context) (*HealthData, error) { } } - // Collect attestation data if provider is available - // Attestation is always enabled if manager is available - if err := c.collectAttestationData(data); err != nil { - log.Logger.Errorw("Failed to collect attestation data", "error", err) - } - - // Collect config data - if err := c.collectConfigData(data); err != nil { - log.Logger.Errorw("Failed to collect config data", "error", err) - } - return data, nil } -// collectMachineInfo reads cached machine info and triggers a best-effort refresh. -func (c *collector) collectMachineInfo(ctx context.Context, data *HealthData) { - if c.machineInfoProvider == nil { - return - } - - if _, ok := c.machineInfoProvider.Get(); !ok { - c.machineInfoProvider.WaitForInitialRefresh(ctx, initialMachineInfoWait) - } - - if machineInfo, ok := c.machineInfoProvider.Get(); ok { - data.MachineInfo = machineInfo - } - - c.machineInfoProvider.RefreshAsync(ctx) -} - // collectMetrics collects metrics data from the metrics store func (c *collector) collectMetrics(ctx context.Context, data *HealthData) error { if c.metricsStore == nil { @@ -336,34 +276,14 @@ func (c *collector) collectComponentData(data *HealthData) error { return nil } -// collectAttestationData collects attestation data from the attestation manager if available and updated -func (c *collector) collectAttestationData(data *HealthData) error { - if c.attestationManager == nil { - log.Logger.Debugw("No attestation manager available, skipping attestation data collection") +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { return nil } - // Get latest attestation data (success or failure info) - attestationData := c.attestationManager.GetAttestationData() - data.AttestationData = attestationData - - // Update collection timestamp if data was newly updated - if c.attestationManager.IsAttestationDataUpdated(c.lastAttestationCollection) { - c.lastAttestationCollection = time.Now().UTC() - } - - return nil -} - -// collectConfigData returns cached agent configuration entries -// Config entries are computed once at startup since there's no dynamic config reload -func (c *collector) collectConfigData(data *HealthData) error { - if len(c.configEntries) == 0 { - log.Logger.Debugw("No config entries available, skipping config data collection") - return nil + out := make(map[string]string, len(in)) + for k, v := range in { + out[k] = v } - - // Return cached config entries (computed once at startup) - data.ConfigEntries = c.configEntries - return nil + return out } diff --git a/internal/exporter/collector/collector_test.go b/internal/exporter/collector/collector_test.go index a851d598..c64c150a 100644 --- a/internal/exporter/collector/collector_test.go +++ b/internal/exporter/collector/collector_test.go @@ -19,7 +19,6 @@ import ( "context" "encoding/hex" "errors" - "sync" "sync/atomic" "testing" "time" @@ -34,310 +33,10 @@ import ( "github.com/stretchr/testify/require" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) -func TestCollector_AttestationDataCollection(t *testing.T) { - tests := []struct { - name string - description string - testLogic func(t *testing.T) - }{ - { - name: "first_collection_always_collects", - description: "First collection should always collect attestation data even if empty", - testLogic: testFirstCollectionAlwaysCollects, - }, - { - name: "subsequent_collection_skips_when_no_update", - description: "Subsequent collections should skip when attestation data hasn't been updated", - testLogic: testSubsequentCollectionSkipsWhenNoUpdate, - }, - { - name: "collection_after_attestation_update", - description: "Collection should happen after attestation data is updated", - testLogic: testCollectionAfterAttestationUpdate, - }, - { - name: "nil_attestation_manager_skips_collection", - description: "Collector should skip attestation collection when manager is nil", - testLogic: testNilAttestationManagerSkipsCollection, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Log("Testing:", tt.description) - tt.testLogic(t) - }) - } -} - -func testFirstCollectionAlwaysCollects(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, cfg) // nil nvmlInstance, 20s for testing - - // Create collector - testCollector := createTestCollector(attestationManager) - - // Start attestation manager to populate some data - attestationManager.Start() - defer attestationManager.Stop() - - // Wait a bit for attestation to run - time.Sleep(100 * time.Millisecond) - - // Collect data for the first time - data, err := testCollector.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - - // First collection should work (but may not have attestation data due to test environment) - if data.AttestationData != nil { - t.Log("First collection successfully populated attestation data") - } else { - t.Log("First collection did not populate attestation data - this is expected when NVML/nonce fails in test environment") - } - - // Check if lastAttestationCollection was updated only if attestation data was collected - collectorImpl := testCollector.(*collector) - if data.AttestationData != nil { - assert.False(t, collectorImpl.lastAttestationCollection.IsZero(), - "lastAttestationCollection should be set after successful collection") - } else { - assert.True(t, collectorImpl.lastAttestationCollection.IsZero(), - "lastAttestationCollection should remain zero when attestation fails") - } -} - -func testSubsequentCollectionSkipsWhenNoUpdate(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, cfg) // nil nvmlInstance, 20s for testing - - // Create collector - testCollector := createTestCollector(attestationManager) - collectorImpl := testCollector.(*collector) - - // Start attestation manager - attestationManager.Start() - defer attestationManager.Stop() - - // Wait for attestation to run and populate data - time.Sleep(100 * time.Millisecond) - - // First collection - data1, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data1) - - firstCollectionTime := collectorImpl.lastAttestationCollection - // In test environment, this will be zero since attestation fails - t.Logf("First collection time: %v", firstCollectionTime) - - // Verify first collection has attestation data - // Verify first collection has attestation data (or logs why it doesn't) - if data1.AttestationData != nil { - assert.Empty(t, data1.AttestationData.SDKResponse.Evidences, "Until Attestation is available in public release, this should be empty") - } else { - t.Log("First collection did not populate attestation data - this is expected when NVML/nonce fails") - } - - // Sleep a little to ensure time difference - time.Sleep(10 * time.Millisecond) - - // Second collection (should skip attestation since no update) - data2, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data2) - - secondCollectionTime := collectorImpl.lastAttestationCollection - - // lastAttestationCollection should remain the same (indicating skip) - assert.Equal(t, firstCollectionTime, secondCollectionTime, - "lastAttestationCollection should not change when attestation collection is skipped") -} - -func testCollectionAfterAttestationUpdate(t *testing.T) { - ctx := context.Background() - cfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, cfg) // nil nvmlInstance, 20s for testing - - // Create collector - testCollector := createTestCollector(attestationManager) - collectorImpl := testCollector.(*collector) - - // Start attestation manager with faster interval for testing (20 seconds) - attestationManager.Start() - defer attestationManager.Stop() - - // Wait for first attestation to run - time.Sleep(100 * time.Millisecond) - - // First collection - data1, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data1) - - firstCollectionTime := collectorImpl.lastAttestationCollection - - // Verify first collection has attestation data - // Verify first collection has attestation data (or logs why it doesn't) - if data1.AttestationData != nil { - assert.Empty(t, data1.AttestationData.SDKResponse.Evidences, "Until Attestation is available in public release, this should be empty") - } else { - t.Log("First collection did not populate attestation data - this is expected when NVML/nonce fails") - } - - // Wait for attestation to run again (it's set to 20 seconds in the test) - t.Log("Waiting for attestation to refresh...") - time.Sleep(10 * time.Second) - - // Second collection (should collect since attestation was updated) - data2, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data2) - - secondCollectionTime := collectorImpl.lastAttestationCollection - - // In test environment, both times will be zero since attestation fails - t.Logf("First collection time: %v, Second collection time: %v", firstCollectionTime, secondCollectionTime) - - // In a real environment with working NVML/HTTP, both collections would have evidence data - // In test environment, they will be nil due to missing dependencies - if data1.AttestationData != nil && data2.AttestationData != nil { - assert.Empty(t, data1.AttestationData.SDKResponse.Evidences, "Until Attestation is available in public release, this should be empty") - assert.Empty(t, data2.AttestationData.SDKResponse.Evidences, "Until Attestation is available in public release, this should be empty") - t.Log("Both collections successfully have attestation data") - } else { - t.Log("Collections do not have attestation data - expected in test environment without real dependencies") - } -} - -func testNilAttestationManagerSkipsCollection(t *testing.T) { - ctx := context.Background() - - // Create collector with nil attestation manager - testCollector := createTestCollectorWithNilAttestation() - - // Collection should skip gracefully - data, err := testCollector.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Nil(t, data.AttestationData, "Should not collect attestation data when manager is nil") -} - -func TestCollector_AttestationDataCollection_WithMockData(t *testing.T) { - // This test verifies collection behavior when attestation is unavailable - ctx := context.Background() - attestationCfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, attestationCfg) - testCollector := createTestCollector(attestationManager) - collectorImpl := testCollector.(*collector) - - // Verify that collection works when no attestation data is available - data1, err := testCollector.Collect(ctx) - require.NoError(t, err) - require.NotNil(t, data1) - - // Should be nil since no attestation data is available - assert.Nil(t, data1.AttestationData, "Should be nil when no attestation data available") - assert.True(t, collectorImpl.lastAttestationCollection.IsZero(), "Should remain zero") - - t.Log("Successfully tested collection with no attestation data") -} - -func TestAttestationManager_UpdateTracking(t *testing.T) { - ctx := context.Background() - attestationCfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - manager := attestation.NewManager(ctx, nil, attestationCfg) // nil nvmlInstance for testing - - // Initially, no updates - baseTime := time.Now().UTC() - assert.False(t, manager.IsAttestationDataUpdated(baseTime), - "Should return false before any attestation runs") - - // Start the manager and test the update tracking - manager.Start() - defer manager.Stop() - - // Give it time to attempt attestation - time.Sleep(100 * time.Millisecond) - - // In test environment this may still be false due to NVML/HTTP failures, but that's expected - updated := manager.IsAttestationDataUpdated(baseTime) - t.Logf("Attestation updated after start: %v", updated) - - // The important part is that the method doesn't crash and returns a boolean - assert.IsType(t, false, updated, "IsAttestationDataUpdated should return a boolean") -} - -// Helper functions - -func createTestCollector(attestationManager *attestation.Manager) Collector { - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: false, - IncludeMetrics: false, - IncludeEvents: false, - IncludeComponentData: false, - } - - return New( - cfg, - nil, // fullConfig - nil, // allComponentNames - nil, // metricsStore - nil, // eventStore - nil, // componentsRegistry - nil, // nvmlInstance - attestationManager, - "test-machine-id", - nil, // dcgmGPUIndexes - ) -} - -func createTestCollectorWithNilAttestation() Collector { - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: false, - IncludeMetrics: false, - IncludeEvents: false, - IncludeComponentData: false, - } - - return New( - cfg, - nil, // fullConfig - nil, // allComponentNames - nil, // metricsStore - nil, // eventStore - nil, // componentsRegistry - nil, // nvmlInstance - nil, // attestationManager (nil for testing) - "test-machine-id", - nil, // dcgmGPUIndexes - ) -} - func TestGenerateCollectionID(t *testing.T) { // Generate multiple collection IDs id1 := GenerateCollectionID() @@ -376,20 +75,13 @@ func TestGenerateEventID(t *testing.T) { } func TestNew(t *testing.T) { - ctx := context.Background() cfg := &config.HealthExporterConfig{ IncludeMachineInfo: true, IncludeMetrics: true, IncludeEvents: true, IncludeComponentData: true, } - attestationCfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, attestationCfg) - - c := New(cfg, nil, nil, nil, nil, nil, nil, attestationManager, "test-machine-id", nil) + c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil) assert.NotNil(t, c, "Collector should be created") @@ -404,7 +96,6 @@ func TestCollector_Collect_BasicFlow(t *testing.T) { IncludeMetrics: false, IncludeEvents: false, IncludeComponentData: false, - Attestation: config.AttestationConfig{}, } collector := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil) @@ -424,7 +115,6 @@ func TestCollector_Collect_BasicFlow(t *testing.T) { assert.Empty(t, data.Metrics, "Metrics should be empty when disabled") assert.Empty(t, data.Events, "Events should be empty when disabled") assert.Empty(t, data.ComponentData, "ComponentData should be empty when disabled") - assert.Nil(t, data.AttestationData, "AttestationData should be nil when disabled") } func TestCollector_CollectMachineInfo_NoNVML(t *testing.T) { @@ -444,180 +134,6 @@ func TestCollector_CollectMachineInfo_NoNVML(t *testing.T) { assert.Nil(t, data.MachineInfo, "MachineInfo should be nil without NVML") } -func TestCollector_CollectMachineInfo_UsesCachedValue(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - expected := &machineinfo.MachineInfo{Hostname: "cached-host"} - provider := &mockMachineInfoProvider{ - cached: expected, - } - c.machineInfoProvider = provider - - data, err := c.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - require.NotNil(t, data.MachineInfo) - assert.Equal(t, expected, data.MachineInfo) - assert.Equal(t, int32(1), provider.refreshCalls.Load()) -} - -func TestCollector_CollectMachineInfo_WaitsBrieflyForInitialRefresh(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - provider := newMockMachineInfoProvider() - c.machineInfoProvider = provider - - go func() { - time.Sleep(50 * time.Millisecond) - provider.setCached(&machineinfo.MachineInfo{Hostname: "prewarmed-host"}) - provider.markInitialRefreshDone() - }() - - data, err := c.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - require.NotNil(t, data.MachineInfo) - assert.Equal(t, "prewarmed-host", data.MachineInfo.Hostname) -} - -func TestCollector_CollectMachineInfo_RefreshDoesNotBlockMetrics(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - IncludeMetrics: true, - MetricsLookback: metav1.Duration{Duration: 5 * time.Minute}, - } - - c := New(cfg, nil, nil, &mockMetricsStore{ - metrics: pkgmetrics.Metrics{ - {Component: "gpu", Name: "temperature", Value: 70, UnixMilliseconds: time.Now().UnixMilli()}, - }, - }, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - - blocker := make(chan struct{}) - provider := newMockMachineInfoProvider() - provider.refreshFn = func(parent context.Context) { - provider.markInitialRefreshDone() - <-blocker - } - c.machineInfoProvider = provider - - start := time.Now() - data, err := c.Collect(ctx) - elapsed := time.Since(start) - close(blocker) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Len(t, data.Metrics, 1) - assert.Nil(t, data.MachineInfo) - assert.GreaterOrEqual(t, elapsed, 4900*time.Millisecond) - assert.Less(t, elapsed, 5500*time.Millisecond) -} - -func TestCollector_CollectMachineInfo_InitialWaitDoesNotRepeatAfterFirstRefresh(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - provider := newMockMachineInfoProvider() - provider.markInitialRefreshDone() - provider.refreshFn = func(parent context.Context) {} - c.machineInfoProvider = provider - - start := time.Now() - data, err := c.Collect(ctx) - elapsed := time.Since(start) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Less(t, elapsed, 200*time.Millisecond) -} - -func TestCollector_CollectMachineInfo_InitialWaitDoesNotRepeatAfterTimeout(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - provider := newMockMachineInfoProvider() - provider.refreshFn = func(parent context.Context) {} - c.machineInfoProvider = provider - - start := time.Now() - data, err := c.Collect(ctx) - firstElapsed := time.Since(start) - - require.NoError(t, err) - require.NotNil(t, data) - assert.GreaterOrEqual(t, firstElapsed, 4900*time.Millisecond) - assert.Less(t, firstElapsed, 5500*time.Millisecond) - - start = time.Now() - data, err = c.Collect(ctx) - secondElapsed := time.Since(start) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Less(t, secondElapsed, 200*time.Millisecond) -} - -func TestCollector_CollectMachineInfo_InitialWaitHonorsContextCancellation(t *testing.T) { - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - provider := newMockMachineInfoProvider() - provider.refreshFn = func(parent context.Context) {} - c.machineInfoProvider = provider - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - start := time.Now() - data, err := c.Collect(ctx) - elapsed := time.Since(start) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Less(t, elapsed, 200*time.Millisecond) -} - -func TestCollector_CollectMachineInfo_RetainsLastGoodOnRefreshFailure(t *testing.T) { - ctx := context.Background() - cfg := &config.HealthExporterConfig{ - IncludeMachineInfo: true, - } - - c := New(cfg, nil, nil, nil, nil, nil, nil, nil, "test-machine-id", nil).(*collector) - expected := &machineinfo.MachineInfo{Hostname: "last-good"} - provider := newMockMachineInfoProvider() - provider.cached = expected - provider.initialRefreshOnce.Do(func() { close(provider.initialRefreshDone) }) - provider.refreshFn = func(parent context.Context) {} - c.machineInfoProvider = provider - - data, err := c.Collect(ctx) - - require.NoError(t, err) - require.NotNil(t, data) - assert.Equal(t, expected, data.MachineInfo) -} - func TestCachedMachineInfoProvider_DeduplicatesConcurrentRefresh(t *testing.T) { originalGetMachineInfo := getMachineInfo defer func() { @@ -975,7 +491,6 @@ func TestCollector_AllFeaturesEnabled(t *testing.T) { IncludeMetrics: true, IncludeEvents: true, IncludeComponentData: true, - Attestation: config.AttestationConfig{}, MetricsLookback: metav1.Duration{Duration: 5 * time.Minute}, EventsLookback: metav1.Duration{Duration: 5 * time.Minute}, } @@ -1005,15 +520,9 @@ func TestCollector_AllFeaturesEnabled(t *testing.T) { components: []components.Component{mockComp}, } - attestationCfg := &config.AttestationConfig{ - Interval: metav1.Duration{Duration: 20 * time.Second}, - JitterEnabled: false, - } - attestationManager := attestation.NewManager(ctx, nil, attestationCfg) - mockEventStore := &mockEventStore{} - collector := New(cfg, nil, nil, mockMetricsStore, mockEventStore, mockRegistry, nil, attestationManager, "test-machine-id", nil) + collector := New(cfg, nil, nil, mockMetricsStore, mockEventStore, mockRegistry, nil, nil, "test-machine-id", nil) data, err := collector.Collect(ctx) require.NoError(t, err) @@ -1028,7 +537,6 @@ func TestCollector_AllFeaturesEnabled(t *testing.T) { assert.Len(t, data.Events, 1) assert.Len(t, data.ComponentData, 1) // MachineInfo will be nil without NVML - // AttestationData may be nil in test environment } // ============================================================================= @@ -1040,83 +548,6 @@ type mockMetricsStore struct { shouldError bool } -type mockMachineInfoProvider struct { - mu sync.RWMutex - cached *machineinfo.MachineInfo - refreshFn func(context.Context) - refreshCalls atomic.Int32 - initialWaited bool - initialRefreshDone chan struct{} - initialRefreshOnce sync.Once -} - -func newMockMachineInfoProvider() *mockMachineInfoProvider { - return &mockMachineInfoProvider{ - initialRefreshDone: make(chan struct{}), - } -} - -func (m *mockMachineInfoProvider) Get() (*machineinfo.MachineInfo, bool) { - m.mu.RLock() - defer m.mu.RUnlock() - - if m.cached == nil { - return nil, false - } - return m.cached, true -} - -func (m *mockMachineInfoProvider) RefreshAsync(parent context.Context) { - m.refreshCalls.Add(1) - if m.refreshFn == nil { - return - } - go func() { - defer func() { - _ = recover() - }() - m.refreshFn(parent) - }() -} - -func (m *mockMachineInfoProvider) WaitForInitialRefresh(ctx context.Context, maxWait time.Duration) bool { - if maxWait <= 0 { - return false - } - - m.mu.Lock() - if m.initialWaited { - m.mu.Unlock() - return false - } - m.initialWaited = true - m.mu.Unlock() - - timer := time.NewTimer(maxWait) - defer timer.Stop() - - select { - case <-m.initialRefreshDone: - return true - case <-ctx.Done(): - return false - case <-timer.C: - return false - } -} - -func (m *mockMachineInfoProvider) setCached(info *machineinfo.MachineInfo) { - m.mu.Lock() - defer m.mu.Unlock() - m.cached = info -} - -func (m *mockMachineInfoProvider) markInitialRefreshDone() { - m.initialRefreshOnce.Do(func() { - close(m.initialRefreshDone) - }) -} - func (m *mockMetricsStore) Read(ctx context.Context, opts ...pkgmetrics.OpOption) (pkgmetrics.Metrics, error) { if m.shouldError { return nil, errors.New("mock metrics store error") diff --git a/internal/exporter/converter/csv.go b/internal/exporter/converter/csv.go index 8d12a650..d8c913ab 100644 --- a/internal/exporter/converter/csv.go +++ b/internal/exporter/converter/csv.go @@ -282,7 +282,7 @@ func (c *csvConverter) writeMachineInfoCSV(outputDir, filename string, data *col // Header and basic system info records = append(records, []string{"attribute_name", "attribute_value"}, - []string{"Fleetint Version", i.FleetintVersion}, + []string{"Agent Version", i.AgentVersion}, []string{"Container Runtime Version", i.ContainerRuntimeVersion}, []string{"OS Image", i.OSImage}, []string{"Kernel Version", i.KernelVersion}, diff --git a/internal/exporter/converter/csv_test.go b/internal/exporter/converter/csv_test.go index 6a5905c8..1881c364 100644 --- a/internal/exporter/converter/csv_test.go +++ b/internal/exporter/converter/csv_test.go @@ -177,10 +177,10 @@ func TestCSVConverter_Convert_MachineInfo(t *testing.T) { timestamp := "20251105_120000" machineInfo := &machineinfo.MachineInfo{ - FleetintVersion: "0.1.5", - DCGMVersion: "4.2.3", - OSImage: "Ubuntu 22.04", - KernelVersion: "5.15.0", + AgentVersion: "0.1.5", + DCGMVersion: "4.2.3", + OSImage: "Ubuntu 22.04", + KernelVersion: "5.15.0", CPUInfo: &apiv1.MachineCPUInfo{ Type: "Intel", Manufacturer: "Intel", @@ -279,7 +279,7 @@ func TestCSVConverter_Convert_AllData(t *testing.T) { }, }, MachineInfo: &machineinfo.MachineInfo{ - FleetintVersion: "0.1.5", + AgentVersion: "0.1.5", }, } diff --git a/internal/exporter/converter/otlp.go b/internal/exporter/converter/otlp.go index 89bd9be8..ac33cc88 100644 --- a/internal/exporter/converter/otlp.go +++ b/internal/exporter/converter/otlp.go @@ -19,6 +19,7 @@ package converter import ( "encoding/json" "fmt" + "os" "reflect" "strings" "time" @@ -32,6 +33,8 @@ import ( "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" ) +var osHostname = os.Hostname + // OTLPData holds both metrics and logs for OTLP export type OTLPData struct { Metrics *metricsv1.MetricsData @@ -98,7 +101,7 @@ func (c *otlpConverter) Convert(data *collector.HealthData) *OTLPData { } } -// createOTLPResource creates OTLP resource with machine info, agent config, and identification +// createOTLPResource creates a minimal OTLP resource for telemetry correlation. func (c *otlpConverter) createOTLPResource(data *collector.HealthData) *resourcev1.Resource { attributes := []*commonv1.KeyValue{ { @@ -113,39 +116,31 @@ func (c *otlpConverter) createOTLPResource(data *collector.HealthData) *resource Value: &commonv1.AnyValue_StringValue{StringValue: data.MachineID}, }, }, - { - Key: "agentConfig.totalComponents", - Value: &commonv1.AnyValue{ - Value: &commonv1.AnyValue_IntValue{IntValue: int64(len(data.ComponentData))}, - }, - }, } - // Add agent config entries as resource attributes - for _, entry := range data.ConfigEntries { + if hostname := resolveOTLPHostname(); hostname != "" { attributes = append(attributes, &commonv1.KeyValue{ - Key: "agentConfig." + entry.Key, + Key: "host.name", Value: &commonv1.AnyValue{ - Value: &commonv1.AnyValue_StringValue{StringValue: entry.Value}, + Value: &commonv1.AnyValue_StringValue{StringValue: hostname}, }, }) } - // Add machine info attributes if available using reflection - if data.MachineInfo != nil { - machineInfoAttributes := convertStructToOTLPAttributes(data.MachineInfo) - attributes = append(attributes, machineInfoAttributes...) + return &resourcev1.Resource{ + Attributes: attributes, } +} - // Add attestation data attributes if available using reflection - if data.AttestationData != nil { - attestationAttributes := convertStructToOTLPAttributesWithPrefix(data.AttestationData, "attestation") - attributes = append(attributes, attestationAttributes...) +func resolveOTLPHostname() string { + if hostname := strings.TrimSpace(os.Getenv("HOSTNAME")); hostname != "" { + return hostname } - - return &resourcev1.Resource{ - Attributes: attributes, + hostname, err := osHostname() + if err != nil { + return "" } + return strings.TrimSpace(hostname) } // convertMetricsToOTLP converts health metrics to OTLP metrics format @@ -211,12 +206,6 @@ func (c *otlpConverter) convertMetricsToOTLP(data *collector.HealthData) []*metr Value: &commonv1.AnyValue_IntValue{IntValue: int64(len(data.ComponentData))}, }, }, - { - Key: "attestation_evidences_count", - Value: &commonv1.AnyValue{ - Value: &commonv1.AnyValue_IntValue{IntValue: int64(c.getAttestationEvidencesCount(data))}, - }, - }, }, }, }, @@ -259,16 +248,17 @@ func (c *otlpConverter) convertLabelsToOTLPAttributes(labels map[string]string, return attributes } -// buildGPUUUIDToIndexMap builds a UUID → GPU index lookup from MachineInfo. +// buildGPUUUIDToIndexMap builds a UUID → GPU index lookup from the collector snapshot. func buildGPUUUIDToIndexMap(data *collector.HealthData) map[string]string { m := make(map[string]string) - if data.MachineInfo == nil || data.MachineInfo.GPUInfo == nil { + if len(data.GPUUUIDToIndex) == 0 { return m } - for _, gpu := range data.MachineInfo.GPUInfo.GPUs { - if gpu.UUID != "" && gpu.GPUIndex != "" { - m[gpu.UUID] = gpu.GPUIndex + for uuid, gpuIndex := range data.GPUUUIDToIndex { + if uuid == "" || gpuIndex == "" { + continue } + m[uuid] = gpuIndex } return m } @@ -646,11 +636,3 @@ func convertStructToOTLPAttributesWithPrefix(v interface{}, prefix string) []*co return attributes } - -// getAttestationEvidencesCount returns the count of attestation evidences -func (c *otlpConverter) getAttestationEvidencesCount(data *collector.HealthData) int { - if data.AttestationData == nil { - return 0 - } - return len(data.AttestationData.SDKResponse.Evidences) -} diff --git a/internal/exporter/converter/otlp_test.go b/internal/exporter/converter/otlp_test.go index ee475cee..529e037e 100644 --- a/internal/exporter/converter/otlp_test.go +++ b/internal/exporter/converter/otlp_test.go @@ -16,6 +16,8 @@ package converter import ( + "errors" + "os" "testing" "time" @@ -28,7 +30,6 @@ import ( metricsv1 "go.opentelemetry.io/proto/otlp/metrics/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" ) @@ -309,21 +310,12 @@ func TestOTLPConverter_Convert_WithComponentData(t *testing.T) { assert.True(t, found, "Should find component data log") } -func TestOTLPConverter_Convert_WithMachineInfo(t *testing.T) { +func TestOTLPConverter_Convert_IgnoresMachineInfoInResource(t *testing.T) { data := &collector.HealthData{ Timestamp: time.Now(), MachineID: "test-machine", MachineInfo: &machineinfo.MachineInfo{ - FleetintVersion: "0.1.5", - DCGMVersion: "4.2.3", - OSImage: "Ubuntu 22.04", - KernelVersion: "5.15.0", - CPUInfo: &apiv1.MachineCPUInfo{ - Type: "Intel", - Manufacturer: "Intel", - Architecture: "x86_64", - LogicalCores: 8, - }, + DCGMVersion: "4.2.3", }, } @@ -333,82 +325,13 @@ func TestOTLPConverter_Convert_WithMachineInfo(t *testing.T) { require.NotNil(t, otlpData) require.NotNil(t, otlpData.Metrics) - // Check resource has machine info attributes rm := otlpData.Metrics.ResourceMetrics[0] assert.NotNil(t, rm.Resource) assert.Greater(t, len(rm.Resource.Attributes), 0) - // Verify machine info is embedded in resource attributes - // The attributes may have different keys based on how machine info is flattened - attrCount := len(rm.Resource.Attributes) - assert.Greater(t, attrCount, 2, "Should have multiple resource attributes including machine info") - - // Check that some attributes exist (the exact key names may vary) - attrKeys := make([]string, 0, len(rm.Resource.Attributes)) for _, attr := range rm.Resource.Attributes { - attrKeys = append(attrKeys, attr.Key) - } - // Should have at least service.name and machine.id - hasServiceName := false - for _, key := range attrKeys { - if key == "service.name" { - hasServiceName = true - break - } + assert.NotEqual(t, "dcgmVersion", attr.Key) } - assert.True(t, hasServiceName, "Should have service.name attribute") - assert.Equal(t, "4.2.3", findAttribute(t, rm.Resource.Attributes, "dcgmVersion").GetStringValue()) -} - -func TestOTLPConverter_Convert_WithAttestationData(t *testing.T) { - attestationData := &attestation.AttestationData{ - Success: true, - SDKResponse: attestation.AttestationSDKResponse{ - Evidences: []attestation.EvidenceItem{ - { - Arch: "BLACKWELL", - Certificate: "test-cert", - DriverVersion: "575.28", - Evidence: "test-evidence", - Nonce: "test-nonce", - VBIOSVersion: "96.00.AF.00.01", - Version: "1.0", - }, - }, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Date(2025, 11, 5, 12, 0, 0, 0, time.UTC), - } - - data := &collector.HealthData{ - Timestamp: time.Now(), - MachineID: "test-machine", - AttestationData: attestationData, - } - - converter := NewOTLPConverter() - otlpData := converter.Convert(data) - - require.NotNil(t, otlpData) - require.NotNil(t, otlpData.Logs) - - // Should NOT have attestation logs - rl := otlpData.Logs.ResourceLogs[0] - logs := rl.ScopeLogs[0].LogRecords - assert.Empty(t, logs, "Should not have attestation logs") - - // Should have attestation data in resource attributes - rm := otlpData.Metrics.ResourceMetrics[0] - attrs := rm.Resource.Attributes - foundAttestation := false - for _, attr := range attrs { - if contains(attr.Key, "attestation") { - foundAttestation = true - break - } - } - assert.True(t, foundAttestation, "Should have attestation data in resource attributes") } func TestOTLPConverter_ConvertStructToOTLPAttributes(t *testing.T) { @@ -699,13 +622,9 @@ func TestOTLPConverter_ConvertLabelsToOTLPAttributes_EnrichesGPUIndex(t *testing func TestBuildGPUUUIDToIndexMap(t *testing.T) { t.Run("builds map from machine info", func(t *testing.T) { data := &collector.HealthData{ - MachineInfo: &machineinfo.MachineInfo{ - GPUInfo: &apiv1.MachineGPUInfo{ - GPUs: []apiv1.MachineGPUInstance{ - {UUID: "GPU-abc-123", GPUIndex: "0"}, - {UUID: "GPU-def-456", GPUIndex: "1"}, - }, - }, + GPUUUIDToIndex: map[string]string{ + "GPU-abc-123": "0", + "GPU-def-456": "1", }, } @@ -721,9 +640,9 @@ func TestBuildGPUUUIDToIndexMap(t *testing.T) { assert.Empty(t, m) }) - t.Run("returns empty map when GPU info is nil", func(t *testing.T) { + t.Run("returns empty map when mapping is nil", func(t *testing.T) { data := &collector.HealthData{ - MachineInfo: &machineinfo.MachineInfo{}, + GPUUUIDToIndex: nil, } m := buildGPUUUIDToIndexMap(data) assert.Empty(t, m) @@ -731,14 +650,10 @@ func TestBuildGPUUUIDToIndexMap(t *testing.T) { t.Run("skips entries with empty uuid or index", func(t *testing.T) { data := &collector.HealthData{ - MachineInfo: &machineinfo.MachineInfo{ - GPUInfo: &apiv1.MachineGPUInfo{ - GPUs: []apiv1.MachineGPUInstance{ - {UUID: "GPU-abc-123", GPUIndex: "0"}, - {UUID: "", GPUIndex: "1"}, - {UUID: "GPU-ghi-789", GPUIndex: ""}, - }, - }, + GPUUUIDToIndex: map[string]string{ + "GPU-abc-123": "0", + "": "1", + "GPU-ghi-789": "", }, } @@ -748,55 +663,6 @@ func TestBuildGPUUUIDToIndexMap(t *testing.T) { }) } -func TestOTLPConverter_GetAttestationEvidencesCount(t *testing.T) { - tests := []struct { - name string - data *collector.HealthData - expectedCount int - }{ - { - name: "with_evidences", - data: &collector.HealthData{ - AttestationData: &attestation.AttestationData{ - SDKResponse: attestation.AttestationSDKResponse{ - Evidences: []attestation.EvidenceItem{ - {Arch: "BLACKWELL"}, - {Arch: "HOPPER"}, - }, - }, - }, - }, - expectedCount: 2, - }, - { - name: "nil_attestation", - data: &collector.HealthData{ - AttestationData: nil, - }, - expectedCount: 0, - }, - { - name: "empty_evidences", - data: &collector.HealthData{ - AttestationData: &attestation.AttestationData{ - SDKResponse: attestation.AttestationSDKResponse{ - Evidences: []attestation.EvidenceItem{}, - }, - }, - }, - expectedCount: 0, - }, - } - - converter := &otlpConverter{} - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - count := converter.getAttestationEvidencesCount(tt.data) - assert.Equal(t, tt.expectedCount, count) - }) - } -} - func TestOTLPConverter_SummaryMetric(t *testing.T) { data := &collector.HealthData{ Timestamp: time.Now(), @@ -881,6 +747,37 @@ func TestOTLPConverter_Interface(t *testing.T) { assert.NotNil(t, converter) } +func TestResolveOTLPHostname(t *testing.T) { + origHostEnv, hadHostEnv := os.LookupEnv("HOSTNAME") + origOSHostname := osHostname + t.Cleanup(func() { + osHostname = origOSHostname + if hadHostEnv { + _ = os.Setenv("HOSTNAME", origHostEnv) + } else { + _ = os.Unsetenv("HOSTNAME") + } + }) + + t.Run("falls back to hostname env", func(t *testing.T) { + _ = os.Setenv("HOSTNAME", "pod-host-a") + osHostname = func() (string, error) { return "os-host-a", nil } + assert.Equal(t, "pod-host-a", resolveOTLPHostname()) + }) + + t.Run("falls back to os hostname", func(t *testing.T) { + _ = os.Unsetenv("HOSTNAME") + osHostname = func() (string, error) { return "os-host-a", nil } + assert.Equal(t, "os-host-a", resolveOTLPHostname()) + }) + + t.Run("returns empty on hostname error", func(t *testing.T) { + _ = os.Unsetenv("HOSTNAME") + osHostname = func() (string, error) { return "", errors.New("boom") } + assert.Equal(t, "", resolveOTLPHostname()) + }) +} + func TestOTLPConverter_Convert_AllData(t *testing.T) { // Test with all data types combined data := &collector.HealthData{ @@ -898,20 +795,6 @@ func TestOTLPConverter_Convert_AllData(t *testing.T) { "reason": "All OK", }, }, - MachineInfo: &machineinfo.MachineInfo{ - FleetintVersion: "0.1.5", - }, - AttestationData: &attestation.AttestationData{ - Success: true, - SDKResponse: attestation.AttestationSDKResponse{ - Evidences: []attestation.EvidenceItem{ - {Arch: "BLACKWELL", VBIOSVersion: "96.00"}, - }, - ResultCode: 0, - ResultMessage: "Ok", - }, - NonceRefreshTimestamp: time.Now(), - }, } converter := NewOTLPConverter() @@ -930,7 +813,6 @@ func TestOTLPConverter_Convert_AllData(t *testing.T) { rl := otlpData.Logs.ResourceLogs[0] assert.Greater(t, len(rl.ScopeLogs[0].LogRecords), 0) - // Verify resource has attributes from machine info assert.Greater(t, len(rm.Resource.Attributes), 0) } diff --git a/internal/exporter/exporter.go b/internal/exporter/exporter.go index 5b7e9b97..dd8301ad 100644 --- a/internal/exporter/exporter.go +++ b/internal/exporter/exporter.go @@ -30,9 +30,8 @@ import ( "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" - "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/endpoint" - "github.com/NVIDIA/fleet-intelligence-agent/internal/enrollment" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/converter" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/writer" @@ -42,15 +41,16 @@ import ( // Ensure healthExporter implements the Exporter interface var _ Exporter = (*healthExporter)(nil) +var newBackendClient = backendclient.New + // healthExporter implements the Exporter interface with improved architecture type healthExporter struct { - ctx context.Context - cancel context.CancelFunc - options *exporterOptions - collector collector.Collector - fileWriter writer.FileWriter - httpWriter writer.HTTPWriter - attestationManager *attestation.Manager + ctx context.Context + cancel context.CancelFunc + options *exporterOptions + collector collector.Collector + fileWriter writer.FileWriter + httpWriter writer.HTTPWriter // Last export timestamp for tracking lastExport time.Time @@ -76,10 +76,6 @@ func New(ctx context.Context, opts ...ExporterOption) (Exporter, error) { } options.setDefaults() - // Create attestation manager (always enabled) - attestationManager := attestation.NewManager(cctx, options.nvmlInstance, &options.config.Attestation) - log.Logger.Infow("Attestation manager created", "interval", options.config.Attestation.Interval.Duration, "jitter_enabled", options.config.Attestation.JitterEnabled) - // Get all component names for config export allComponentNames := registry.AllComponentNames() @@ -91,7 +87,7 @@ func New(ctx context.Context, opts ...ExporterOption) (Exporter, error) { options.eventStore, options.componentsRegistry, options.nvmlInstance, - attestationManager, + nil, options.machineID, options.dcgmGPUIndexes, ) @@ -103,13 +99,12 @@ func New(ctx context.Context, opts ...ExporterOption) (Exporter, error) { httpWriter := writer.NewHTTPWriter(options.httpClient, otlpConverter) exporter := &healthExporter{ - ctx: cctx, - cancel: cancel, - options: options, - collector: dataCollector, - fileWriter: fileWriter, - httpWriter: httpWriter, - attestationManager: attestationManager, + ctx: cctx, + cancel: cancel, + options: options, + collector: dataCollector, + fileWriter: fileWriter, + httpWriter: httpWriter, } // Set JWT refresh function on the HTTP writer @@ -127,11 +122,6 @@ func (e *healthExporter) Start() error { log.Logger.Infow("Starting health exporter") - // Start the attestation manager if enabled - if e.attestationManager != nil { - e.attestationManager.Start() - } - // Start the health export ticker go func() { ticker := time.NewTicker(e.options.config.Interval.Duration) @@ -158,9 +148,6 @@ func (e *healthExporter) Start() error { // Stop gracefully shuts down the exporter func (e *healthExporter) Stop() error { log.Logger.Infow("Stopping health exporter") - if e.attestationManager != nil { - e.attestationManager.Stop() - } e.cancel() return nil } @@ -267,17 +254,39 @@ func (e *healthExporter) refreshConfigFromMetadata(ctx context.Context) { return } - // Load metrics endpoint (update even if empty to handle un-enrollment) - if metricsEndpoint, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "metrics_endpoint"); err == nil { - if metricsEndpoint != "" { - validated, validateErr := endpoint.ValidateBackendEndpoint(metricsEndpoint) - if validateErr != nil { - log.Logger.Errorw("ignoring invalid metrics endpoint from metadata", "error", validateErr) - metricsEndpoint = "" + metricsEndpoint := "" + logsEndpoint := "" + baseURL, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "backend_base_url") + if err != nil { + log.Logger.Errorw("failed to read backend base URL from metadata", "error", err) + baseURL = "" + } + if baseURL != "" { + validated, validateErr := endpoint.ValidateBackendEndpoint(baseURL) + if validateErr != nil { + log.Logger.Errorw("ignoring invalid backend base URL from metadata", "error", validateErr) + metricsEndpoint = e.readValidatedEndpoint(ctx, "metrics_endpoint") + logsEndpoint = e.readValidatedEndpoint(ctx, "logs_endpoint") + } else { + if joined, joinErr := endpoint.JoinPath(validated, "api", "v1", "health", "metrics"); joinErr == nil { + metricsEndpoint = joined } else { - metricsEndpoint = validated.String() + log.Logger.Errorw("failed to derive metrics endpoint from backend base URL", "error", joinErr) + metricsEndpoint = e.readValidatedEndpoint(ctx, "metrics_endpoint") + } + if joined, joinErr := endpoint.JoinPath(validated, "api", "v1", "health", "logs"); joinErr == nil { + logsEndpoint = joined + } else { + log.Logger.Errorw("failed to derive logs endpoint from backend base URL", "error", joinErr) + logsEndpoint = e.readValidatedEndpoint(ctx, "logs_endpoint") } } + } else { + metricsEndpoint = e.readValidatedEndpoint(ctx, "metrics_endpoint") + logsEndpoint = e.readValidatedEndpoint(ctx, "logs_endpoint") + } + + { if e.options.config.MetricsEndpoint != metricsEndpoint { e.options.config.MetricsEndpoint = metricsEndpoint if metricsEndpoint == "" { @@ -286,21 +295,9 @@ func (e *healthExporter) refreshConfigFromMetadata(ctx context.Context) { log.Logger.Infow("updated metrics endpoint from metadata", "metrics_endpoint", metricsEndpoint) } } - } else { - log.Logger.Errorw("failed to read metrics endpoint from metadata", "error", err) } - // Load logs endpoint (update even if empty to handle un-enrollment) - if logsEndpoint, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "logs_endpoint"); err == nil { - if logsEndpoint != "" { - validated, validateErr := endpoint.ValidateBackendEndpoint(logsEndpoint) - if validateErr != nil { - log.Logger.Errorw("ignoring invalid logs endpoint from metadata", "error", validateErr) - logsEndpoint = "" - } else { - logsEndpoint = validated.String() - } - } + { if e.options.config.LogsEndpoint != logsEndpoint { e.options.config.LogsEndpoint = logsEndpoint if logsEndpoint == "" { @@ -309,8 +306,6 @@ func (e *healthExporter) refreshConfigFromMetadata(ctx context.Context) { log.Logger.Infow("updated logs endpoint from metadata", "logs_endpoint", logsEndpoint) } } - } else { - log.Logger.Errorw("failed to read logs endpoint from metadata", "error", err) } // Load auth token (update even if empty to handle un-enrollment) @@ -355,14 +350,28 @@ func (e *healthExporter) refreshJWTToken(ctx context.Context) (string, error) { return "", fmt.Errorf("no SAK token available for JWT refresh") } - // Read enroll endpoint from metadata (stored during enrollment) - enrollEndpoint, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "enroll_endpoint") - if err != nil || enrollEndpoint == "" { - return "", fmt.Errorf("no enroll endpoint available for JWT refresh") + baseURL, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, "backend_base_url") + if err == nil && baseURL != "" { + if _, validateErr := endpoint.ValidateBackendEndpoint(baseURL); validateErr != nil { + log.Logger.Errorw("ignoring invalid backend base URL for JWT refresh", "backend_base_url", baseURL, "error", validateErr) + baseURL = "" + } + } + if err != nil || baseURL == "" { + baseURL, err = e.readLegacyBackendBaseURL(ctx) + if err != nil { + return "", err + } + if baseURL == "" { + return "", fmt.Errorf("no backend base URL available for JWT refresh") + } } - // Perform enrollment to get new JWT token using the shared function - newJWT, err := enrollment.PerformEnrollment(ctx, enrollEndpoint, sakToken) + client, err := newBackendClient(baseURL) + if err != nil { + return "", fmt.Errorf("failed to create backend client for JWT refresh: %w", err) + } + newJWT, err := client.Enroll(ctx, sakToken) if err != nil { return "", fmt.Errorf("failed to refresh JWT token: %w", err) } @@ -378,3 +387,32 @@ func (e *healthExporter) refreshJWTToken(ctx context.Context) (string, error) { log.Logger.Infow("Successfully refreshed JWT token") return newJWT, nil } + +func (e *healthExporter) readValidatedEndpoint(ctx context.Context, key string) string { + value, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, key) + if err != nil || value == "" { + return "" + } + validated, err := endpoint.ValidateBackendEndpoint(value) + if err != nil { + log.Logger.Errorw("ignoring invalid legacy endpoint from metadata", "key", key, "error", err) + return "" + } + return validated.String() +} + +func (e *healthExporter) readLegacyBackendBaseURL(ctx context.Context) (string, error) { + for _, key := range []string{"enroll_endpoint", "metrics_endpoint", "logs_endpoint", "nonce_endpoint"} { + value, err := pkgmetadata.ReadMetadata(ctx, e.options.dbRO, key) + if err != nil || value == "" { + continue + } + baseURL, err := endpoint.DeriveBackendBaseURL(value) + if err != nil { + log.Logger.Errorw("ignoring invalid legacy backend endpoint for JWT refresh", "key", key, "value", value, "error", err) + continue + } + return baseURL, nil + } + return "", nil +} diff --git a/internal/exporter/exporter_test.go b/internal/exporter/exporter_test.go index a30f3bb2..2a60edba 100644 --- a/internal/exporter/exporter_test.go +++ b/internal/exporter/exporter_test.go @@ -38,6 +38,7 @@ import ( pkgmetadata "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metadata" pkgmetrics "github.com/NVIDIA/fleet-intelligence-sdk/pkg/metrics" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/collector" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter/writer" @@ -706,10 +707,7 @@ func TestRefreshConfigFromMetadata(t *testing.T) { ctx := context.Background() - // Insert test data into metadata table using SetMetadata - err := pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "https://new-metrics.example.com") - require.NoError(t, err) - err = pkgmetadata.SetMetadata(ctx, tmpDB, "logs_endpoint", "https://new-logs.example.com") + err := pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "https://backend.example.com") require.NoError(t, err) err = pkgmetadata.SetMetadata(ctx, tmpDB, pkgmetadata.MetadataKeyToken, "new-test-token") require.NoError(t, err) @@ -736,8 +734,8 @@ func TestRefreshConfigFromMetadata(t *testing.T) { he.refreshConfigFromMetadata(ctx) // Verify config was updated - assert.Equal(t, "https://new-metrics.example.com", he.options.config.MetricsEndpoint) - assert.Equal(t, "https://new-logs.example.com", he.options.config.LogsEndpoint) + assert.Equal(t, "https://backend.example.com/api/v1/health/metrics", he.options.config.MetricsEndpoint) + assert.Equal(t, "https://backend.example.com/api/v1/health/logs", he.options.config.LogsEndpoint) assert.Equal(t, "new-test-token", he.options.config.AuthToken) // Cleanup @@ -751,10 +749,7 @@ func TestRefreshConfigFromMetadata(t *testing.T) { ctx := context.Background() - // Insert empty values using SetMetadata - err := pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "") - require.NoError(t, err) - err = pkgmetadata.SetMetadata(ctx, tmpDB, "logs_endpoint", "") + err := pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "") require.NoError(t, err) cfg := &config.HealthExporterConfig{ @@ -792,9 +787,7 @@ func TestRefreshConfigFromMetadata(t *testing.T) { ctx := context.Background() - err := pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "http://bad-metrics.example.com") - require.NoError(t, err) - err = pkgmetadata.SetMetadata(ctx, tmpDB, "logs_endpoint", "https://user@bad-logs.example.com") + err := pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "http://bad-backend.example.com") require.NoError(t, err) cfg := &config.HealthExporterConfig{ @@ -821,6 +814,41 @@ func TestRefreshConfigFromMetadata(t *testing.T) { err = exporter.Stop() require.NoError(t, err) }) + + t.Run("falls back to legacy endpoints when backend base URL is invalid", func(t *testing.T) { + tmpDB := setupTestDB(t) + defer tmpDB.Close() + + ctx := context.Background() + + err := pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "http://bad-backend.example.com") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "https://legacy.example.com/api/v1/health/metrics") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "logs_endpoint", "https://legacy.example.com/api/v1/health/logs") + require.NoError(t, err) + + cfg := &config.HealthExporterConfig{ + Interval: metav1.Duration{Duration: 1 * time.Minute}, + Timeout: metav1.Duration{Duration: 30 * time.Second}, + } + + exporter, err := New(ctx, + WithConfig(cfg), + WithDatabaseConnections(tmpDB, tmpDB), + WithMachineID("test-machine-id"), + ) + require.NoError(t, err) + + he := exporter.(*healthExporter) + he.refreshConfigFromMetadata(ctx) + + assert.Equal(t, "https://legacy.example.com/api/v1/health/metrics", he.options.config.MetricsEndpoint) + assert.Equal(t, "https://legacy.example.com/api/v1/health/logs", he.options.config.LogsEndpoint) + + err = exporter.Stop() + require.NoError(t, err) + }) } // TestUpdateTokenInMetadata tests the updateTokenInMetadata function @@ -932,19 +960,16 @@ func TestRefreshJWTToken(t *testing.T) { }) t.Run("refreshes token successfully", func(t *testing.T) { - // Mock enrollment server expectedToken := "new-jwt-token" - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Verify request - assert.Equal(t, http.MethodPost, r.Method) - assert.Equal(t, "Bearer test-sak-token", r.Header.Get("Authorization")) - - // Return new token - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"jwt_assertion": "%s"}`, expectedToken) - })) - defer server.Close() + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + assert.Equal(t, "https://backend.example.com", rawBaseURL) + return &fakeJWTRefreshClient{ + expectedSAK: "test-sak-token", + token: expectedToken, + }, nil + } tmpDB := setupTestDB(t) defer tmpDB.Close() @@ -953,7 +978,7 @@ func TestRefreshJWTToken(t *testing.T) { // Setup metadata err := pkgmetadata.SetMetadata(ctx, tmpDB, "sak_token", "test-sak-token") require.NoError(t, err) - err = pkgmetadata.SetMetadata(ctx, tmpDB, "enroll_endpoint", server.URL) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "https://backend.example.com") require.NoError(t, err) cfg := &config.HealthExporterConfig{ @@ -984,6 +1009,168 @@ func TestRefreshJWTToken(t *testing.T) { err = exporter.Stop() require.NoError(t, err) }) + + t.Run("refreshes token successfully using legacy enroll endpoint", func(t *testing.T) { + expectedToken := "new-jwt-token" + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + assert.Equal(t, "https://backend.example.com", rawBaseURL) + return &fakeJWTRefreshClient{ + expectedSAK: "test-sak-token", + token: expectedToken, + }, nil + } + + tmpDB := setupTestDB(t) + defer tmpDB.Close() + + ctx := context.Background() + err := pkgmetadata.SetMetadata(ctx, tmpDB, "sak_token", "test-sak-token") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "enroll_endpoint", "https://backend.example.com/api/v1/health/enroll") + require.NoError(t, err) + + cfg := &config.HealthExporterConfig{ + Interval: metav1.Duration{Duration: 1 * time.Minute}, + Timeout: metav1.Duration{Duration: 30 * time.Second}, + } + + exporter, err := New(ctx, + WithConfig(cfg), + WithDatabaseConnections(tmpDB, tmpDB), + WithMachineID("test-machine-id"), + ) + require.NoError(t, err) + require.NotNil(t, exporter) + + he := exporter.(*healthExporter) + + token, err := he.refreshJWTToken(ctx) + require.NoError(t, err) + assert.Equal(t, expectedToken, token) + + err = exporter.Stop() + require.NoError(t, err) + }) + + t.Run("refreshes token using legacy endpoint when backend base URL is invalid", func(t *testing.T) { + expectedToken := "new-jwt-token" + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + assert.Equal(t, "https://backend.example.com", rawBaseURL) + return &fakeJWTRefreshClient{ + expectedSAK: "test-sak-token", + token: expectedToken, + }, nil + } + + tmpDB := setupTestDB(t) + defer tmpDB.Close() + + ctx := context.Background() + err := pkgmetadata.SetMetadata(ctx, tmpDB, "sak_token", "test-sak-token") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "backend_base_url", "http://bad-backend.example.com") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "enroll_endpoint", "https://backend.example.com/api/v1/health/enroll") + require.NoError(t, err) + + cfg := &config.HealthExporterConfig{ + Interval: metav1.Duration{Duration: 1 * time.Minute}, + Timeout: metav1.Duration{Duration: 30 * time.Second}, + } + + exporter, err := New(ctx, + WithConfig(cfg), + WithDatabaseConnections(tmpDB, tmpDB), + WithMachineID("test-machine-id"), + ) + require.NoError(t, err) + + he := exporter.(*healthExporter) + + token, err := he.refreshJWTToken(ctx) + require.NoError(t, err) + assert.Equal(t, expectedToken, token) + + err = exporter.Stop() + require.NoError(t, err) + }) + + t.Run("refreshes token using later legacy endpoint when earlier one is malformed", func(t *testing.T) { + expectedToken := "new-jwt-token" + originalFactory := newBackendClient + t.Cleanup(func() { newBackendClient = originalFactory }) + newBackendClient = func(rawBaseURL string) (backendclient.Client, error) { + assert.Equal(t, "https://backend.example.com", rawBaseURL) + return &fakeJWTRefreshClient{ + expectedSAK: "test-sak-token", + token: expectedToken, + }, nil + } + + tmpDB := setupTestDB(t) + defer tmpDB.Close() + + ctx := context.Background() + err := pkgmetadata.SetMetadata(ctx, tmpDB, "sak_token", "test-sak-token") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "enroll_endpoint", "https://backend.example.com?bad=query") + require.NoError(t, err) + err = pkgmetadata.SetMetadata(ctx, tmpDB, "metrics_endpoint", "https://backend.example.com/api/v1/health/metrics") + require.NoError(t, err) + + cfg := &config.HealthExporterConfig{ + Interval: metav1.Duration{Duration: 1 * time.Minute}, + Timeout: metav1.Duration{Duration: 30 * time.Second}, + } + + exporter, err := New(ctx, + WithConfig(cfg), + WithDatabaseConnections(tmpDB, tmpDB), + WithMachineID("test-machine-id"), + ) + require.NoError(t, err) + + he := exporter.(*healthExporter) + + token, err := he.refreshJWTToken(ctx) + require.NoError(t, err) + assert.Equal(t, expectedToken, token) + + err = exporter.Stop() + require.NoError(t, err) + }) +} + +type fakeJWTRefreshClient struct { + expectedSAK string + token string +} + +func (f *fakeJWTRefreshClient) Enroll(_ context.Context, sakToken string) (string, error) { + if f.expectedSAK != "" && sakToken != f.expectedSAK { + return "", fmt.Errorf("unexpected sak token %q", sakToken) + } + return f.token, nil +} + +func (f *fakeJWTRefreshClient) UpsertNode(context.Context, string, *backendclient.NodeUpsertRequest, string) error { + return nil +} + +func (f *fakeJWTRefreshClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return nil, nil +} + +func (f *fakeJWTRefreshClient) SubmitAttestation(context.Context, string, *backendclient.AttestationRequest, string) error { + return nil +} + +func (f *fakeJWTRefreshClient) RefreshToken(context.Context, string) (string, error) { + return "", nil } // TestIntegration provides integration tests diff --git a/internal/exporter/options.go b/internal/exporter/options.go index f89e8d95..dfc22438 100644 --- a/internal/exporter/options.go +++ b/internal/exporter/options.go @@ -180,10 +180,6 @@ func (c *exporterOptions) validate() error { return errors.New("components registry is required when IncludeComponentData is enabled") } - if c.config.IncludeMachineInfo && c.nvmlInstance == nil { - return errors.New("NVML instance is required when IncludeMachineInfo is enabled") - } - // Machine ID is always required - it should be set by server via WithMachineID if c.machineID == "" { return errors.New("machine ID is required - must be set via WithMachineID()") diff --git a/internal/exporter/options_test.go b/internal/exporter/options_test.go index e8d0cad8..3f614770 100644 --- a/internal/exporter/options_test.go +++ b/internal/exporter/options_test.go @@ -307,18 +307,6 @@ func TestExporterOptionsValidate(t *testing.T) { wantErr: true, expectedErr: "components registry is required when IncludeComponentData is enabled", }, - { - name: "machine info enabled but no NVML instance", - setupOpts: func() *exporterOptions { - return &exporterOptions{ - config: &config.HealthExporterConfig{ - IncludeMachineInfo: true, - }, - } - }, - wantErr: true, - expectedErr: "NVML instance is required when IncludeMachineInfo is enabled", - }, } for _, tt := range tests { diff --git a/internal/exporter/writer/http.go b/internal/exporter/writer/http.go index dff82093..c5aa7730 100644 --- a/internal/exporter/writer/http.go +++ b/internal/exporter/writer/http.go @@ -228,6 +228,7 @@ func (w *httpWriter) sendOTLPRequest(ctx context.Context, reqData []byte, dataTy req.Header.Set("X-Machine-ID", machineID) req.Header.Set("X-Data-Type", dataType) req.Header.Set("X-Collection-ID", collectionID) + req.Header.Set("X-Agent-Mode", "direct-inventory-write") if authToken != "" { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", authToken)) diff --git a/internal/inventory/hash.go b/internal/inventory/hash.go new file mode 100644 index 00000000..d0284811 --- /dev/null +++ b/internal/inventory/hash.go @@ -0,0 +1,41 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "time" +) + +// ComputeHash returns a deterministic hash for the stable inventory contents. +func ComputeHash(snap *Snapshot) (string, error) { + if snap == nil { + return "", fmt.Errorf("inventory snapshot is nil") + } + normalized := *snap + normalized.CollectedAt = time.Time{} + normalized.InventoryHash = "" + + payload, err := json.Marshal(normalized) + if err != nil { + return "", fmt.Errorf("marshal inventory snapshot: %w", err) + } + sum := sha256.Sum256(payload) + return hex.EncodeToString(sum[:]), nil +} diff --git a/internal/inventory/hash_test.go b/internal/inventory/hash_test.go new file mode 100644 index 00000000..49b5a9c9 --- /dev/null +++ b/internal/inventory/hash_test.go @@ -0,0 +1,49 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestComputeHashIgnoresCollectedAtAndExistingHash(t *testing.T) { + base := &Snapshot{ + CollectedAt: time.Unix(100, 0).UTC(), + InventoryHash: "old-hash", + Hostname: "host-a", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + } + other := *base + other.CollectedAt = time.Unix(200, 0).UTC() + other.InventoryHash = "different-old-hash" + + hash1, err := ComputeHash(base) + require.NoError(t, err) + hash2, err := ComputeHash(&other) + require.NoError(t, err) + require.Equal(t, hash1, hash2) + + other.Hostname = "host-b" + hash3, err := ComputeHash(&other) + require.NoError(t, err) + require.NotEqual(t, hash1, hash3) +} diff --git a/internal/inventory/manager.go b/internal/inventory/manager.go new file mode 100644 index 00000000..a720375b --- /dev/null +++ b/internal/inventory/manager.go @@ -0,0 +1,192 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "math/big" + "sync" + "time" + + "github.com/NVIDIA/fleet-intelligence-sdk/pkg/log" +) + +// Manager coordinates periodic inventory collection into a store. +type Manager interface { + Run(ctx context.Context) error + CollectOnce(ctx context.Context) (*Snapshot, error) +} + +type manager struct { + mu sync.RWMutex + exportMu sync.Mutex + source Source + sink Sink + config InventoryConfig + + lastSnapshot *Snapshot + lastExportedHash string +} + +// NewManager creates an inventory manager. +func NewManager(source Source, sink Sink, cfg InventoryConfig) Manager { + return &manager{ + source: source, + sink: sink, + config: cfg, + } +} + +func (m *manager) Run(ctx context.Context) error { + _, err := m.CollectOnce(ctx) + if err != nil { + log.Logger.Warnw("initial inventory collection failed", "error", err) + } + if m.config.Interval <= 0 { + return nil + } + nextInterval := m.nextInterval(err) + if m.config.JitterEnabled { + nextInterval += calculateJitter(initialJitterCap(nextInterval)) + } + + for { + if err := sleepWithContext(ctx, nextInterval); err != nil { + return err + } + _, err = m.CollectOnce(ctx) + nextInterval = m.nextInterval(err) + } +} + +func (m *manager) nextInterval(err error) time.Duration { + nextInterval := m.config.Interval + if err != nil && m.config.RetryInterval > 0 && m.config.RetryInterval < nextInterval { + nextInterval = m.config.RetryInterval + if m.config.JitterEnabled { + nextInterval += calculateJitter(retryJitterCap(m.config.RetryInterval)) + } + } + return nextInterval +} + +func (m *manager) CollectOnce(ctx context.Context) (*Snapshot, error) { + if m.source == nil { + return nil, fmt.Errorf("inventory source is required") + } + snap, err := m.source.Collect(ctx) + if err != nil { + return nil, err + } + if snap == nil { + return nil, fmt.Errorf("inventory source returned nil snapshot") + } + hash, err := ComputeHash(snap) + if err != nil { + return nil, err + } + snap.InventoryHash = hash + + m.mu.Lock() + cloned := *snap + m.lastSnapshot = &cloned + m.mu.Unlock() + + if m.sink != nil { + m.exportMu.Lock() + defer m.exportMu.Unlock() + + m.mu.RLock() + alreadyExported := m.lastExportedHash == hash + m.mu.RUnlock() + if !alreadyExported { + if err := m.sink.Export(ctx, snap); err != nil { + if errors.Is(err, ErrNotReady) { + return snap, nil + } + return nil, err + } + m.mu.Lock() + m.lastExportedHash = hash + m.mu.Unlock() + } + } + + return snap, nil +} + +func sleepWithContext(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } + } + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func calculateJitter(maxJitter time.Duration) time.Duration { + if maxJitter <= 0 { + return 0 + } + maxMs := int64(maxJitter / time.Millisecond) + if maxMs <= 0 { + return 0 + } + randomMs, err := rand.Int(rand.Reader, big.NewInt(maxMs)) + if err != nil { + log.Logger.Warnw("failed to generate secure inventory jitter, using fallback", "error", err) + return time.Duration(time.Now().UnixNano()%maxMs) * time.Millisecond + } + return time.Duration(randomMs.Int64()) * time.Millisecond +} + +func initialJitterCap(interval time.Duration) time.Duration { + if interval <= 0 { + return 0 + } + jitter := interval / 4 + const maxInitialJitter = 30 * time.Minute + if jitter > maxInitialJitter { + return maxInitialJitter + } + return jitter +} + +func retryJitterCap(retryInterval time.Duration) time.Duration { + if retryInterval <= 0 { + return 0 + } + jitter := retryInterval / 2 + const maxRetryJitter = 5 * time.Minute + if jitter > maxRetryJitter { + return maxRetryJitter + } + return jitter +} diff --git a/internal/inventory/manager_run_test.go b/internal/inventory/manager_run_test.go new file mode 100644 index 00000000..ca21b8c1 --- /dev/null +++ b/internal/inventory/manager_run_test.go @@ -0,0 +1,160 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type errSource struct{ err error } + +func (s errSource) Collect(context.Context) (*Snapshot, error) { return nil, s.err } + +type nilSnapshotSource struct{} + +func (nilSnapshotSource) Collect(context.Context) (*Snapshot, error) { return nil, nil } + +type countingSource struct { + collectCh chan struct{} +} + +func (s *countingSource) Collect(context.Context) (*Snapshot, error) { + if s.collectCh != nil { + s.collectCh <- struct{}{} + } + return &Snapshot{MachineID: "machine-1", Hostname: "host-a"}, nil +} + +func TestManagerCollectOnceErrors(t *testing.T) { + _, err := NewManager(nil, nil, InventoryConfig{}).CollectOnce(context.Background()) + require.ErrorContains(t, err, "inventory source is required") + + _, err = NewManager(errSource{err: errors.New("boom")}, nil, InventoryConfig{}).CollectOnce(context.Background()) + require.ErrorContains(t, err, "boom") + + _, err = NewManager(nilSnapshotSource{}, nil, InventoryConfig{}).CollectOnce(context.Background()) + require.ErrorContains(t, err, "nil snapshot") +} + +func TestManagerRunWithZeroInterval(t *testing.T) { + src := &fakeSource{ + snapshots: []*Snapshot{{MachineID: "machine-1", Hostname: "host-a"}}, + } + sink := &fakeSink{} + + err := NewManager(src, sink, InventoryConfig{}).Run(context.Background()) + require.NoError(t, err) + require.Len(t, sink.exported, 1) +} + +func TestManagerRunStopsOnContextCancel(t *testing.T) { + src := &fakeSource{ + snapshots: []*Snapshot{{MachineID: "machine-1", Hostname: "host-a"}}, + } + sink := &fakeSink{ready: make(chan struct{}, 1)} + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan error, 1) + go func() { + done <- NewManager(src, sink, InventoryConfig{Interval: 10 * time.Millisecond}).Run(ctx) + }() + + select { + case <-sink.ready: + case <-time.After(250 * time.Millisecond): + t.Fatal("timed out waiting for inventory export") + } + cancel() + + err := <-done + require.ErrorIs(t, err, context.Canceled) + require.NotEmpty(t, sink.exported) +} + +func TestSleepWithContext(t *testing.T) { + require.NoError(t, sleepWithContext(context.Background(), 0)) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, sleepWithContext(ctx, 0), context.Canceled) + + ctx, cancel = context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, sleepWithContext(ctx, time.Hour), context.Canceled) +} + +func TestInventoryJitterHelpers(t *testing.T) { + require.Equal(t, time.Duration(0), initialJitterCap(0)) + require.Equal(t, 15*time.Second, initialJitterCap(time.Minute)) + require.Equal(t, 30*time.Minute, initialJitterCap(4*time.Hour)) + + require.Equal(t, time.Duration(0), retryJitterCap(0)) + require.Equal(t, 30*time.Second, retryJitterCap(time.Minute)) + require.Equal(t, 5*time.Minute, retryJitterCap(20*time.Minute)) + + require.Equal(t, time.Duration(0), calculateJitter(0)) + jitter := calculateJitter(50 * time.Millisecond) + require.GreaterOrEqual(t, jitter, time.Duration(0)) + require.Less(t, jitter, 50*time.Millisecond) +} + +func TestManagerRunUsesRetryIntervalWithoutJitter(t *testing.T) { + src := errSource{err: errors.New("boom")} + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + + start := time.Now() + err := NewManager(src, nil, InventoryConfig{ + Interval: time.Hour, + RetryInterval: 5 * time.Millisecond, + }).Run(ctx) + elapsed := time.Since(start) + + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, elapsed, 15*time.Millisecond) + require.Less(t, elapsed, 100*time.Millisecond) +} + +func TestManagerRunWaitsIntervalBeforeSecondCollect(t *testing.T) { + src := &countingSource{collectCh: make(chan struct{}, 4)} + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- NewManager(src, nil, InventoryConfig{Interval: 50 * time.Millisecond}).Run(ctx) + }() + + select { + case <-src.collectCh: + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for initial collection") + } + + select { + case <-src.collectCh: + t.Fatal("second collection happened before interval elapsed") + case <-time.After(20 * time.Millisecond): + } + + cancel() + require.ErrorIs(t, <-done, context.Canceled) +} diff --git a/internal/inventory/manager_test.go b/internal/inventory/manager_test.go new file mode 100644 index 00000000..e4b1b830 --- /dev/null +++ b/internal/inventory/manager_test.go @@ -0,0 +1,149 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inventory + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type fakeSource struct { + mu sync.Mutex + snapshots []*Snapshot + index int +} + +func (f *fakeSource) Collect(context.Context) (*Snapshot, error) { + f.mu.Lock() + defer f.mu.Unlock() + if len(f.snapshots) == 0 { + return nil, nil + } + if f.index >= len(f.snapshots) { + last := *f.snapshots[len(f.snapshots)-1] + return &last, nil + } + snap := *f.snapshots[f.index] + f.index++ + return &snap, nil +} + +type fakeSink struct { + mu sync.Mutex + exported []*Snapshot + ready chan struct{} +} + +func (f *fakeSink) Export(_ context.Context, snap *Snapshot) error { + f.mu.Lock() + defer f.mu.Unlock() + cloned := *snap + f.exported = append(f.exported, &cloned) + if f.ready != nil { + select { + case f.ready <- struct{}{}: + default: + } + } + return nil +} + +func TestManagerCollectOnceExportsOnlyWhenInventoryChanges(t *testing.T) { + src := &fakeSource{ + snapshots: []*Snapshot{ + { + CollectedAt: time.Unix(100, 0).UTC(), + Hostname: "host-a", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + }, + { + CollectedAt: time.Unix(200, 0).UTC(), + Hostname: "host-a", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + }, + { + CollectedAt: time.Unix(300, 0).UTC(), + Hostname: "host-b", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + }, + }, + } + sink := &fakeSink{} + mgr := NewManager(src, sink, InventoryConfig{}) + + snap1, err := mgr.CollectOnce(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, snap1.InventoryHash) + require.Len(t, sink.exported, 1) + + snap2, err := mgr.CollectOnce(context.Background()) + require.NoError(t, err) + require.Equal(t, snap1.InventoryHash, snap2.InventoryHash) + require.Len(t, sink.exported, 1) + + snap3, err := mgr.CollectOnce(context.Background()) + require.NoError(t, err) + require.NotEqual(t, snap1.InventoryHash, snap3.InventoryHash) + require.Len(t, sink.exported, 2) +} + +func TestManagerCollectOnceConcurrentExportSingleHash(t *testing.T) { + src := &fakeSource{ + snapshots: []*Snapshot{{ + CollectedAt: time.Unix(100, 0).UTC(), + Hostname: "host-a", + MachineID: "machine-id", + Resources: Resources{ + CPUInfo: CPUInfo{Type: "Xeon", LogicalCores: 64}, + }, + }}, + } + sink := &fakeSink{} + mgr := NewManager(src, sink, InventoryConfig{}) + + var ( + wg sync.WaitGroup + errs = make(chan error, 8) + ) + for range 8 { + wg.Add(1) + go func() { + defer wg.Done() + _, err := mgr.CollectOnce(context.Background()) + errs <- err + }() + } + wg.Wait() + close(errs) + + for err := range errs { + require.NoError(t, err) + } + require.Len(t, sink.exported, 1) +} diff --git a/internal/inventory/mapper/backend.go b/internal/inventory/mapper/backend.go new file mode 100644 index 00000000..e76bdac5 --- /dev/null +++ b/internal/inventory/mapper/backend.go @@ -0,0 +1,114 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mapper contains inventory payload mappers. +package mapper + +import ( + "strconv" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" +) + +// ToNodeUpsertRequest maps an inventory snapshot to the backend node-upsert contract. +func ToNodeUpsertRequest(s *inventory.Snapshot) *backendclient.NodeUpsertRequest { + if s == nil { + return nil + } + gpus := make([]backendclient.GPUDevice, 0, len(s.Resources.GPUInfo.GPUs)) + for _, gpu := range s.Resources.GPUInfo.GPUs { + gpus = append(gpus, backendclient.GPUDevice{ + UUID: gpu.UUID, + BusID: gpu.BusID, + SN: gpu.SN, + MinorID: gpu.MinorID, + BoardID: gpu.BoardID, + VBIOSVersion: gpu.VBIOSVersion, + ChassisSN: gpu.ChassisSN, + GPUIndex: gpu.GPUIndex, + }) + } + + blockDevices := make([]backendclient.BlockDevice, 0, len(s.Resources.DiskInfo.BlockDevices)) + for _, disk := range s.Resources.DiskInfo.BlockDevices { + blockDevices = append(blockDevices, backendclient.BlockDevice{ + Name: disk.Name, + Type: disk.Type, + Size: disk.Size, + WWN: disk.WWN, + MountPoint: disk.MountPoint, + FSType: disk.FSType, + PartUUID: disk.PartUUID, + Parents: append([]string(nil), disk.Parents...), + }) + } + + interfaces := make([]backendclient.NICInterface, 0, len(s.Resources.NICInfo.PrivateIPInterfaces)) + for _, nic := range s.Resources.NICInfo.PrivateIPInterfaces { + interfaces = append(interfaces, backendclient.NICInterface{ + Interface: nic.Interface, + MAC: nic.MAC, + IP: nic.IP, + }) + } + + return &backendclient.NodeUpsertRequest{ + Hostname: s.Hostname, + AgentConfig: backendclient.AgentConfig{ + TotalComponents: s.AgentConfig.TotalComponents, + RetentionPeriodSeconds: s.AgentConfig.RetentionPeriodSeconds, + EnabledComponents: append([]string(nil), s.AgentConfig.EnabledComponents...), + DisabledComponents: append([]string(nil), s.AgentConfig.DisabledComponents...), + }, + MachineID: s.MachineID, + SystemUUID: s.SystemUUID, + BootID: s.BootID, + OperatingSystem: s.OperatingSystem, + OSImage: s.OSImage, + KernelVersion: s.KernelVersion, + AgentVersion: s.AgentVersion, + GPUDriverVersion: s.GPUDriverVersion, + CUDAVersion: s.CUDAVersion, + DCGMVersion: s.DCGMVersion, + ContainerRuntimeVersion: s.ContainerRuntimeVersion, + NetPrivateIP: s.NetPrivateIP, + Resources: backendclient.NodeResources{ + CPUInfo: backendclient.CPUInfo{ + Type: s.Resources.CPUInfo.Type, + Manufacturer: s.Resources.CPUInfo.Manufacturer, + Architecture: s.Resources.CPUInfo.Architecture, + LogicalCores: strconv.FormatInt(s.Resources.CPUInfo.LogicalCores, 10), + }, + MemoryInfo: backendclient.MemoryInfo{ + TotalBytes: strconv.FormatUint(s.Resources.MemoryInfo.TotalBytes, 10), + }, + GPUInfo: backendclient.GPUInfo{ + Product: s.Resources.GPUInfo.Product, + Manufacturer: s.Resources.GPUInfo.Manufacturer, + Architecture: s.Resources.GPUInfo.Architecture, + Memory: s.Resources.GPUInfo.Memory, + GPUs: gpus, + }, + DiskInfo: backendclient.DiskInfo{ + ContainerRootDisk: s.Resources.DiskInfo.ContainerRootDisk, + BlockDevices: blockDevices, + }, + NICInfo: backendclient.NICInfo{ + PrivateIPInterfaces: interfaces, + }, + }, + } +} diff --git a/internal/inventory/mapper/backend_test.go b/internal/inventory/mapper/backend_test.go new file mode 100644 index 00000000..c870e370 --- /dev/null +++ b/internal/inventory/mapper/backend_test.go @@ -0,0 +1,119 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mapper + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" +) + +func TestToNodeUpsertRequestNil(t *testing.T) { + require.Nil(t, ToNodeUpsertRequest(nil)) +} + +func TestToNodeUpsertRequest(t *testing.T) { + req := ToNodeUpsertRequest(&inventory.Snapshot{ + Hostname: "host-a", + MachineID: "machine-id", + SystemUUID: "uuid-1", + BootID: "boot-1", + OperatingSystem: "linux", + OSImage: "Ubuntu", + KernelVersion: "6.5.0", + AgentVersion: "1.2.3", + GPUDriverVersion: "550.54.15", + CUDAVersion: "12.4", + DCGMVersion: "4.2.3", + ContainerRuntimeVersion: "containerd://1.7.13", + NetPrivateIP: "10.0.0.10", + AgentConfig: inventory.AgentConfig{ + TotalComponents: 30, + RetentionPeriodSeconds: 86400, + EnabledComponents: []string{"cpu", "gpu"}, + DisabledComponents: []string{"disk"}, + }, + Resources: inventory.Resources{ + CPUInfo: inventory.CPUInfo{ + Type: "Xeon", + Manufacturer: "Intel", + Architecture: "x86_64", + LogicalCores: 64, + }, + MemoryInfo: inventory.MemoryInfo{ + TotalBytes: 1024, + }, + GPUInfo: inventory.GPUInfo{ + Product: "H100", + Manufacturer: "NVIDIA", + Architecture: "Hopper", + Memory: "80GB", + GPUs: []inventory.GPUDevice{{ + UUID: "GPU-1", + BusID: "0000:01:00.0", + SN: "serial", + MinorID: "1", + BoardID: 7, + VBIOSVersion: "vbios", + ChassisSN: "chassis", + GPUIndex: "0", + }}, + }, + DiskInfo: inventory.DiskInfo{ + ContainerRootDisk: "/dev/nvme0n1", + BlockDevices: []inventory.BlockDevice{{ + Name: "nvme0n1", + Type: "disk", + Size: 2048, + WWN: "wwn", + MountPoint: "/", + FSType: "ext4", + PartUUID: "part-uuid", + Parents: []string{"parent0"}, + }}, + }, + NICInfo: inventory.NICInfo{ + PrivateIPInterfaces: []inventory.NICInterface{{ + Interface: "eth0", + MAC: "00:11:22:33:44:55", + IP: "10.0.0.10", + }}, + }, + }, + }) + + require.NotNil(t, req) + require.Equal(t, "host-a", req.Hostname) + require.Equal(t, "machine-id", req.MachineID) + require.Equal(t, int64(30), req.AgentConfig.TotalComponents) + require.Equal(t, int64(86400), req.AgentConfig.RetentionPeriodSeconds) + require.Equal(t, []string{"cpu", "gpu"}, req.AgentConfig.EnabledComponents) + require.Equal(t, []string{"disk"}, req.AgentConfig.DisabledComponents) + require.Equal(t, "64", req.Resources.CPUInfo.LogicalCores) + require.Equal(t, "1024", req.Resources.MemoryInfo.TotalBytes) + require.Equal(t, "H100", req.Resources.GPUInfo.Product) + require.Equal(t, "NVIDIA", req.Resources.GPUInfo.Manufacturer) + require.Len(t, req.Resources.GPUInfo.GPUs, 1) + require.Equal(t, 7, req.Resources.GPUInfo.GPUs[0].BoardID) + require.Equal(t, "/dev/nvme0n1", req.Resources.DiskInfo.ContainerRootDisk) + require.Len(t, req.Resources.DiskInfo.BlockDevices, 1) + require.Equal(t, "parent0", req.Resources.DiskInfo.BlockDevices[0].Parents[0]) + require.Len(t, req.Resources.NICInfo.PrivateIPInterfaces, 1) + require.Equal(t, "eth0", req.Resources.NICInfo.PrivateIPInterfaces[0].Interface) + require.Equal(t, "10.0.0.10", req.Resources.NICInfo.PrivateIPInterfaces[0].IP) +} diff --git a/internal/inventory/sink/backend.go b/internal/inventory/sink/backend.go new file mode 100644 index 00000000..a9b1dc1d --- /dev/null +++ b/internal/inventory/sink/backend.go @@ -0,0 +1,78 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package sink contains inventory sink implementations. +package sink + +import ( + "context" + "fmt" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/mapper" +) + +type backendSink struct { + state agentstate.State + clientFactory func(rawBaseURL string) (backendclient.Client, error) +} + +// NewBackendSink creates the backend inventory sink. +func NewBackendSink(state agentstate.State) inventory.Sink { + return &backendSink{ + state: state, + clientFactory: backendclient.New, + } +} + +func (s *backendSink) Export(ctx context.Context, snap *inventory.Snapshot) error { + if s.state == nil { + return fmt.Errorf("inventory backend export requires agent state") + } + if s.clientFactory == nil { + return fmt.Errorf("inventory backend export requires backend client factory") + } + if snap == nil { + return fmt.Errorf("inventory backend export requires inventory snapshot") + } + baseURL, ok, err := s.state.GetBackendBaseURL(ctx) + if err != nil { + return err + } + if !ok || baseURL == "" { + return inventory.ErrNotReady + } + jwt, ok, err := s.state.GetJWT(ctx) + if err != nil { + return err + } + if !ok || jwt == "" { + return inventory.ErrNotReady + } + nodeUUID, ok, err := s.state.GetNodeUUID(ctx) + if err != nil { + return err + } + if !ok || nodeUUID == "" { + return inventory.ErrNotReady + } + client, err := s.clientFactory(baseURL) + if err != nil { + return fmt.Errorf("create backend client: %w", err) + } + return client.UpsertNode(ctx, nodeUUID, mapper.ToNodeUpsertRequest(snap), jwt) +} diff --git a/internal/inventory/sink/backend_test.go b/internal/inventory/sink/backend_test.go new file mode 100644 index 00000000..f144b57f --- /dev/null +++ b/internal/inventory/sink/backend_test.go @@ -0,0 +1,141 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sink + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/backendclient" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" +) + +type fakeState struct { + baseURL string + jwt string + nodeUUID string + err error +} + +func (f fakeState) GetBackendBaseURL(context.Context) (string, bool, error) { + if f.err != nil { + return "", false, f.err + } + return f.baseURL, f.baseURL != "", nil +} +func (f fakeState) SetBackendBaseURL(context.Context, string) error { return nil } +func (f fakeState) GetJWT(context.Context) (string, bool, error) { + if f.err != nil { + return "", false, f.err + } + return f.jwt, f.jwt != "", nil +} +func (f fakeState) SetJWT(context.Context, string) error { return nil } +func (f fakeState) GetSAK(context.Context) (string, bool, error) { return "", false, nil } +func (f fakeState) SetSAK(context.Context, string) error { return nil } +func (f fakeState) GetNodeUUID(context.Context) (string, bool, error) { + if f.err != nil { + return "", false, f.err + } + return f.nodeUUID, f.nodeUUID != "", nil +} +func (f fakeState) SetNodeUUID(context.Context, string) error { return nil } + +type fakeClient struct { + nodeUUID string + req *backendclient.NodeUpsertRequest + jwt string +} + +func (f *fakeClient) Enroll(context.Context, string) (string, error) { return "", nil } +func (f *fakeClient) GetNonce(context.Context, string, string) (*backendclient.NonceResponse, error) { + return nil, nil +} +func (f *fakeClient) SubmitAttestation(context.Context, string, *backendclient.AttestationRequest, string) error { + return nil +} +func (f *fakeClient) RefreshToken(context.Context, string) (string, error) { return "", nil } +func (f *fakeClient) UpsertNode(_ context.Context, nodeUUID string, req *backendclient.NodeUpsertRequest, jwt string) error { + f.nodeUUID = nodeUUID + f.req = req + f.jwt = jwt + return nil +} + +func TestBackendSinkExportNotReady(t *testing.T) { + s := &backendSink{ + state: fakeState{}, + clientFactory: backendclient.New, + } + + err := s.Export(context.Background(), &inventory.Snapshot{}) + require.ErrorIs(t, err, inventory.ErrNotReady) +} + +func TestBackendSinkExportErrors(t *testing.T) { + err := (&backendSink{}).Export(context.Background(), &inventory.Snapshot{}) + require.ErrorContains(t, err, "agent state") + + err = (&backendSink{state: fakeState{baseURL: "https://example.com", jwt: "jwt"}}).Export(context.Background(), &inventory.Snapshot{}) + require.ErrorContains(t, err, "client factory") + + err = (&backendSink{ + state: fakeState{err: errors.New("state error")}, + clientFactory: backendclient.New, + }).Export(context.Background(), &inventory.Snapshot{}) + require.ErrorContains(t, err, "state error") + + err = (&backendSink{ + state: fakeState{baseURL: "https://example.com", jwt: "jwt"}, + clientFactory: backendclient.New, + }).Export(context.Background(), nil) + require.ErrorContains(t, err, "inventory snapshot") + + err = (&backendSink{ + state: fakeState{baseURL: "https://example.com", jwt: "jwt", nodeUUID: "node-1"}, + clientFactory: func(string) (backendclient.Client, error) { + return nil, errors.New("client factory error") + }, + }).Export(context.Background(), &inventory.Snapshot{}) + require.ErrorContains(t, err, "create backend client") +} + +func TestBackendSinkExportUsesState(t *testing.T) { + client := &fakeClient{} + s := &backendSink{ + state: fakeState{ + baseURL: "https://example.com", + jwt: "jwt-token", + nodeUUID: "node-1", + }, + clientFactory: func(string) (backendclient.Client, error) { + return client, nil + }, + } + + err := s.Export(context.Background(), &inventory.Snapshot{ + Hostname: "host-a", + MachineID: "machine-id", + }) + require.NoError(t, err) + require.Equal(t, "node-1", client.nodeUUID) + require.Equal(t, "jwt-token", client.jwt) + require.NotNil(t, client.req) + require.Equal(t, "host-a", client.req.Hostname) +} diff --git a/internal/inventory/source/source.go b/internal/inventory/source/source.go new file mode 100644 index 00000000..79e631b8 --- /dev/null +++ b/internal/inventory/source/source.go @@ -0,0 +1,153 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package source contains inventory collection adapters. +package source + +import ( + "context" + "fmt" + "time" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" +) + +// MachineInfoCollector is the local machine inventory collector dependency. +type MachineInfoCollector interface { + Collect(ctx context.Context) (*machineinfo.MachineInfo, error) +} + +type machineInfoSource struct { + collector MachineInfoCollector + agentConfig inventory.AgentConfig +} + +// NewMachineInfoSource wraps the machine inventory collector as an inventory source. +func NewMachineInfoSource(collector MachineInfoCollector) inventory.Source { + return &machineInfoSource{collector: collector} +} + +// NewMachineInfoSourceWithAgentConfig wraps the machine inventory collector and attaches useful +// agent configuration that should travel with inventory rather than OTLP telemetry. +func NewMachineInfoSourceWithAgentConfig(collector MachineInfoCollector, agentConfig *inventory.AgentConfig) inventory.Source { + var cfg inventory.AgentConfig + if agentConfig != nil { + cfg = *agentConfig + } + return &machineInfoSource{ + collector: collector, + agentConfig: cfg, + } +} + +func (s *machineInfoSource) Collect(ctx context.Context) (*inventory.Snapshot, error) { + if s.collector == nil { + return nil, fmt.Errorf("machine info collector is required") + } + info, err := s.collector.Collect(ctx) + if err != nil { + return nil, err + } + if info == nil { + return nil, fmt.Errorf("machine info collector returned nil machine info") + } + + snap := &inventory.Snapshot{ + CollectedAt: time.Now().UTC(), + Hostname: info.Hostname, + MachineID: info.MachineID, + SystemUUID: info.SystemUUID, + BootID: info.BootID, + OperatingSystem: info.OperatingSystem, + OSImage: info.OSImage, + KernelVersion: info.KernelVersion, + AgentVersion: info.AgentVersion, + GPUDriverVersion: info.GPUDriverVersion, + CUDAVersion: info.CUDAVersion, + DCGMVersion: info.DCGMVersion, + ContainerRuntimeVersion: info.ContainerRuntimeVersion, + AgentConfig: s.agentConfig, + } + if info.CPUInfo != nil { + snap.Resources.CPUInfo = inventory.CPUInfo{ + Type: info.CPUInfo.Type, + Manufacturer: info.CPUInfo.Manufacturer, + Architecture: info.CPUInfo.Architecture, + LogicalCores: info.CPUInfo.LogicalCores, + } + } + if info.MemoryInfo != nil { + snap.Resources.MemoryInfo = inventory.MemoryInfo{ + TotalBytes: info.MemoryInfo.TotalBytes, + } + } + if info.GPUInfo != nil { + snap.Resources.GPUInfo = inventory.GPUInfo{ + Product: info.GPUInfo.Product, + Manufacturer: info.GPUInfo.Manufacturer, + Architecture: info.GPUInfo.Architecture, + Memory: info.GPUInfo.Memory, + } + if len(info.GPUInfo.GPUs) > 0 { + snap.Resources.GPUInfo.GPUs = make([]inventory.GPUDevice, 0, len(info.GPUInfo.GPUs)) + for _, gpu := range info.GPUInfo.GPUs { + snap.Resources.GPUInfo.GPUs = append(snap.Resources.GPUInfo.GPUs, inventory.GPUDevice{ + UUID: gpu.UUID, + BusID: gpu.BusID, + SN: gpu.SN, + MinorID: gpu.MinorID, + BoardID: int(gpu.BoardID), + VBIOSVersion: gpu.VBIOSVersion, + ChassisSN: gpu.ChassisSN, + GPUIndex: gpu.GPUIndex, + }) + } + } + } + if info.DiskInfo != nil { + snap.Resources.DiskInfo = inventory.DiskInfo{ + ContainerRootDisk: info.DiskInfo.ContainerRootDisk, + } + if len(info.DiskInfo.BlockDevices) > 0 { + snap.Resources.DiskInfo.BlockDevices = make([]inventory.BlockDevice, 0, len(info.DiskInfo.BlockDevices)) + for _, disk := range info.DiskInfo.BlockDevices { + snap.Resources.DiskInfo.BlockDevices = append(snap.Resources.DiskInfo.BlockDevices, inventory.BlockDevice{ + Name: disk.Name, + Type: disk.Type, + Size: disk.Size, + WWN: disk.WWN, + MountPoint: disk.MountPoint, + FSType: disk.FSType, + PartUUID: disk.PartUUID, + Parents: append([]string(nil), disk.Parents...), + }) + } + } + } + if info.NICInfo != nil && len(info.NICInfo.PrivateIPInterfaces) > 0 { + snap.Resources.NICInfo.PrivateIPInterfaces = make([]inventory.NICInterface, 0, len(info.NICInfo.PrivateIPInterfaces)) + for _, nic := range info.NICInfo.PrivateIPInterfaces { + snap.Resources.NICInfo.PrivateIPInterfaces = append(snap.Resources.NICInfo.PrivateIPInterfaces, inventory.NICInterface{ + Interface: nic.Interface, + MAC: nic.MAC, + IP: nic.IP, + }) + } + snap.NetPrivateIP = info.NICInfo.PrivateIPInterfaces[0].IP + } + + return snap, nil +} diff --git a/internal/inventory/source/source_test.go b/internal/inventory/source/source_test.go new file mode 100644 index 00000000..3596b3dd --- /dev/null +++ b/internal/inventory/source/source_test.go @@ -0,0 +1,168 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package source + +import ( + "context" + "testing" + + apiv1 "github.com/NVIDIA/fleet-intelligence-sdk/api/v1" + "github.com/stretchr/testify/require" + + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" +) + +type fakeMachineInfoCollector struct { + info *machineinfo.MachineInfo + err error +} + +func (f fakeMachineInfoCollector) Collect(context.Context) (*machineinfo.MachineInfo, error) { + return f.info, f.err +} + +func TestMachineInfoSourceCollect(t *testing.T) { + src := NewMachineInfoSource(fakeMachineInfoCollector{ + info: &machineinfo.MachineInfo{ + AgentVersion: "1.2.3", + GPUDriverVersion: "550.54.15", + CUDAVersion: "12.4", + DCGMVersion: "4.2.3", + ContainerRuntimeVersion: "containerd://1.7.13", + KernelVersion: "6.5.0", + OSImage: "Ubuntu 22.04", + OperatingSystem: "linux", + SystemUUID: "system-uuid", + MachineID: "machine-id", + BootID: "boot-id", + Hostname: "host-a", + CPUInfo: &apiv1.MachineCPUInfo{ + Type: "Xeon", + Manufacturer: "Intel", + Architecture: "x86_64", + LogicalCores: 64, + }, + MemoryInfo: &apiv1.MachineMemoryInfo{ + TotalBytes: 1024, + }, + GPUInfo: &apiv1.MachineGPUInfo{ + Product: "H100", + Manufacturer: "NVIDIA", + Architecture: "Hopper", + Memory: "80GB", + GPUs: []apiv1.MachineGPUInstance{{ + UUID: "GPU-1", + GPUIndex: "0", + BusID: "0000:01:00.0", + SN: "serial", + MinorID: "1", + BoardID: 7, + VBIOSVersion: "vbios", + ChassisSN: "chassis", + }}, + }, + DiskInfo: &apiv1.MachineDiskInfo{ + ContainerRootDisk: "/dev/nvme0n1", + BlockDevices: []apiv1.MachineDiskDevice{{ + Name: "nvme0n1", + Type: "disk", + Size: 2048, + WWN: "wwn", + MountPoint: "/", + FSType: "ext4", + PartUUID: "part-uuid", + Parents: []string{"parent0"}, + }}, + }, + NICInfo: &apiv1.MachineNICInfo{ + PrivateIPInterfaces: []apiv1.MachineNetworkInterface{{ + Interface: "eth0", + MAC: "00:11:22:33:44:55", + IP: "10.0.0.10", + }}, + }, + }, + }) + + snap, err := src.Collect(context.Background()) + require.NoError(t, err) + require.NotNil(t, snap) + require.Equal(t, "machine-id", snap.MachineID) + require.Equal(t, "host-a", snap.Hostname) + require.Equal(t, "10.0.0.10", snap.NetPrivateIP) + require.Equal(t, "Xeon", snap.Resources.CPUInfo.Type) + require.Equal(t, uint64(1024), snap.Resources.MemoryInfo.TotalBytes) + require.Equal(t, "H100", snap.Resources.GPUInfo.Product) + require.Len(t, snap.Resources.GPUInfo.GPUs, 1) + require.Equal(t, 7, snap.Resources.GPUInfo.GPUs[0].BoardID) + require.Equal(t, "/dev/nvme0n1", snap.Resources.DiskInfo.ContainerRootDisk) + require.Len(t, snap.Resources.DiskInfo.BlockDevices, 1) + require.Equal(t, "eth0", snap.Resources.NICInfo.PrivateIPInterfaces[0].Interface) +} + +func TestMachineInfoSourceCollectWithAgentConfig(t *testing.T) { + src := NewMachineInfoSourceWithAgentConfig( + fakeMachineInfoCollector{ + info: &machineinfo.MachineInfo{ + MachineID: "machine-id", + SystemUUID: "system-uuid", + Hostname: "host-a", + }, + }, + &inventory.AgentConfig{ + TotalComponents: 42, + RetentionPeriodSeconds: 86400, + EnabledComponents: []string{"cpu", "gpu"}, + DisabledComponents: []string{"disk"}, + }, + ) + + snap, err := src.Collect(context.Background()) + require.NoError(t, err) + require.NotNil(t, snap) + require.Equal(t, "machine-id", snap.MachineID) + require.Equal(t, int64(42), snap.AgentConfig.TotalComponents) + require.Equal(t, int64(86400), snap.AgentConfig.RetentionPeriodSeconds) + require.Equal(t, []string{"cpu", "gpu"}, snap.AgentConfig.EnabledComponents) + require.Equal(t, []string{"disk"}, snap.AgentConfig.DisabledComponents) +} + +func TestMachineInfoSourceCollectIgnoresSystemUUIDForMachineID(t *testing.T) { + src := NewMachineInfoSource(fakeMachineInfoCollector{ + info: &machineinfo.MachineInfo{ + MachineID: "machine-id", + SystemUUID: "system-uuid", + Hostname: "host-a", + }, + }) + + snap, err := src.Collect(context.Background()) + require.NoError(t, err) + require.NotNil(t, snap) + require.Equal(t, "machine-id", snap.MachineID) +} + +func TestMachineInfoSourceCollectErrors(t *testing.T) { + _, err := NewMachineInfoSource(nil).Collect(context.Background()) + require.ErrorContains(t, err, "collector is required") + + _, err = NewMachineInfoSource(fakeMachineInfoCollector{err: context.DeadlineExceeded}).Collect(context.Background()) + require.ErrorIs(t, err, context.DeadlineExceeded) + + _, err = NewMachineInfoSource(fakeMachineInfoCollector{}).Collect(context.Background()) + require.ErrorContains(t, err, "nil machine info") +} diff --git a/internal/inventory/types.go b/internal/inventory/types.go new file mode 100644 index 00000000..58075d82 --- /dev/null +++ b/internal/inventory/types.go @@ -0,0 +1,135 @@ +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package inventory owns inventory collection and sync state. +package inventory + +import ( + "context" + "errors" + "time" +) + +// ErrNotReady indicates inventory export cannot proceed because backend state is not ready yet. +var ErrNotReady = errors.New("inventory backend not ready") + +// Snapshot is the agent-owned inventory state model. +type Snapshot struct { + CollectedAt time.Time + InventoryHash string + Hostname string + MachineID string + SystemUUID string + BootID string + OperatingSystem string + OSImage string + KernelVersion string + AgentVersion string + GPUDriverVersion string + CUDAVersion string + DCGMVersion string + ContainerRuntimeVersion string + NetPrivateIP string + AgentConfig AgentConfig + Resources Resources +} + +type AgentConfig struct { + TotalComponents int64 + RetentionPeriodSeconds int64 + EnabledComponents []string + DisabledComponents []string +} + +type Resources struct { + CPUInfo CPUInfo + MemoryInfo MemoryInfo + GPUInfo GPUInfo + DiskInfo DiskInfo + NICInfo NICInfo +} + +type CPUInfo struct { + Type string + Manufacturer string + Architecture string + LogicalCores int64 +} + +type MemoryInfo struct { + TotalBytes uint64 +} + +type GPUInfo struct { + Product string + Manufacturer string + Architecture string + Memory string + GPUs []GPUDevice +} + +type GPUDevice struct { + UUID string + BusID string + SN string + MinorID string + BoardID int + VBIOSVersion string + ChassisSN string + GPUIndex string +} + +type DiskInfo struct { + ContainerRootDisk string + BlockDevices []BlockDevice +} + +type BlockDevice struct { + Name string + Type string + Size int64 + WWN string + MountPoint string + FSType string + PartUUID string + Parents []string +} + +type NICInfo struct { + PrivateIPInterfaces []NICInterface +} + +type NICInterface struct { + Interface string + MAC string + IP string +} + +// Source collects inventory from local providers. +type Source interface { + Collect(ctx context.Context) (*Snapshot, error) +} + +// Sink exports inventory snapshots to an external destination. +type Sink interface { + Export(ctx context.Context, snap *Snapshot) error +} + +// InventoryConfig controls periodic inventory scheduling. +type InventoryConfig struct { + Interval time.Duration + RetryInterval time.Duration + JitterEnabled bool +} diff --git a/internal/machineinfo/machineinfo.go b/internal/machineinfo/machineinfo.go index 812f60b9..a6613cae 100644 --- a/internal/machineinfo/machineinfo.go +++ b/internal/machineinfo/machineinfo.go @@ -37,10 +37,10 @@ import ( var getDCGMVersion = dcgmversion.DetectHostengineVersion -// MachineInfo is a custom struct that replaces GPUdVersion with FleetintVersion +// MachineInfo is a custom struct that replaces GPUdVersion with AgentVersion. type MachineInfo struct { - // FleetintVersion represents the current version of Fleet Intelligence agent - FleetintVersion string `json:"fleetintVersion,omitempty"` + // AgentVersion represents the current version of the Fleet Intelligence agent. + AgentVersion string `json:"agentVersion,omitempty"` // GPUDriverVersion represents the current version of GPU driver installed GPUDriverVersion string `json:"gpuDriverVersion,omitempty"` // CUDAVersion represents the current version of cuda library. @@ -124,9 +124,9 @@ func GetMachineInfo(nvmlInstance nvidianvml.Instance, opts ...MachineInfoOption) dcgmVersion, _ := getDCGMVersion() - // Convert to our custom MachineInfo struct with Fleet Intelligence version + // Convert to our custom MachineInfo struct with the agent version. return &MachineInfo{ - FleetintVersion: version.Version, + AgentVersion: version.Version, GPUDriverVersion: gpudInfo.GPUDriverVersion, CUDAVersion: gpudInfo.CUDAVersion, DCGMVersion: dcgmVersion, @@ -156,8 +156,7 @@ func (i *MachineInfo) RenderTable(wr io.Writer) { table := tablewriter.NewWriter(wr) table.SetAlignment(tablewriter.ALIGN_CENTER) - // Show Fleetint Version instead of GPUd Version - table.Append([]string{"Fleetint Version", i.FleetintVersion}) + table.Append([]string{"Agent Version", i.AgentVersion}) table.Append([]string{"Container Runtime Version", i.ContainerRuntimeVersion}) table.Append([]string{"OS Image", i.OSImage}) table.Append([]string{"Kernel Version", i.KernelVersion}) diff --git a/internal/machineinfo/machineinfo_test.go b/internal/machineinfo/machineinfo_test.go index 4f10a5db..68b8c880 100644 --- a/internal/machineinfo/machineinfo_test.go +++ b/internal/machineinfo/machineinfo_test.go @@ -54,7 +54,7 @@ func TestGetMachineInfo(t *testing.T) { validate: func(t *testing.T, info *MachineInfo) { assert.NotNil(t, info) // The version should be set from the version package - assert.Equal(t, version.Version, info.FleetintVersion) + assert.Equal(t, version.Version, info.AgentVersion) assert.Equal(t, "4.2.3", info.DCGMVersion) // Other fields should be populated by the underlying GetMachineInfo assert.NotEmpty(t, info.Hostname) @@ -85,7 +85,7 @@ func TestMachineInfoStruct(t *testing.T) { now := metav1.Now() testInfo := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", GPUDriverVersion: "550.54.15", CUDAVersion: "12.4", DCGMVersion: "4.2.3", @@ -128,7 +128,7 @@ func TestMachineInfoStruct(t *testing.T) { } // Verify all fields are set correctly - assert.Equal(t, "1.0.0-test", testInfo.FleetintVersion) + assert.Equal(t, "1.0.0-test", testInfo.AgentVersion) assert.Equal(t, "550.54.15", testInfo.GPUDriverVersion) assert.Equal(t, "12.4", testInfo.CUDAVersion) assert.Equal(t, "4.2.3", testInfo.DCGMVersion) @@ -199,7 +199,7 @@ func TestRenderTable_Empty(t *testing.T) { // TestRenderTable_BasicFields tests RenderTable with basic fields func TestRenderTable_BasicFields(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", ContainerRuntimeVersion: "containerd://1.7.13", OSImage: "Ubuntu 22.04.4 LTS", KernelVersion: "6.5.0-28-generic", @@ -222,7 +222,7 @@ func TestRenderTable_BasicFields(t *testing.T) { assert.Contains(t, output, "4.2.3") // Verify labels are present - assert.Contains(t, output, "Fleetint Version") + assert.Contains(t, output, "Agent Version") assert.Contains(t, output, "Container Runtime Version") assert.Contains(t, output, "OS Image") assert.Contains(t, output, "Kernel Version") @@ -233,7 +233,7 @@ func TestRenderTable_BasicFields(t *testing.T) { // TestRenderTable_WithCPUInfo tests RenderTable with CPU information func TestRenderTable_WithCPUInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", CPUInfo: &apiv1.MachineCPUInfo{ Type: "Intel(R) Xeon(R) CPU", Manufacturer: "GenuineIntel", @@ -264,7 +264,7 @@ func TestRenderTable_WithCPUInfo(t *testing.T) { // TestRenderTable_WithMemoryInfo tests RenderTable with memory information func TestRenderTable_WithMemoryInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", MemoryInfo: &apiv1.MachineMemoryInfo{ TotalBytes: 137438953472, // 128 GiB }, @@ -285,7 +285,7 @@ func TestRenderTable_WithMemoryInfo(t *testing.T) { // TestRenderTable_WithGPUInfo tests RenderTable with GPU information func TestRenderTable_WithGPUInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", GPUDriverVersion: "550.54.15", GPUInfo: &apiv1.MachineGPUInfo{ Product: "NVIDIA A100-SXM4-80GB", @@ -319,7 +319,7 @@ func TestRenderTable_WithGPUInfo(t *testing.T) { // TestRenderTable_WithNICInfo tests RenderTable with network interface information func TestRenderTable_WithNICInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", NICInfo: &apiv1.MachineNICInfo{ PrivateIPInterfaces: []apiv1.MachineNetworkInterface{ { @@ -358,7 +358,7 @@ func TestRenderTable_WithNICInfo(t *testing.T) { // TestRenderTable_WithDiskInfo tests RenderTable with disk information func TestRenderTable_WithDiskInfo(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", DiskInfo: &apiv1.MachineDiskInfo{ ContainerRootDisk: "/dev/sda1", }, @@ -380,7 +380,7 @@ func TestRenderTable_Complete(t *testing.T) { now := metav1.NewTime(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)) info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", GPUDriverVersion: "550.54.15", CUDAVersion: "12.4", ContainerRuntimeVersion: "containerd://1.7.13", @@ -466,7 +466,7 @@ func TestMachineInfo_JSONMarshaling(t *testing.T) { // The actual marshaling is tested implicitly through the struct definition info := &MachineInfo{ - FleetintVersion: "1.0.0", + AgentVersion: "1.0.0", GPUDriverVersion: "550.54.15", CUDAVersion: "12.4", DCGMVersion: "4.2.3", @@ -482,7 +482,7 @@ func TestMachineInfo_JSONMarshaling(t *testing.T) { // Verify all fields have proper json tags assert.NotNil(t, info) - assert.NotEmpty(t, info.FleetintVersion) + assert.NotEmpty(t, info.AgentVersion) assert.NotEmpty(t, info.GPUDriverVersion) assert.NotEmpty(t, info.CUDAVersion) assert.NotEmpty(t, info.DCGMVersion) @@ -506,7 +506,7 @@ func TestGetMachineInfo_DCGMVersionBestEffort(t *testing.T) { // TestRenderTable_WithNilSubStructs tests that RenderTable handles nil sub-structs gracefully func TestRenderTable_WithNilSubStructs(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", // All sub-structs are intentionally nil CPUInfo: nil, MemoryInfo: nil, @@ -526,13 +526,13 @@ func TestRenderTable_WithNilSubStructs(t *testing.T) { // Should still show basic info assert.Contains(t, output, "1.0.0-test") - assert.Contains(t, output, "Fleetint Version") + assert.Contains(t, output, "Agent Version") } // TestRenderTable_EmptyNICList tests RenderTable with empty NIC list func TestRenderTable_EmptyNICList(t *testing.T) { info := &MachineInfo{ - FleetintVersion: "1.0.0-test", + AgentVersion: "1.0.0-test", NICInfo: &apiv1.MachineNICInfo{ PrivateIPInterfaces: []apiv1.MachineNetworkInterface{}, }, diff --git a/internal/server/server.go b/internal/server/server.go index 8649c478..c1177786 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -53,8 +53,14 @@ import ( nvidianvml "github.com/NVIDIA/fleet-intelligence-sdk/pkg/nvidia-query/nvml" "github.com/NVIDIA/fleet-intelligence-sdk/pkg/sqlite" + "github.com/NVIDIA/fleet-intelligence-agent/internal/agentstate" + "github.com/NVIDIA/fleet-intelligence-agent/internal/attestation" "github.com/NVIDIA/fleet-intelligence-agent/internal/config" "github.com/NVIDIA/fleet-intelligence-agent/internal/exporter" + "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory" + inventorysink "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/sink" + inventorysource "github.com/NVIDIA/fleet-intelligence-agent/internal/inventory/source" + "github.com/NVIDIA/fleet-intelligence-agent/internal/machineinfo" "github.com/NVIDIA/fleet-intelligence-agent/internal/registry" ) @@ -87,6 +93,12 @@ type Server struct { machineID string } +type inventoryMachineInfoCollectorFunc func(context.Context) (*machineinfo.MachineInfo, error) + +func (f inventoryMachineInfoCollectorFunc) Collect(ctx context.Context) (*machineinfo.MachineInfo, error) { + return f(ctx) +} + // initializeDatabases opens and initializes database connections func initializeDatabases(ctx context.Context, cfg *config.Config) (*sql.DB, *sql.DB, error) { stateFile := ":memory:" @@ -164,6 +176,33 @@ func getHealthCheckInterval(config *config.Config) time.Duration { return healthCheckInterval } +func getInventorySyncInterval(config *config.Config) time.Duration { + if config == nil { + return 0 + } + if config.Inventory != nil { + if !config.Inventory.Enabled { + return 0 + } + return config.Inventory.Interval.Duration + } + return 0 +} + +func getAttestationInterval(config *config.Config) time.Duration { + if config == nil || config.Attestation == nil || !config.Attestation.Enabled { + return 0 + } + return config.Attestation.Interval.Duration +} + +func getAttestationTimeout(config *config.Config) time.Duration { + if config == nil || config.HealthExporter == nil { + return 0 + } + return config.HealthExporter.Timeout.Duration +} + // shouldEnableComponent determines if a component should be enabled based on configuration func shouldEnableComponent(name string, enabledByDefault bool, config *config.Config) bool { shouldEnable := enabledByDefault @@ -337,6 +376,9 @@ func New(ctx context.Context, auditLogger log.AuditLogger, config *config.Config } } + s.startInventoryLoop(ctx, config, nvmlInstance, dcgmGPUIndexes) + s.startAttestationLoop(ctx, config) + // Create and start health exporter with all dependencies if enabled if config.HealthExporter != nil { var err error @@ -368,6 +410,76 @@ func New(ctx context.Context, auditLogger log.AuditLogger, config *config.Config return s, nil } +func (s *Server) startInventoryLoop( + ctx context.Context, + cfg *config.Config, + nvmlInstance nvidianvml.Instance, + dcgmGPUIndexes map[string]string, +) { + interval := getInventorySyncInterval(cfg) + if interval <= 0 { + return + } + + allComponents := registry.AllComponentNames() + retentionPeriodSeconds, enabledComponents, disabledComponents := cfg.InventoryAgentConfig(allComponents) + + source := inventorysource.NewMachineInfoSourceWithAgentConfig( + inventoryMachineInfoCollectorFunc(func(context.Context) (*machineinfo.MachineInfo, error) { + return machineinfo.GetMachineInfo(nvmlInstance, machineinfo.WithDCGMGPUIndexes(dcgmGPUIndexes)) + }), + &inventory.AgentConfig{ + TotalComponents: int64(len(allComponents)), + RetentionPeriodSeconds: retentionPeriodSeconds, + EnabledComponents: enabledComponents, + DisabledComponents: disabledComponents, + }, + ) + sink := inventorysink.NewBackendSink(agentstate.NewSQLite()) + manager := inventory.NewManager(source, sink, inventory.InventoryConfig{ + Interval: interval, + RetryInterval: inventoryRetryInterval, + JitterEnabled: true, + }) + + go func() { + if err := manager.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.Logger.Errorw("inventory loop manager exited", "error", err) + } + }() +} + +const attestationRetryInterval = 5 * time.Minute +const inventoryRetryInterval = 1 * time.Minute + +func (s *Server) startAttestationLoop(ctx context.Context, cfg *config.Config) { + interval := getAttestationInterval(cfg) + if interval <= 0 { + return + } + + state := agentstate.NewSQLite() + manager := attestation.NewManager( + attestation.NewStateNodeUUIDProvider(state), + attestation.NewStateJWTProvider(state), + attestation.NewStateNonceProvider(state), + attestation.NewCLIEvidenceCollector(getAttestationTimeout(cfg)), + attestation.NewStateBackendSubmitter(state), + attestation.AttestationConfig{ + InitialInterval: cfg.Attestation.InitialInterval.Duration, + Interval: interval, + RetryInterval: attestationRetryInterval, + JitterEnabled: true, + }, + ) + + go func() { + if err := manager.Run(ctx); err != nil && !errors.Is(err, context.Canceled) { + log.Logger.Errorw("attestation loop exited", "error", err) + } + }() +} + // GetHealthExporter returns the health exporter instance (for offline mode access) func (s *Server) GetHealthExporter() exporter.Exporter { return s.healthExporter diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 5964fdd9..6eb1f6f3 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -79,6 +79,94 @@ func TestGetHealthCheckInterval(t *testing.T) { } } +func TestGetInventorySyncInterval(t *testing.T) { + tests := []struct { + name string + config *config.Config + expected time.Duration + }{ + { + name: "nil config", + config: nil, + expected: 0, + }, + { + name: "uses inventory sync config", + config: &config.Config{ + Inventory: &config.InventoryConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 2 * time.Minute}, + }, + }, + expected: 2 * time.Minute, + }, + { + name: "disabled inventory sync", + config: &config.Config{ + Inventory: &config.InventoryConfig{ + Enabled: false, + Interval: metav1.Duration{Duration: 2 * time.Minute}, + }, + }, + expected: 0, + }, + { + name: "no inventory or exporter config", + config: &config.Config{}, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.expected, getInventorySyncInterval(tt.config)) + }) + } +} + +func TestGetAttestationSettings(t *testing.T) { + tests := []struct { + name string + config *config.Config + wantInterval time.Duration + wantTimeout time.Duration + }{ + { + name: "nil config", + config: nil, + wantInterval: 0, + wantTimeout: 0, + }, + { + name: "no exporter", + config: &config.Config{}, + wantInterval: 0, + wantTimeout: 0, + }, + { + name: "uses health exporter attestation settings", + config: &config.Config{ + HealthExporter: &config.HealthExporterConfig{ + Timeout: metav1.Duration{Duration: 45 * time.Second}, + }, + Attestation: &config.AttestationConfig{ + Enabled: true, + Interval: metav1.Duration{Duration: 6 * time.Hour}, + }, + }, + wantInterval: 6 * time.Hour, + wantTimeout: 45 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.wantInterval, getAttestationInterval(tt.config)) + assert.Equal(t, tt.wantTimeout, getAttestationTimeout(tt.config)) + }) + } +} + // TestShouldEnableComponent tests the shouldEnableComponent function. func TestShouldEnableComponent(t *testing.T) { tests := []struct {